Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions spd/app/backend/routers/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from spd.app.backend.app_tokenizer import AppTokenizer
from spd.app.backend.compute import (
MAX_OUTPUT_NODES_PER_POS,
Edge,
compute_prompt_attributions,
compute_prompt_attributions_optimized,
Expand Down Expand Up @@ -211,9 +212,6 @@ class CompleteMessageWithOptimization(BaseModel):
ProgressCallback = Callable[[int, int, str], None]


MAX_OUTPUT_NODES_PER_POS = 15


def _build_out_probs(
ci_masked_out_logits: torch.Tensor,
target_out_logits: torch.Tensor,
Expand Down
36 changes: 14 additions & 22 deletions spd/investigate/agent_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,28 +161,20 @@

def _format_model_info(model_info: dict[str, Any]) -> str:
"""Format model architecture info for inclusion in the agent prompt."""
parts = [f"- **Architecture**: {model_info.get('summary', 'Unknown')}"]

target_config = model_info.get("target_model_config")
if target_config:
if "n_layer" in target_config:
parts.append(f"- **Layers**: {target_config['n_layer']}")
if "n_embd" in target_config:
parts.append(f"- **Hidden dim**: {target_config['n_embd']}")
if "vocab_size" in target_config:
parts.append(f"- **Vocab size**: {target_config['vocab_size']}")
if "n_ctx" in target_config:
parts.append(f"- **Context length**: {target_config['n_ctx']}")

topology = model_info.get("topology")
if topology and topology.get("block_structure"):
block = topology["block_structure"][0]
attn = ", ".join(block.get("attn_projections", []))
ffn = ", ".join(block.get("ffn_projections", []))
parts.append(f"- **Attention projections**: {attn}")
parts.append(f"- **FFN projections**: {ffn}")

return "\n".join(parts)
target_config = model_info["target_model_config"]
topology = model_info["topology"]
block = topology["block_structure"][0]

return "\n".join(
[
f"- **Architecture**: {model_info['summary']}",
f"- **Layers**: {target_config['n_layer']}",
f"- **Hidden dim**: {target_config['n_embd']}",
f"- **Vocab size**: {target_config['vocab_size']}",
f"- **Attention projections**: {', '.join(block['attn_projections'])}",
f"- **FFN projections**: {', '.join(block['ffn_projections'])}",
]
)


def get_agent_prompt(
Expand Down
37 changes: 11 additions & 26 deletions spd/investigate/scripts/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,19 @@ def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int:
)


def wait_for_backend(port: int, timeout: float = 120.0) -> bool:
"""Wait for the backend to become healthy."""
def wait_for_backend(port: int, timeout: float = 120.0) -> None:
"""Wait for the backend to become healthy. Raises on timeout."""
url = f"http://localhost:{port}/api/health"
start = time.time()
while time.time() - start < timeout:
try:
resp = requests.get(url, timeout=5)
if resp.status_code == 200:
return True
return
except requests.exceptions.ConnectionError:
pass
time.sleep(1)
return False
raise RuntimeError(f"Backend on port {port} failed to start within {timeout}s")


def load_run(port: int, wandb_path: str, context_length: int) -> None:
Expand All @@ -106,26 +106,16 @@ def log_event(events_path: Path, event: InvestigationEvent) -> None:
f.write(event.model_dump_json() + "\n")


def run_agent(
wandb_path: str,
inv_id: str,
context_length: int = 128,
max_turns: int = 50,
) -> None:
"""Run a single investigation agent.

Args:
wandb_path: WandB path of the SPD run.
inv_id: Unique identifier for this investigation.
context_length: Context length for prompts.
max_turns: Maximum agentic turns before stopping (prevents runaway agents).
"""
def run_agent(inv_id: str) -> None:
"""Run a single investigation agent. All config read from metadata.json."""
inv_dir = get_investigation_output_dir(inv_id)
assert inv_dir.exists(), f"Investigation directory does not exist: {inv_dir}"

# Read prompt from metadata
metadata: dict[str, Any] = json.loads((inv_dir / "metadata.json").read_text())
prompt = metadata["prompt"]
wandb_path: str = metadata["wandb_path"]
prompt: str = metadata["prompt"]
context_length: int = metadata["context_length"]
max_turns: int = metadata["max_turns"]

write_claude_settings(inv_dir)

Expand Down Expand Up @@ -192,12 +182,7 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None:

try:
logger.info(f"[{inv_id}] Waiting for backend...")
if not wait_for_backend(port):
log_event(
events_path,
InvestigationEvent(event_type="error", message="Backend failed to start"),
)
raise RuntimeError("Backend failed to start")
wait_for_backend(port)

logger.info(f"[{inv_id}] Backend ready, loading run...")
log_event(
Expand Down
8 changes: 1 addition & 7 deletions spd/investigate/scripts/run_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,7 @@ def launch_investigation(
}
(output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2))

cmd = (
f"{sys.executable} -m spd.investigate.scripts.run_agent "
f'"{wandb_path}" '
f"--inv_id {inv_id} "
f"--context_length {context_length} "
f"--max_turns {max_turns}"
)
cmd = f"{sys.executable} -m spd.investigate.scripts.run_agent {inv_id}"

slurm_config = SlurmConfig(
job_name=job_name,
Expand Down
Loading