From 464e311529344492e7a1c9527a040d6af5978da9 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:26:30 +0000 Subject: [PATCH 1/3] Fix duplicate defaults: context_length and max_turns Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/investigate/scripts/run_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py index 54806ed36..1e35888bb 100644 --- a/spd/investigate/scripts/run_agent.py +++ b/spd/investigate/scripts/run_agent.py @@ -109,8 +109,8 @@ def log_event(events_path: Path, event: InvestigationEvent) -> None: def run_agent( wandb_path: str, inv_id: str, - context_length: int = 128, - max_turns: int = 50, + context_length: int, + max_turns: int, ) -> None: """Run a single investigation agent. From 1ea3858fd53ce5a7ff3fac5917dd6c56b51f1601 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:26:58 +0000 Subject: [PATCH 2/3] Deduplicate MAX_OUTPUT_NODES_PER_POS constant Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/graphs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index a51b1649c..baefbaac3 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -19,6 +19,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( + MAX_OUTPUT_NODES_PER_POS, Edge, compute_prompt_attributions, compute_prompt_attributions_optimized, @@ -211,9 +212,6 @@ class CompleteMessageWithOptimization(BaseModel): ProgressCallback = Callable[[int, int, str], None] -MAX_OUTPUT_NODES_PER_POS = 15 - - def _build_out_probs( ci_masked_out_logits: torch.Tensor, target_out_logits: torch.Tensor, From 05453462db39d53928440f25860611051dc435f3 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 16:21:26 +0000 Subject: [PATCH 3/3] Simplify investigate module: single inv_id arg, fail-fast patterns - run_agent reads all config from metadata.json instead of duplicating as CLI args (wandb_path, context_length, max_turns) - wait_for_backend raises directly instead of returning bool - _format_model_info accesses keys directly instead of .get() fallbacks Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/investigate/agent_prompt.py | 36 +++++++++++---------------- spd/investigate/scripts/run_agent.py | 37 +++++++++------------------- spd/investigate/scripts/run_slurm.py | 8 +----- 3 files changed, 26 insertions(+), 55 deletions(-) diff --git a/spd/investigate/agent_prompt.py b/spd/investigate/agent_prompt.py index d53a47ac3..b7e182bb7 100644 --- a/spd/investigate/agent_prompt.py +++ b/spd/investigate/agent_prompt.py @@ -161,28 +161,20 @@ def _format_model_info(model_info: dict[str, Any]) -> str: """Format model architecture info for inclusion in the agent prompt.""" - parts = [f"- **Architecture**: {model_info.get('summary', 'Unknown')}"] - - target_config = model_info.get("target_model_config") - if target_config: - if "n_layer" in target_config: - parts.append(f"- **Layers**: {target_config['n_layer']}") - if "n_embd" in target_config: - parts.append(f"- **Hidden dim**: {target_config['n_embd']}") - if "vocab_size" in target_config: - parts.append(f"- **Vocab size**: {target_config['vocab_size']}") - if "n_ctx" in target_config: - parts.append(f"- **Context length**: {target_config['n_ctx']}") - - topology = model_info.get("topology") - if topology and topology.get("block_structure"): - block = topology["block_structure"][0] - attn = ", ".join(block.get("attn_projections", [])) - ffn = ", ".join(block.get("ffn_projections", [])) - parts.append(f"- **Attention projections**: {attn}") - parts.append(f"- **FFN projections**: {ffn}") - - return "\n".join(parts) + target_config = model_info["target_model_config"] + topology = model_info["topology"] + block = topology["block_structure"][0] + + return "\n".join( + [ + f"- **Architecture**: {model_info['summary']}", + f"- **Layers**: {target_config['n_layer']}", + f"- **Hidden dim**: {target_config['n_embd']}", + f"- **Vocab size**: {target_config['vocab_size']}", + f"- **Attention projections**: {', '.join(block['attn_projections'])}", + f"- **FFN projections**: {', '.join(block['ffn_projections'])}", + ] + ) def get_agent_prompt( diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py index 1e35888bb..33ccf60d7 100644 --- a/spd/investigate/scripts/run_agent.py +++ b/spd/investigate/scripts/run_agent.py @@ -67,19 +67,19 @@ def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: ) -def wait_for_backend(port: int, timeout: float = 120.0) -> bool: - """Wait for the backend to become healthy.""" +def wait_for_backend(port: int, timeout: float = 120.0) -> None: + """Wait for the backend to become healthy. Raises on timeout.""" url = f"http://localhost:{port}/api/health" start = time.time() while time.time() - start < timeout: try: resp = requests.get(url, timeout=5) if resp.status_code == 200: - return True + return except requests.exceptions.ConnectionError: pass time.sleep(1) - return False + raise RuntimeError(f"Backend on port {port} failed to start within {timeout}s") def load_run(port: int, wandb_path: str, context_length: int) -> None: @@ -106,26 +106,16 @@ def log_event(events_path: Path, event: InvestigationEvent) -> None: f.write(event.model_dump_json() + "\n") -def run_agent( - wandb_path: str, - inv_id: str, - context_length: int, - max_turns: int, -) -> None: - """Run a single investigation agent. - - Args: - wandb_path: WandB path of the SPD run. - inv_id: Unique identifier for this investigation. - context_length: Context length for prompts. - max_turns: Maximum agentic turns before stopping (prevents runaway agents). - """ +def run_agent(inv_id: str) -> None: + """Run a single investigation agent. All config read from metadata.json.""" inv_dir = get_investigation_output_dir(inv_id) assert inv_dir.exists(), f"Investigation directory does not exist: {inv_dir}" - # Read prompt from metadata metadata: dict[str, Any] = json.loads((inv_dir / "metadata.json").read_text()) - prompt = metadata["prompt"] + wandb_path: str = metadata["wandb_path"] + prompt: str = metadata["prompt"] + context_length: int = metadata["context_length"] + max_turns: int = metadata["max_turns"] write_claude_settings(inv_dir) @@ -192,12 +182,7 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: try: logger.info(f"[{inv_id}] Waiting for backend...") - if not wait_for_backend(port): - log_event( - events_path, - InvestigationEvent(event_type="error", message="Backend failed to start"), - ) - raise RuntimeError("Backend failed to start") + wait_for_backend(port) logger.info(f"[{inv_id}] Backend ready, loading run...") log_event( diff --git a/spd/investigate/scripts/run_slurm.py b/spd/investigate/scripts/run_slurm.py index 703ed2f78..eb0e58cf3 100644 --- a/spd/investigate/scripts/run_slurm.py +++ b/spd/investigate/scripts/run_slurm.py @@ -62,13 +62,7 @@ def launch_investigation( } (output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2)) - cmd = ( - f"{sys.executable} -m spd.investigate.scripts.run_agent " - f'"{wandb_path}" ' - f"--inv_id {inv_id} " - f"--context_length {context_length} " - f"--max_turns {max_turns}" - ) + cmd = f"{sys.executable} -m spd.investigate.scripts.run_agent {inv_id}" slurm_config = SlurmConfig( job_name=job_name,