"""
PyTorch hooks for accessing and modifying hidden states during inference.
This module provides the core mechanism for injecting noise into the model's
internal representations. It uses PyTorch's forward hook system to intercept
activations at specified layers and apply perturbations.
"""
from dataclasses import dataclass, field
from typing import Optional, Callable, Union
import warnings
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
from .injection import NoiseInjector, GaussianNoiseInjector, create_injector
[docs]
@dataclass
class HookConfig:
"""Configuration for hidden state hooks."""
layers: list[int] = field(default_factory=lambda: [0, 1, 2])
injection_point: str = "post" # "pre" or "post" layer norm
enabled: bool = True
[docs]
class HiddenStateHook:
"""
A hook that intercepts and optionally modifies hidden states.
This is the low-level building block. For most uses, prefer
NoisyInferenceWrapper which handles setup and cleanup.
"""
[docs]
def __init__(
self,
injector: NoiseInjector,
layer_idx: int,
enabled: bool = True
):
self.injector = injector
self.layer_idx = layer_idx
self.enabled = enabled
self.handle: Optional[torch.utils.hooks.RemovableHandle] = None
# For debugging/analysis
self.call_count = 0
self.last_input_norm: Optional[float] = None
self.last_output_norm: Optional[float] = None
def __call__(
self,
module: nn.Module,
input: tuple[torch.Tensor, ...],
output: Union[torch.Tensor, tuple]
) -> Union[torch.Tensor, tuple]:
"""
Hook function called after forward pass of target module.
Handles both simple tensor outputs and tuple outputs
(common in transformer layers).
"""
self.call_count += 1
if not self.enabled:
return output
# Handle tuple outputs (e.g., (hidden_states, attention_weights, ...))
if isinstance(output, tuple):
hidden_states = output[0]
rest = output[1:]
else:
hidden_states = output
rest = None
# Track norms for debugging
self.last_input_norm = hidden_states.norm().item()
# Inject noise
perturbed = self.injector.inject(hidden_states)
self.last_output_norm = perturbed.norm().item()
# Reconstruct output format
if rest is not None:
return (perturbed,) + rest
return perturbed
[docs]
def register(self, module: nn.Module) -> "HiddenStateHook":
"""Register this hook on a module."""
self.handle = module.register_forward_hook(self)
return self
[docs]
def remove(self):
"""Remove the hook."""
if self.handle is not None:
self.handle.remove()
self.handle = None
[docs]
def reset_stats(self):
"""Reset debugging statistics."""
self.call_count = 0
self.last_input_norm = None
self.last_output_norm = None
[docs]
class NoisyInferenceWrapper:
"""
Wraps a transformer model to enable noisy inference.
This is the main interface for running experiments. It handles:
- Identifying target layers in different model architectures
- Registering/removing hooks
- Generating multiple trajectories with different noise samples
- Aggregating results
Example:
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
>>> wrapper = NoisyInferenceWrapper(model, injection_layers=[0,1,2])
>>> outputs = wrapper.generate_trajectories(input_ids, k=5)
"""
# Known architectures and their layer access patterns
LAYER_PATTERNS = {
"llama": "model.layers",
"mistral": "model.layers",
"gpt2": "transformer.h",
"gpt_neo": "transformer.h",
"opt": "model.decoder.layers",
"falcon": "transformer.h",
"phi": "model.layers",
"qwen": "model.layers",
"gemma": "model.layers",
}
[docs]
def __init__(
self,
model: PreTrainedModel,
injection_layers: Optional[list[int]] = None,
noise_scale: float = 0.1,
noise_strategy: str = "gaussian",
injection_mode: str = "continuous",
device: Optional[str] = None,
**injector_kwargs
):
"""
Initialize the wrapper.
Args:
model: HuggingFace transformer model
injection_layers: Which layers to inject noise into.
If None, defaults to first 25% of layers.
noise_scale: Magnitude of noise injection
noise_strategy: Type of noise ("gaussian", "uniform", "annealed", "once")
injection_mode: "continuous" (every forward) or "once" (per generation)
device: Device for noise tensors
**injector_kwargs: Additional arguments for noise injector
"""
self.model = model
self.device = device or next(model.parameters()).device
self.noise_scale = noise_scale
self.noise_strategy = noise_strategy
self.injection_mode = injection_mode
# Detect architecture and get layers
self.layers = self._get_layers()
self.num_layers = len(self.layers)
# Default to early layers if not specified
if injection_layers is None:
n_inject = max(1, self.num_layers // 4)
injection_layers = list(range(n_inject))
self.injection_layers = injection_layers
# Create injectors and hooks
self.injectors: dict[int, NoiseInjector] = {}
self.hooks: dict[int, HiddenStateHook] = {}
for layer_idx in self.injection_layers:
if layer_idx >= self.num_layers:
warnings.warn(f"Layer {layer_idx} exceeds model depth {self.num_layers}, skipping")
continue
# Use "once" strategy if injection_mode is "once"
strategy = "once" if injection_mode == "once" else noise_strategy
injector = create_injector(
strategy=strategy,
scale=noise_scale,
device=str(self.device),
**injector_kwargs
)
self.injectors[layer_idx] = injector
hook = HiddenStateHook(injector, layer_idx)
self.hooks[layer_idx] = hook
self._hooks_registered = False
def _get_layers(self) -> nn.ModuleList:
"""Extract the transformer layers from the model."""
model_type = getattr(self.model.config, "model_type", "").lower()
# Try known patterns
for arch, pattern in self.LAYER_PATTERNS.items():
if arch in model_type:
try:
layers = self.model
for attr in pattern.split("."):
layers = getattr(layers, attr)
return layers
except AttributeError:
continue
# Fallback: try common patterns
for pattern in ["model.layers", "transformer.h", "decoder.layers"]:
try:
layers = self.model
for attr in pattern.split("."):
layers = getattr(layers, attr)
return layers
except AttributeError:
continue
raise ValueError(
f"Could not find layers in model of type {model_type}. "
f"Known architectures: {list(self.LAYER_PATTERNS.keys())}"
)
[docs]
def register_hooks(self):
"""Register all hooks on the model."""
if self._hooks_registered:
return
for layer_idx, hook in self.hooks.items():
hook.register(self.layers[layer_idx])
self._hooks_registered = True
[docs]
def remove_hooks(self):
"""Remove all hooks from the model."""
for hook in self.hooks.values():
hook.remove()
self._hooks_registered = False
[docs]
def reset_injectors(self):
"""Reset all injectors (call between generations)."""
for injector in self.injectors.values():
injector.reset()
[docs]
def set_enabled(self, enabled: bool):
"""Enable or disable all hooks."""
for hook in self.hooks.values():
hook.enabled = enabled
[docs]
def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
max_new_tokens: int = 512,
**generate_kwargs
) -> torch.Tensor:
"""
Generate with noise injection.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
max_new_tokens: Maximum tokens to generate
**generate_kwargs: Additional arguments for model.generate()
Returns:
Generated token IDs
"""
self.register_hooks()
self.reset_injectors()
try:
with torch.no_grad():
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True, # Enable sampling
**generate_kwargs
)
return outputs
finally:
if self.injection_mode == "once":
# Keep hooks but will resample on next reset
pass
# Don't remove hooks - might want multiple generations
[docs]
def generate_trajectories(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
k: int = 5,
max_new_tokens: int = 512,
**generate_kwargs
) -> list[torch.Tensor]:
"""
Generate K independent trajectories with different noise samples.
This is the core experimental method. Each trajectory gets a fresh
noise sample, enabling exploration of different reasoning paths.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
k: Number of trajectories to generate
max_new_tokens: Maximum tokens per trajectory
**generate_kwargs: Additional arguments for generation
Returns:
List of K generated token ID tensors
"""
trajectories = []
self.register_hooks()
for i in range(k):
self.reset_injectors() # Fresh noise for each trajectory
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids.clone(),
attention_mask=attention_mask.clone() if attention_mask is not None else None,
max_new_tokens=max_new_tokens,
do_sample=True,
**generate_kwargs
)
trajectories.append(output)
return trajectories
[docs]
def generate_trajectories_decoded(
self,
prompt: str,
tokenizer: AutoTokenizer,
k: int = 5,
max_new_tokens: int = 512,
**generate_kwargs
) -> list[str]:
"""
Convenience method: generate K trajectories and decode to strings.
Args:
prompt: Input prompt string
tokenizer: Tokenizer for encoding/decoding
k: Number of trajectories
max_new_tokens: Max tokens per trajectory
Returns:
List of K decoded response strings
"""
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
trajectories = self.generate_trajectories(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
k=k,
max_new_tokens=max_new_tokens,
**generate_kwargs
)
decoded = []
for traj in trajectories:
# Decode only the generated part (not the prompt)
generated = traj[0, inputs["input_ids"].shape[1]:]
text = tokenizer.decode(generated, skip_special_tokens=True)
decoded.append(text)
return decoded
[docs]
def get_stats(self) -> dict:
"""Get debugging statistics from hooks."""
return {
layer_idx: {
"call_count": hook.call_count,
"last_input_norm": hook.last_input_norm,
"last_output_norm": hook.last_output_norm,
}
for layer_idx, hook in self.hooks.items()
}
def __enter__(self):
self.register_hooks()
return self
def __exit__(self, *args):
self.remove_hooks()
def __repr__(self):
return (
f"NoisyInferenceWrapper("
f"layers={self.injection_layers}, "
f"scale={self.noise_scale}, "
f"strategy={self.noise_strategy}, "
f"mode={self.injection_mode})"
)