Source code for structured_stochasticity.hooks

"""
PyTorch hooks for accessing and modifying hidden states during inference.

This module provides the core mechanism for injecting noise into the model's
internal representations. It uses PyTorch's forward hook system to intercept
activations at specified layers and apply perturbations.
"""

from dataclasses import dataclass, field
from typing import Optional, Callable, Union
import warnings

import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer

from .injection import NoiseInjector, GaussianNoiseInjector, create_injector


[docs] @dataclass class HookConfig: """Configuration for hidden state hooks.""" layers: list[int] = field(default_factory=lambda: [0, 1, 2]) injection_point: str = "post" # "pre" or "post" layer norm enabled: bool = True
[docs] class HiddenStateHook: """ A hook that intercepts and optionally modifies hidden states. This is the low-level building block. For most uses, prefer NoisyInferenceWrapper which handles setup and cleanup. """
[docs] def __init__( self, injector: NoiseInjector, layer_idx: int, enabled: bool = True ): self.injector = injector self.layer_idx = layer_idx self.enabled = enabled self.handle: Optional[torch.utils.hooks.RemovableHandle] = None # For debugging/analysis self.call_count = 0 self.last_input_norm: Optional[float] = None self.last_output_norm: Optional[float] = None
def __call__( self, module: nn.Module, input: tuple[torch.Tensor, ...], output: Union[torch.Tensor, tuple] ) -> Union[torch.Tensor, tuple]: """ Hook function called after forward pass of target module. Handles both simple tensor outputs and tuple outputs (common in transformer layers). """ self.call_count += 1 if not self.enabled: return output # Handle tuple outputs (e.g., (hidden_states, attention_weights, ...)) if isinstance(output, tuple): hidden_states = output[0] rest = output[1:] else: hidden_states = output rest = None # Track norms for debugging self.last_input_norm = hidden_states.norm().item() # Inject noise perturbed = self.injector.inject(hidden_states) self.last_output_norm = perturbed.norm().item() # Reconstruct output format if rest is not None: return (perturbed,) + rest return perturbed
[docs] def register(self, module: nn.Module) -> "HiddenStateHook": """Register this hook on a module.""" self.handle = module.register_forward_hook(self) return self
[docs] def remove(self): """Remove the hook.""" if self.handle is not None: self.handle.remove() self.handle = None
[docs] def reset_stats(self): """Reset debugging statistics.""" self.call_count = 0 self.last_input_norm = None self.last_output_norm = None
[docs] class NoisyInferenceWrapper: """ Wraps a transformer model to enable noisy inference. This is the main interface for running experiments. It handles: - Identifying target layers in different model architectures - Registering/removing hooks - Generating multiple trajectories with different noise samples - Aggregating results Example: >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B") >>> wrapper = NoisyInferenceWrapper(model, injection_layers=[0,1,2]) >>> outputs = wrapper.generate_trajectories(input_ids, k=5) """ # Known architectures and their layer access patterns LAYER_PATTERNS = { "llama": "model.layers", "mistral": "model.layers", "gpt2": "transformer.h", "gpt_neo": "transformer.h", "opt": "model.decoder.layers", "falcon": "transformer.h", "phi": "model.layers", "qwen": "model.layers", "gemma": "model.layers", }
[docs] def __init__( self, model: PreTrainedModel, injection_layers: Optional[list[int]] = None, noise_scale: float = 0.1, noise_strategy: str = "gaussian", injection_mode: str = "continuous", device: Optional[str] = None, **injector_kwargs ): """ Initialize the wrapper. Args: model: HuggingFace transformer model injection_layers: Which layers to inject noise into. If None, defaults to first 25% of layers. noise_scale: Magnitude of noise injection noise_strategy: Type of noise ("gaussian", "uniform", "annealed", "once") injection_mode: "continuous" (every forward) or "once" (per generation) device: Device for noise tensors **injector_kwargs: Additional arguments for noise injector """ self.model = model self.device = device or next(model.parameters()).device self.noise_scale = noise_scale self.noise_strategy = noise_strategy self.injection_mode = injection_mode # Detect architecture and get layers self.layers = self._get_layers() self.num_layers = len(self.layers) # Default to early layers if not specified if injection_layers is None: n_inject = max(1, self.num_layers // 4) injection_layers = list(range(n_inject)) self.injection_layers = injection_layers # Create injectors and hooks self.injectors: dict[int, NoiseInjector] = {} self.hooks: dict[int, HiddenStateHook] = {} for layer_idx in self.injection_layers: if layer_idx >= self.num_layers: warnings.warn(f"Layer {layer_idx} exceeds model depth {self.num_layers}, skipping") continue # Use "once" strategy if injection_mode is "once" strategy = "once" if injection_mode == "once" else noise_strategy injector = create_injector( strategy=strategy, scale=noise_scale, device=str(self.device), **injector_kwargs ) self.injectors[layer_idx] = injector hook = HiddenStateHook(injector, layer_idx) self.hooks[layer_idx] = hook self._hooks_registered = False
def _get_layers(self) -> nn.ModuleList: """Extract the transformer layers from the model.""" model_type = getattr(self.model.config, "model_type", "").lower() # Try known patterns for arch, pattern in self.LAYER_PATTERNS.items(): if arch in model_type: try: layers = self.model for attr in pattern.split("."): layers = getattr(layers, attr) return layers except AttributeError: continue # Fallback: try common patterns for pattern in ["model.layers", "transformer.h", "decoder.layers"]: try: layers = self.model for attr in pattern.split("."): layers = getattr(layers, attr) return layers except AttributeError: continue raise ValueError( f"Could not find layers in model of type {model_type}. " f"Known architectures: {list(self.LAYER_PATTERNS.keys())}" )
[docs] def register_hooks(self): """Register all hooks on the model.""" if self._hooks_registered: return for layer_idx, hook in self.hooks.items(): hook.register(self.layers[layer_idx]) self._hooks_registered = True
[docs] def remove_hooks(self): """Remove all hooks from the model.""" for hook in self.hooks.values(): hook.remove() self._hooks_registered = False
[docs] def reset_injectors(self): """Reset all injectors (call between generations).""" for injector in self.injectors.values(): injector.reset()
[docs] def set_enabled(self, enabled: bool): """Enable or disable all hooks.""" for hook in self.hooks.values(): hook.enabled = enabled
[docs] def generate( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, max_new_tokens: int = 512, **generate_kwargs ) -> torch.Tensor: """ Generate with noise injection. Args: input_ids: Input token IDs attention_mask: Attention mask max_new_tokens: Maximum tokens to generate **generate_kwargs: Additional arguments for model.generate() Returns: Generated token IDs """ self.register_hooks() self.reset_injectors() try: with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, # Enable sampling **generate_kwargs ) return outputs finally: if self.injection_mode == "once": # Keep hooks but will resample on next reset pass
# Don't remove hooks - might want multiple generations
[docs] def generate_trajectories( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, k: int = 5, max_new_tokens: int = 512, **generate_kwargs ) -> list[torch.Tensor]: """ Generate K independent trajectories with different noise samples. This is the core experimental method. Each trajectory gets a fresh noise sample, enabling exploration of different reasoning paths. Args: input_ids: Input token IDs attention_mask: Attention mask k: Number of trajectories to generate max_new_tokens: Maximum tokens per trajectory **generate_kwargs: Additional arguments for generation Returns: List of K generated token ID tensors """ trajectories = [] self.register_hooks() for i in range(k): self.reset_injectors() # Fresh noise for each trajectory with torch.no_grad(): output = self.model.generate( input_ids=input_ids.clone(), attention_mask=attention_mask.clone() if attention_mask is not None else None, max_new_tokens=max_new_tokens, do_sample=True, **generate_kwargs ) trajectories.append(output) return trajectories
[docs] def generate_trajectories_decoded( self, prompt: str, tokenizer: AutoTokenizer, k: int = 5, max_new_tokens: int = 512, **generate_kwargs ) -> list[str]: """ Convenience method: generate K trajectories and decode to strings. Args: prompt: Input prompt string tokenizer: Tokenizer for encoding/decoding k: Number of trajectories max_new_tokens: Max tokens per trajectory Returns: List of K decoded response strings """ inputs = tokenizer(prompt, return_tensors="pt").to(self.device) trajectories = self.generate_trajectories( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"), k=k, max_new_tokens=max_new_tokens, **generate_kwargs ) decoded = [] for traj in trajectories: # Decode only the generated part (not the prompt) generated = traj[0, inputs["input_ids"].shape[1]:] text = tokenizer.decode(generated, skip_special_tokens=True) decoded.append(text) return decoded
[docs] def get_stats(self) -> dict: """Get debugging statistics from hooks.""" return { layer_idx: { "call_count": hook.call_count, "last_input_norm": hook.last_input_norm, "last_output_norm": hook.last_output_norm, } for layer_idx, hook in self.hooks.items() }
def __enter__(self): self.register_hooks() return self def __exit__(self, *args): self.remove_hooks() def __repr__(self): return ( f"NoisyInferenceWrapper(" f"layers={self.injection_layers}, " f"scale={self.noise_scale}, " f"strategy={self.noise_strategy}, " f"mode={self.injection_mode})" )