"""
Evaluation and trajectory aggregation for structured stochasticity experiments.
Key insight: With K trajectories, we need a strategy to select or combine answers.
Options include:
- Majority vote: Most common answer wins
- Best-of-K: Verify each, take first correct
- Weighted: Weight by confidence signals
- Verifier: Use separate model to check answers
"""
from collections import Counter
from dataclasses import dataclass, field
from typing import Optional, Callable
import re
from .tasks import Task, TaskInstance, TaskResult
[docs]
@dataclass
class AggregatedResult:
"""Result of aggregating multiple trajectories."""
selected_response: str
is_correct: bool
k_trajectories: int
selection_method: str
# Detailed breakdown
individual_results: list[TaskResult] = field(default_factory=list)
agreement_rate: float = 0.0
# If any trajectory got it right
any_correct: bool = False
num_correct: int = 0
metadata: dict = field(default_factory=dict)
[docs]
class TrajectoryAggregator:
"""
Aggregates multiple reasoning trajectories into a single answer.
This is where the "structured stochasticity" hypothesis gets tested:
if K>1 trajectories systematically improve accuracy vs K=1,
it suggests reasoning collapse is indeed a trajectory problem.
"""
[docs]
def __init__(
self,
method: str = "majority_vote",
verifier: Optional[Callable[[str, TaskInstance], float]] = None
):
"""
Args:
method: Aggregation method
- "majority_vote": Select most common answer
- "first_valid": Take first response that parses as valid
- "verifier": Use verifier function to score responses
- "oracle": Cheating mode - pick correct if any (for upper bound)
verifier: Optional function (response, instance) -> score
"""
self.method = method
self.verifier = verifier
[docs]
def aggregate(
self,
responses: list[str],
task: Task,
instance: TaskInstance
) -> AggregatedResult:
"""
Aggregate K responses into a single result.
Args:
responses: List of K response strings
task: Task being evaluated
instance: The specific task instance
Returns:
AggregatedResult with selected answer and metadata
"""
k = len(responses)
# Verify each trajectory individually
individual_results = [
task.verify_solution(instance, resp)
for resp in responses
]
num_correct = sum(r.is_correct for r in individual_results)
any_correct = num_correct > 0
# Select based on method
if self.method == "oracle":
# Cheating: pick correct answer if any trajectory found it
selected_idx = self._oracle_select(individual_results)
elif self.method == "majority_vote":
selected_idx = self._majority_vote(responses, task, instance)
elif self.method == "first_valid":
selected_idx = self._first_valid(individual_results)
elif self.method == "verifier" and self.verifier is not None:
selected_idx = self._verifier_select(responses, instance)
else:
selected_idx = 0 # Default to first
selected_response = responses[selected_idx]
selected_result = individual_results[selected_idx]
# Calculate agreement rate
extracted_answers = [self._extract_answer(r, task) for r in responses]
agreement_rate = self._calculate_agreement(extracted_answers)
return AggregatedResult(
selected_response=selected_response,
is_correct=selected_result.is_correct,
k_trajectories=k,
selection_method=self.method,
individual_results=individual_results,
agreement_rate=agreement_rate,
any_correct=any_correct,
num_correct=num_correct,
metadata={
"selected_idx": selected_idx,
"extracted_answers": extracted_answers
}
)
def _oracle_select(self, results: list[TaskResult]) -> int:
"""Select first correct result (oracle/cheating mode)."""
for i, r in enumerate(results):
if r.is_correct:
return i
return 0
def _majority_vote(
self,
responses: list[str],
task: Task,
instance: TaskInstance
) -> int:
"""Select response with most common extracted answer."""
answers = [self._extract_answer(r, task) for r in responses]
if not any(answers):
return 0
# Count non-None answers
valid_answers = [a for a in answers if a is not None]
if not valid_answers:
return 0
counter = Counter(valid_answers)
most_common = counter.most_common(1)[0][0]
# Return index of first response with this answer
for i, a in enumerate(answers):
if a == most_common:
return i
return 0
def _first_valid(self, results: list[TaskResult]) -> int:
"""Select first response that doesn't have an error."""
for i, r in enumerate(results):
if r.error_message is None:
return i
return 0
def _verifier_select(self, responses: list[str], instance: TaskInstance) -> int:
"""Use verifier to score and select best response."""
scores = [self.verifier(r, instance) for r in responses]
return max(range(len(scores)), key=lambda i: scores[i])
def _extract_answer(self, response: str, task: Task) -> Optional[str]:
"""
Extract a canonical answer from response for comparison.
This is task-specific - we try to extract what the model
claims as its final answer.
"""
# Generic patterns for "final answer"
patterns = [
r"(?:final answer|answer is|result is)[:\s]*([^\n.]+)",
r"(?:therefore|thus|so)[,:\s]*([^\n.]+)",
r"=\s*(\d+)\s*$",
]
response_lower = response.lower()
for pattern in patterns:
match = re.search(pattern, response_lower)
if match:
return match.group(1).strip()
# For Tower of Hanoi, count moves as proxy
if task.name == "tower_of_hanoi":
moves = re.findall(r"[Mm]ove.*?from.*?to", response)
return str(len(moves)) if moves else None
return None
def _calculate_agreement(self, answers: list[Optional[str]]) -> float:
"""Calculate what fraction of trajectories agree on an answer."""
valid = [a for a in answers if a is not None]
if not valid:
return 0.0
counter = Counter(valid)
most_common_count = counter.most_common(1)[0][1]
return most_common_count / len(valid)
[docs]
@dataclass
class EvaluationResult:
"""Results from evaluating across complexity levels."""
task_name: str
k_trajectories: int
selection_method: str
# Accuracy by complexity
accuracy_by_complexity: dict[int, float] = field(default_factory=dict)
# Maximum complexity solved (e.g., 100% accuracy up to this level)
max_solved_complexity: int = 0
# Detailed results
all_results: list[AggregatedResult] = field(default_factory=list)
# Oracle upper bound (what's possible if we could always pick right)
oracle_accuracy_by_complexity: dict[int, float] = field(default_factory=dict)
[docs]
class Evaluator:
"""
Runs evaluation across complexity levels.
Main experimental loop:
1. For each complexity level
2. Generate task instances
3. For each K value
4. Generate K trajectories per instance
5. Aggregate and verify
6. Report accuracy curves
"""
[docs]
def __init__(
self,
task: Task,
aggregator: Optional[TrajectoryAggregator] = None
):
self.task = task
self.aggregator = aggregator or TrajectoryAggregator()
[docs]
def evaluate(
self,
generate_fn: Callable[[str, int], list[str]],
complexity_range: tuple[int, int],
k_values: list[int] = [1, 5, 10],
trials_per_complexity: int = 10,
verbose: bool = True
) -> dict[int, EvaluationResult]:
"""
Run full evaluation.
Args:
generate_fn: Function (prompt, k) -> list of k responses
complexity_range: (min_complexity, max_complexity)
k_values: List of K values to test
trials_per_complexity: How many instances per complexity level
verbose: Print progress
Returns:
Dict mapping K -> EvaluationResult
"""
results = {k: EvaluationResult(
task_name=self.task.name,
k_trajectories=k,
selection_method=self.aggregator.method
) for k in k_values}
for complexity in range(complexity_range[0], complexity_range[1] + 1):
if verbose:
print(f"\nComplexity {complexity}:")
for k in k_values:
correct = 0
oracle_correct = 0
for trial in range(trials_per_complexity):
# Generate instance
instance = self.task.generate_instance(complexity)
# Generate K trajectories
responses = generate_fn(instance.prompt, k)
# Aggregate
agg_result = self.aggregator.aggregate(
responses, self.task, instance
)
results[k].all_results.append(agg_result)
if agg_result.is_correct:
correct += 1
if agg_result.any_correct:
oracle_correct += 1
accuracy = correct / trials_per_complexity
oracle_accuracy = oracle_correct / trials_per_complexity
results[k].accuracy_by_complexity[complexity] = accuracy
results[k].oracle_accuracy_by_complexity[complexity] = oracle_accuracy
if verbose:
print(f" K={k}: {accuracy:.1%} (oracle: {oracle_accuracy:.1%})")
# Calculate max solved complexity
for k, result in results.items():
for c in sorted(result.accuracy_by_complexity.keys()):
if result.accuracy_by_complexity[c] >= 0.9: # 90% threshold
result.max_solved_complexity = c
else:
break
return results
[docs]
def compare_k_scaling(
self,
results: dict[int, EvaluationResult]
) -> dict:
"""
Analyze how performance scales with K.
This is the key analysis: if max_solved_complexity increases
with K under constant token budgets, it supports the hypothesis.
"""
scaling = {
"k_values": sorted(results.keys()),
"max_complexity_by_k": {},
"accuracy_at_complexity": {},
}
for k, result in sorted(results.items()):
scaling["max_complexity_by_k"][k] = result.max_solved_complexity
for c, acc in result.accuracy_by_complexity.items():
if c not in scaling["accuracy_at_complexity"]:
scaling["accuracy_at_complexity"][c] = {}
scaling["accuracy_at_complexity"][c][k] = acc
return scaling