Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions amnesia_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,12 @@ def with_exponential_backoff(fn, max_retries=20, base_delay=2.0, max_delay=120.0
class LLMClient:
"""Wrapper for llama.cpp or any OpenAI-compatible /v1/chat/completions endpoint."""

def __init__(self, server_url: str = SERVER_URL, temperature: float = TEMPERATURE, api_key: str = None, model_name: str = None):
def __init__(self, server_url: str = SERVER_URL, temperature: float = TEMPERATURE, api_key: str = None, model_name: str = None, reasoning_enabled: bool = False):
self.server_url = server_url.rstrip("/")
self.temperature = temperature
self.model_name = model_name # passed to API as model field (required by OpenRouter)
self.auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
self.reasoning_enabled = reasoning_enabled # send reasoning: {enabled: true} for OpenRouter

def generate(self, messages: list[dict], max_tokens: int) -> dict:
"""
Expand All @@ -227,6 +228,8 @@ def generate(self, messages: list[dict], max_tokens: int) -> dict:
"temperature": self.temperature,
"stream": True,
}
if self.reasoning_enabled:
payload["reasoning"] = {"enabled": True}

def _do_request():
if self.model_name:
Expand Down Expand Up @@ -293,9 +296,15 @@ def _do_request():
else:
full_content = content

# Build reasoning_details for multi-turn pass-back (OpenRouter format)
reasoning_details = None
if reasoning and self.reasoning_enabled:
reasoning_details = [{"type": "thinking", "thinking": reasoning}]

return {
"content": full_content,
"reasoning_content": reasoning,
"reasoning_details": reasoning_details,
"final_content": content,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
Expand Down Expand Up @@ -423,6 +432,7 @@ def _do_request():
return {
"content": content,
"reasoning_content": "",
"reasoning_details": None,
"final_content": content,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
Expand All @@ -449,6 +459,7 @@ def create_client(
api_key: str = None,
model_name: str = None,
temperature: float = TEMPERATURE,
reasoning_enabled: bool = True,
) -> Union[LLMClient, GeminiClient]:
"""
Create appropriate client based on server_url scheme.
Expand Down Expand Up @@ -482,6 +493,7 @@ def create_client(
temperature=temperature,
api_key=api_key,
model_name=or_model,
reasoning_enabled=reasoning_enabled,
)
elif server_url.startswith("http"):
return LLMClient(server_url=server_url, temperature=temperature)
Expand Down Expand Up @@ -946,6 +958,7 @@ def run_trial(
finish = "truncated"

content = resp["content"]
reasoning_details = resp.get("reasoning_details")
total_now = resp["total_tokens"]
peak_tokens = max(peak_tokens, total_now)

Expand Down Expand Up @@ -1007,11 +1020,17 @@ def run_trial(
code_output=combined_output,
)
conversation.append(code_turn)
messages.append({"role": "assistant", "content": content})
asst_msg = {"role": "assistant", "content": content}
if reasoning_details:
asst_msg["reasoning_details"] = reasoning_details
messages.append(asst_msg)
messages.append({"role": "user", "content": f"Code output:\n{combined_output}"})
continue

messages.append({"role": "assistant", "content": content})
asst_msg = {"role": "assistant", "content": content}
if reasoning_details:
asst_msg["reasoning_details"] = reasoning_details
messages.append(asst_msg)
messages.append({"role": "user", "content": "Continue solving."})
conversation.append(Turn(role="user", content="Continue solving."))

Expand Down Expand Up @@ -1341,6 +1360,7 @@ def run_all_models(
trials: int = TRIALS_PER_WINDOW,
temperature: float = TEMPERATURE,
cli_api_key: str = None,
reasoning_enabled: bool = True,
):
"""
Iterate over all models in models.json, run all problems for each model.
Expand Down Expand Up @@ -1379,6 +1399,7 @@ def run_all_models(
api_key=api_key,
model_name=mname,
temperature=temperature,
reasoning_enabled=reasoning_enabled,
)
except ValueError as e:
print(f" ERROR: Could not create client for {mname}: {e} — skipping")
Expand Down Expand Up @@ -1481,6 +1502,8 @@ def main():
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--config", type=str, default=None,
help="Run specific config only: NoTIR_HardCut, TIR_HardCut, NoTIR_Compact, TIR_Compact")
parser.add_argument("--no-reasoning", action="store_true",
help="Disable reasoning tokens for OpenRouter models (saves cost/latency)")
parser.add_argument("--results-dir", type=str, default=None,
help="Results directory for --scores / --analyze (default: ./results)")

Expand Down Expand Up @@ -1536,6 +1559,8 @@ def main():
else:
api_key = os.environ.get("GEMINI_API_KEY")

reasoning_enabled = not args.no_reasoning

# Multi-model mode
if args.run_all_models:
run_all_models(
Expand All @@ -1546,6 +1571,7 @@ def main():
trials=trials_per_window,
temperature=args.temperature,
cli_api_key=api_key,
reasoning_enabled=reasoning_enabled,
)
return

Expand All @@ -1559,6 +1585,7 @@ def main():
api_key=api_key,
model_name=model_name,
temperature=args.temperature,
reasoning_enabled=reasoning_enabled,
)
except ValueError as e:
print(f"ERROR: {e}")
Expand Down