Source code for structured_stochasticity.evaluation

"""
Evaluation and trajectory aggregation for structured stochasticity experiments.

Key insight: With K trajectories, we need a strategy to select or combine answers.
Options include:
- Majority vote: Most common answer wins
- Best-of-K: Verify each, take first correct
- Weighted: Weight by confidence signals
- Verifier: Use separate model to check answers
"""

from collections import Counter
from dataclasses import dataclass, field
from typing import Optional, Callable
import re

from .tasks import Task, TaskInstance, TaskResult


[docs] @dataclass class AggregatedResult: """Result of aggregating multiple trajectories.""" selected_response: str is_correct: bool k_trajectories: int selection_method: str # Detailed breakdown individual_results: list[TaskResult] = field(default_factory=list) agreement_rate: float = 0.0 # If any trajectory got it right any_correct: bool = False num_correct: int = 0 metadata: dict = field(default_factory=dict)
[docs] class TrajectoryAggregator: """ Aggregates multiple reasoning trajectories into a single answer. This is where the "structured stochasticity" hypothesis gets tested: if K>1 trajectories systematically improve accuracy vs K=1, it suggests reasoning collapse is indeed a trajectory problem. """
[docs] def __init__( self, method: str = "majority_vote", verifier: Optional[Callable[[str, TaskInstance], float]] = None ): """ Args: method: Aggregation method - "majority_vote": Select most common answer - "first_valid": Take first response that parses as valid - "verifier": Use verifier function to score responses - "oracle": Cheating mode - pick correct if any (for upper bound) verifier: Optional function (response, instance) -> score """ self.method = method self.verifier = verifier
[docs] def aggregate( self, responses: list[str], task: Task, instance: TaskInstance ) -> AggregatedResult: """ Aggregate K responses into a single result. Args: responses: List of K response strings task: Task being evaluated instance: The specific task instance Returns: AggregatedResult with selected answer and metadata """ k = len(responses) # Verify each trajectory individually individual_results = [ task.verify_solution(instance, resp) for resp in responses ] num_correct = sum(r.is_correct for r in individual_results) any_correct = num_correct > 0 # Select based on method if self.method == "oracle": # Cheating: pick correct answer if any trajectory found it selected_idx = self._oracle_select(individual_results) elif self.method == "majority_vote": selected_idx = self._majority_vote(responses, task, instance) elif self.method == "first_valid": selected_idx = self._first_valid(individual_results) elif self.method == "verifier" and self.verifier is not None: selected_idx = self._verifier_select(responses, instance) else: selected_idx = 0 # Default to first selected_response = responses[selected_idx] selected_result = individual_results[selected_idx] # Calculate agreement rate extracted_answers = [self._extract_answer(r, task) for r in responses] agreement_rate = self._calculate_agreement(extracted_answers) return AggregatedResult( selected_response=selected_response, is_correct=selected_result.is_correct, k_trajectories=k, selection_method=self.method, individual_results=individual_results, agreement_rate=agreement_rate, any_correct=any_correct, num_correct=num_correct, metadata={ "selected_idx": selected_idx, "extracted_answers": extracted_answers } )
def _oracle_select(self, results: list[TaskResult]) -> int: """Select first correct result (oracle/cheating mode).""" for i, r in enumerate(results): if r.is_correct: return i return 0 def _majority_vote( self, responses: list[str], task: Task, instance: TaskInstance ) -> int: """Select response with most common extracted answer.""" answers = [self._extract_answer(r, task) for r in responses] if not any(answers): return 0 # Count non-None answers valid_answers = [a for a in answers if a is not None] if not valid_answers: return 0 counter = Counter(valid_answers) most_common = counter.most_common(1)[0][0] # Return index of first response with this answer for i, a in enumerate(answers): if a == most_common: return i return 0 def _first_valid(self, results: list[TaskResult]) -> int: """Select first response that doesn't have an error.""" for i, r in enumerate(results): if r.error_message is None: return i return 0 def _verifier_select(self, responses: list[str], instance: TaskInstance) -> int: """Use verifier to score and select best response.""" scores = [self.verifier(r, instance) for r in responses] return max(range(len(scores)), key=lambda i: scores[i]) def _extract_answer(self, response: str, task: Task) -> Optional[str]: """ Extract a canonical answer from response for comparison. This is task-specific - we try to extract what the model claims as its final answer. """ # Generic patterns for "final answer" patterns = [ r"(?:final answer|answer is|result is)[:\s]*([^\n.]+)", r"(?:therefore|thus|so)[,:\s]*([^\n.]+)", r"=\s*(\d+)\s*$", ] response_lower = response.lower() for pattern in patterns: match = re.search(pattern, response_lower) if match: return match.group(1).strip() # For Tower of Hanoi, count moves as proxy if task.name == "tower_of_hanoi": moves = re.findall(r"[Mm]ove.*?from.*?to", response) return str(len(moves)) if moves else None return None def _calculate_agreement(self, answers: list[Optional[str]]) -> float: """Calculate what fraction of trajectories agree on an answer.""" valid = [a for a in answers if a is not None] if not valid: return 0.0 counter = Counter(valid) most_common_count = counter.most_common(1)[0][1] return most_common_count / len(valid)
[docs] @dataclass class EvaluationResult: """Results from evaluating across complexity levels.""" task_name: str k_trajectories: int selection_method: str # Accuracy by complexity accuracy_by_complexity: dict[int, float] = field(default_factory=dict) # Maximum complexity solved (e.g., 100% accuracy up to this level) max_solved_complexity: int = 0 # Detailed results all_results: list[AggregatedResult] = field(default_factory=list) # Oracle upper bound (what's possible if we could always pick right) oracle_accuracy_by_complexity: dict[int, float] = field(default_factory=dict)
[docs] class Evaluator: """ Runs evaluation across complexity levels. Main experimental loop: 1. For each complexity level 2. Generate task instances 3. For each K value 4. Generate K trajectories per instance 5. Aggregate and verify 6. Report accuracy curves """
[docs] def __init__( self, task: Task, aggregator: Optional[TrajectoryAggregator] = None ): self.task = task self.aggregator = aggregator or TrajectoryAggregator()
[docs] def evaluate( self, generate_fn: Callable[[str, int], list[str]], complexity_range: tuple[int, int], k_values: list[int] = [1, 5, 10], trials_per_complexity: int = 10, verbose: bool = True ) -> dict[int, EvaluationResult]: """ Run full evaluation. Args: generate_fn: Function (prompt, k) -> list of k responses complexity_range: (min_complexity, max_complexity) k_values: List of K values to test trials_per_complexity: How many instances per complexity level verbose: Print progress Returns: Dict mapping K -> EvaluationResult """ results = {k: EvaluationResult( task_name=self.task.name, k_trajectories=k, selection_method=self.aggregator.method ) for k in k_values} for complexity in range(complexity_range[0], complexity_range[1] + 1): if verbose: print(f"\nComplexity {complexity}:") for k in k_values: correct = 0 oracle_correct = 0 for trial in range(trials_per_complexity): # Generate instance instance = self.task.generate_instance(complexity) # Generate K trajectories responses = generate_fn(instance.prompt, k) # Aggregate agg_result = self.aggregator.aggregate( responses, self.task, instance ) results[k].all_results.append(agg_result) if agg_result.is_correct: correct += 1 if agg_result.any_correct: oracle_correct += 1 accuracy = correct / trials_per_complexity oracle_accuracy = oracle_correct / trials_per_complexity results[k].accuracy_by_complexity[complexity] = accuracy results[k].oracle_accuracy_by_complexity[complexity] = oracle_accuracy if verbose: print(f" K={k}: {accuracy:.1%} (oracle: {oracle_accuracy:.1%})") # Calculate max solved complexity for k, result in results.items(): for c in sorted(result.accuracy_by_complexity.keys()): if result.accuracy_by_complexity[c] >= 0.9: # 90% threshold result.max_solved_complexity = c else: break return results
[docs] def compare_k_scaling( self, results: dict[int, EvaluationResult] ) -> dict: """ Analyze how performance scales with K. This is the key analysis: if max_solved_complexity increases with K under constant token budgets, it supports the hypothesis. """ scaling = { "k_values": sorted(results.keys()), "max_complexity_by_k": {}, "accuracy_at_complexity": {}, } for k, result in sorted(results.items()): scaling["max_complexity_by_k"][k] = result.max_solved_complexity for c, acc in result.accuracy_by_complexity.items(): if c not in scaling["accuracy_at_complexity"]: scaling["accuracy_at_complexity"][c] = {} scaling["accuracy_at_complexity"][c][k] = acc return scaling