"""
Main experiment runner for structured stochasticity experiments.
This module ties everything together and provides CLI interface
for running experiments.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import json
import time
from datetime import datetime
import yaml
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from .hooks import NoisyInferenceWrapper
from .tasks import Task, TaskFactory
from .evaluation import Evaluator, TrajectoryAggregator, EvaluationResult
[docs]
@dataclass
class ExperimentConfig:
"""Configuration for an experiment run."""
# Model
model_name: str = "meta-llama/Llama-3.2-1B"
device: str = "cuda"
torch_dtype: str = "float16"
# Noise injection
injection_layers: list[int] = field(default_factory=lambda: [0, 1, 2, 3])
noise_scale: float = 0.1
noise_strategy: str = "gaussian"
injection_mode: str = "continuous"
# Task
task_name: str = "tower_of_hanoi"
complexity_range: tuple[int, int] = (3, 7)
trials_per_complexity: int = 10
# Evaluation
k_values: list[int] = field(default_factory=lambda: [1, 3, 5, 10])
selection_method: str = "majority_vote"
max_new_tokens: int = 1024
# Output
output_dir: str = "experiments"
experiment_name: Optional[str] = None
[docs]
@classmethod
def from_yaml(cls, path: str) -> "ExperimentConfig":
"""Load config from YAML file."""
with open(path) as f:
data = yaml.safe_load(f)
# Flatten nested structure if present
flat = {}
for section in ["model", "injection", "task", "evaluation", "output"]:
if section in data:
flat.update(data[section])
else:
flat.update(data)
return cls(**{k: v for k, v in flat.items() if k in cls.__dataclass_fields__})
[docs]
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"model_name": self.model_name,
"device": self.device,
"torch_dtype": self.torch_dtype,
"injection_layers": self.injection_layers,
"noise_scale": self.noise_scale,
"noise_strategy": self.noise_strategy,
"injection_mode": self.injection_mode,
"task_name": self.task_name,
"complexity_range": self.complexity_range,
"trials_per_complexity": self.trials_per_complexity,
"k_values": self.k_values,
"selection_method": self.selection_method,
"max_new_tokens": self.max_new_tokens,
"output_dir": self.output_dir,
"experiment_name": self.experiment_name,
}
[docs]
class Experiment:
"""
Main experiment class.
Handles:
- Loading model and tokenizer
- Setting up noise injection
- Running evaluation
- Saving results
"""
[docs]
def __init__(self, config: ExperimentConfig):
self.config = config
self.model = None
self.tokenizer = None
self.wrapper = None
self.task = None
self.results = None
[docs]
def setup(self):
"""Load model, tokenizer, and task."""
print(f"Loading model: {self.config.model_name}")
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map.get(self.config.torch_dtype, torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
torch_dtype=dtype,
device_map=self.config.device
)
print(f"Setting up noise injection: {self.config.noise_strategy}, "
f"scale={self.config.noise_scale}, layers={self.config.injection_layers}")
self.wrapper = NoisyInferenceWrapper(
self.model,
injection_layers=self.config.injection_layers,
noise_scale=self.config.noise_scale,
noise_strategy=self.config.noise_strategy,
injection_mode=self.config.injection_mode,
device=self.config.device
)
print(f"Loading task: {self.config.task_name}")
self.task = TaskFactory.create(self.config.task_name)
[docs]
def generate_responses(self, prompt: str, k: int) -> list[str]:
"""Generate K responses for a prompt."""
return self.wrapper.generate_trajectories_decoded(
prompt=prompt,
tokenizer=self.tokenizer,
k=k,
max_new_tokens=self.config.max_new_tokens,
temperature=1.0, # Standard temperature
top_p=0.95,
)
[docs]
def run(self) -> dict[int, EvaluationResult]:
"""Run the full experiment."""
if self.model is None:
self.setup()
aggregator = TrajectoryAggregator(method=self.config.selection_method)
evaluator = Evaluator(self.task, aggregator)
print(f"\nRunning evaluation:")
print(f" Complexity range: {self.config.complexity_range}")
print(f" K values: {self.config.k_values}")
print(f" Trials per complexity: {self.config.trials_per_complexity}")
print()
self.results = evaluator.evaluate(
generate_fn=self.generate_responses,
complexity_range=self.config.complexity_range,
k_values=self.config.k_values,
trials_per_complexity=self.config.trials_per_complexity,
verbose=True
)
return self.results
[docs]
def save_results(self, output_path: Optional[str] = None):
"""Save experiment results to JSON."""
if self.results is None:
raise ValueError("No results to save. Run experiment first.")
if output_path is None:
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
name = self.config.experiment_name or datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = output_dir / f"{name}.json"
output_path = Path(output_path)
# Convert results to serializable format
output = {
"config": self.config.to_dict(),
"timestamp": datetime.now().isoformat(),
"results": {}
}
for k, result in self.results.items():
output["results"][k] = {
"task_name": result.task_name,
"k_trajectories": result.k_trajectories,
"selection_method": result.selection_method,
"accuracy_by_complexity": result.accuracy_by_complexity,
"oracle_accuracy_by_complexity": result.oracle_accuracy_by_complexity,
"max_solved_complexity": result.max_solved_complexity,
}
with open(output_path, "w") as f:
json.dump(output, f, indent=2)
print(f"\nResults saved to: {output_path}")
return output_path
[docs]
def print_summary(self):
"""Print summary of results."""
if self.results is None:
print("No results available.")
return
print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)
print(f"\nTask: {self.config.task_name}")
print(f"Model: {self.config.model_name}")
print(f"Noise: {self.config.noise_strategy}, scale={self.config.noise_scale}")
print(f"Injection layers: {self.config.injection_layers}")
print(f"\nAccuracy by K and Complexity:")
print("-" * 50)
# Header
k_values = sorted(self.results.keys())
header = "Complexity |" + "|".join(f" K={k:2d} " for k in k_values)
print(header)
print("-" * len(header))
# Get all complexity levels
all_complexities = set()
for result in self.results.values():
all_complexities.update(result.accuracy_by_complexity.keys())
for c in sorted(all_complexities):
row = f" {c:2d} |"
for k in k_values:
acc = self.results[k].accuracy_by_complexity.get(c, 0)
row += f" {acc:5.1%} |"
print(row)
print("-" * 50)
# Max solved complexity
print("\nMax Solved Complexity (≥90% accuracy):")
for k in k_values:
print(f" K={k}: {self.results[k].max_solved_complexity}")
# Key finding
k1_max = self.results.get(1, self.results[min(k_values)]).max_solved_complexity
k_max_val = max(k_values)
k_max_result = self.results[k_max_val].max_solved_complexity
if k_max_result > k1_max:
print(f"\n✓ POSITIVE SIGNAL: K={k_max_val} solves complexity "
f"{k_max_result} vs K=1 solving {k1_max}")
else:
print(f"\n○ No improvement: K={k_max_val} matches K=1 "
f"(both solve up to complexity {k1_max})")
[docs]
def main():
"""CLI entry point."""
import click
@click.command()
@click.option("--config", "-c", type=str, help="Path to YAML config file")
@click.option("--model", "-m", type=str, help="Model name (overrides config)")
@click.option("--scale", "-s", type=float, help="Noise scale (overrides config)")
@click.option("--sweep", nargs=3, multiple=True,
help="Parameter sweep: --sweep param val1 val2 ...")
def run_experiment(config, model, scale, sweep):
"""Run structured stochasticity experiment."""
if config:
exp_config = ExperimentConfig.from_yaml(config)
else:
exp_config = ExperimentConfig()
# Override with CLI args
if model:
exp_config.model_name = model
if scale:
exp_config.noise_scale = scale
# Handle sweeps
if sweep:
# TODO: Implement parameter sweeps
print("Parameter sweeps not yet implemented")
return
# Run experiment
experiment = Experiment(exp_config)
experiment.run()
experiment.print_summary()
experiment.save_results()
run_experiment()
# For running experiments programmatically
[docs]
def run_quick_experiment(
model_name: str = "meta-llama/Llama-3.2-1B",
noise_scale: float = 0.1,
complexity_range: tuple[int, int] = (3, 5),
k_values: list[int] = [1, 5],
trials: int = 5
) -> dict:
"""
Quick experiment runner for notebooks/scripts.
Returns dict with results summary.
"""
config = ExperimentConfig(
model_name=model_name,
noise_scale=noise_scale,
complexity_range=complexity_range,
k_values=k_values,
trials_per_complexity=trials
)
experiment = Experiment(config)
experiment.run()
experiment.print_summary()
return {
"config": config.to_dict(),
"accuracy_by_k": {
k: result.accuracy_by_complexity
for k, result in experiment.results.items()
},
"max_solved_by_k": {
k: result.max_solved_complexity
for k, result in experiment.results.items()
}
}
if __name__ == "__main__":
main()