From 91439209f1ec18f703f195f18900c503496f57e0 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 20:40:32 +0000 Subject: [PATCH 01/62] Add agent swarm for parallel behavior investigation Implements a SLURM-based system for launching parallel Claude Code agents that investigate behaviors in SPD model decompositions. Key components: - spd-swarm CLI: Submits SLURM array job for N agents - Each agent starts isolated app backend (unique port, separate database) - Detailed system prompt guides agents through investigation methodology - Findings written to append-only JSONL files (events.jsonl, explanations.jsonl) New files: - spd/agent_swarm/schemas.py: BehaviorExplanation, SwarmEvent schemas - spd/agent_swarm/agent_prompt.py: Detailed API and methodology instructions - spd/agent_swarm/scripts/run_slurm_cli.py: CLI entry point - spd/agent_swarm/scripts/run_slurm.py: SLURM submission logic - spd/agent_swarm/scripts/run_agent.py: Worker script for each job Also adds SPD_APP_DB_PATH env var support for database isolation. https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- CLAUDE.md | 24 ++ pyproject.toml | 1 + spd/agent_swarm/CLAUDE.md | 124 +++++++++ spd/agent_swarm/__init__.py | 22 ++ spd/agent_swarm/agent_prompt.py | 330 +++++++++++++++++++++++ spd/agent_swarm/schemas.py | 120 +++++++++ spd/agent_swarm/scripts/__init__.py | 1 + spd/agent_swarm/scripts/run_agent.py | 284 +++++++++++++++++++ spd/agent_swarm/scripts/run_slurm.py | 119 ++++++++ spd/agent_swarm/scripts/run_slurm_cli.py | 62 +++++ spd/app/backend/database.py | 18 +- 11 files changed, 1103 insertions(+), 2 deletions(-) create mode 100644 spd/agent_swarm/CLAUDE.md create mode 100644 spd/agent_swarm/__init__.py create mode 100644 spd/agent_swarm/agent_prompt.py create mode 100644 spd/agent_swarm/schemas.py create mode 100644 spd/agent_swarm/scripts/__init__.py create mode 100644 spd/agent_swarm/scripts/run_agent.py create mode 100644 spd/agent_swarm/scripts/run_slurm.py create mode 100644 spd/agent_swarm/scripts/run_slurm_cli.py diff --git a/CLAUDE.md b/CLAUDE.md index 30da03eb6..2e636885c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -156,6 +156,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── agent_swarm/ # Parallel agent investigation (see agent_swarm/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) @@ -195,6 +196,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-autointerp` | `spd/autointerp/scripts/cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | +| `spd-swarm` | `spd/agent_swarm/scripts/run_slurm_cli.py` | Launch parallel agent swarm | ### Files to Skip When Searching @@ -231,6 +233,9 @@ Use `spd/` as the search root (not repo root) to avoid noise. **Clustering Pipeline:** - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Agent Swarm Pipeline:** +- `spd-swarm` → `spd/agent_swarm/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/agent_swarm/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -277,6 +282,25 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Swarm for Parallel Investigation (`spd-swarm`) + +Launch a swarm of Claude Code agents to investigate behaviors in an SPD model: + +```bash +spd-swarm --n_agents 10 # Launch 10 parallel agents +spd-swarm --n_agents 5 --time 4:00:00 # Custom time limit +``` + +Each agent: +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates behaviors using the SPD app API +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/agent_swarm//task_*/explanations.jsonl` + +See `spd/agent_swarm/CLAUDE.md` for details. + ### Running on SLURM Cluster (`spd-run`) For the core team, `spd-run` provides full-featured SLURM orchestration: diff --git a/pyproject.toml b/pyproject.toml index 76a539454..24f47fe64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ spd-clustering = "spd.clustering.scripts.run_pipeline:cli" spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" +spd-swarm = "spd.agent_swarm.scripts.run_slurm_cli:cli" [build-system] requires = ["setuptools", "wheel"] diff --git a/spd/agent_swarm/CLAUDE.md b/spd/agent_swarm/CLAUDE.md new file mode 100644 index 000000000..ee2e89be2 --- /dev/null +++ b/spd/agent_swarm/CLAUDE.md @@ -0,0 +1,124 @@ +# Agent Swarm Module + +This module provides infrastructure for launching parallel SLURM-based Claude Code agents +that investigate behaviors in SPD model decompositions. + +## Overview + +The agent swarm system allows you to: +1. Launch many parallel agents (each as a SLURM job with 1 GPU) +2. Each agent runs an isolated app backend instance +3. Agents investigate behaviors using the SPD app API +4. Findings are written to append-only JSONL files + +## Usage + +```bash +# Launch 10 agents to investigate a decomposition +spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 10 + +# With custom settings +spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --context_length 64 --time 4:00:00 +``` + +## Architecture + +``` +spd/agent_swarm/ +├── __init__.py # Public exports +├── CLAUDE.md # This file +├── schemas.py # Pydantic models for outputs +├── agent_prompt.py # System prompt for agents +└── scripts/ + ├── __init__.py + ├── run_slurm_cli.py # CLI entry point (spd-swarm) + ├── run_slurm.py # SLURM submission logic + └── run_agent.py # Worker script (runs in each SLURM job) +``` + +## Output Structure + +``` +SPD_OUT_DIR/agent_swarm// +├── metadata.json # Swarm configuration +├── task_1/ +│ ├── events.jsonl # Progress and observations +│ ├── explanations.jsonl # Complete behavior explanations +│ ├── app.db # Isolated SQLite database +│ ├── agent_prompt.md # The prompt given to the agent +│ └── claude_output.txt # Raw Claude Code output +├── task_2/ +│ └── ... +└── task_N/ + └── ... +``` + +## Key Files + +| File | Purpose | +|------|---------| +| `schemas.py` | Defines `BehaviorExplanation`, `SwarmEvent`, `Evidence` schemas | +| `agent_prompt.py` | Contains detailed instructions for agents on using the API | +| `run_slurm.py` | Creates git snapshot, generates commands, submits SLURM array | +| `run_agent.py` | Starts backend, loads run, launches Claude Code | + +## Schemas + +### BehaviorExplanation +The primary output - documents a discovered behavior: +- `subject_prompt`: Prompt demonstrating the behavior +- `behavior_description`: What the model does +- `components_involved`: List of components and their roles +- `explanation`: How components work together +- `supporting_evidence`: Ablations, attributions, etc. +- `confidence`: high/medium/low +- `alternative_hypotheses`: Other considered explanations +- `limitations`: Known caveats + +### SwarmEvent +General logging: +- `event_type`: start, progress, observation, hypothesis, test_result, error, complete +- `timestamp`: When it occurred +- `message`: Human-readable description +- `details`: Structured data + +## Database Isolation + +Each agent gets its own SQLite database via the `SPD_APP_DB_PATH` environment variable. +This prevents conflicts when multiple agents run on the same machine. + +## Monitoring + +```bash +# Watch events from all agents +tail -f SPD_OUT_DIR/agent_swarm//task_*/events.jsonl + +# View all explanations +cat SPD_OUT_DIR/agent_swarm//task_*/explanations.jsonl | jq . + +# Check SLURM job status +squeue --me + +# View specific job logs +tail -f ~/slurm_logs/slurm-_.out +``` + +## Configuration + +CLI arguments: +- `wandb_path`: Required - WandB run path for the SPD decomposition +- `--n_agents`: Required - Number of parallel agents to launch +- `--context_length`: Token context length (default: 128) +- `--partition`: SLURM partition (default: h200-reserved) +- `--time`: Time limit per agent (default: 8:00:00) +- `--job_suffix`: Optional suffix for job names + +## Extending + +To modify agent behavior: +1. Edit `agent_prompt.py` to change investigation instructions +2. Update `schemas.py` to add new output fields +3. Modify `run_agent.py` to change the worker flow + +The agent prompt is the primary way to guide agent behavior - it contains +detailed API documentation and scientific methodology guidance. diff --git a/spd/agent_swarm/__init__.py b/spd/agent_swarm/__init__.py new file mode 100644 index 000000000..cac91de2d --- /dev/null +++ b/spd/agent_swarm/__init__.py @@ -0,0 +1,22 @@ +"""Agent Swarm: Parallel SLURM-based agent investigation of model behaviors. + +This module provides infrastructure for launching many parallel Claude Code agents, +each investigating behaviors in an SPD model decomposition. Each agent: +1. Starts an isolated app backend instance (separate database, unique port) +2. Receives detailed instructions on using the SPD app API +3. Investigates behaviors and writes findings to append-only JSONL files +""" + +from spd.agent_swarm.schemas import ( + BehaviorExplanation, + ComponentInfo, + Evidence, + SwarmEvent, +) + +__all__ = [ + "BehaviorExplanation", + "ComponentInfo", + "Evidence", + "SwarmEvent", +] diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py new file mode 100644 index 000000000..e3cf4fd21 --- /dev/null +++ b/spd/agent_swarm/agent_prompt.py @@ -0,0 +1,330 @@ +"""System prompt for SPD investigation agents. + +This module contains the detailed instructions given to each agent in the swarm. +The prompt explains how to use the SPD app API and the scientific methodology +for investigating model behaviors. +""" + +AGENT_SYSTEM_PROMPT = """ +# SPD Behavior Investigation Agent + +You are a research agent investigating behaviors in a neural network model decomposition. +Your goal is to find interesting behaviors, understand how components interact to produce +them, and document your findings as explanations. + +## Your Mission + +You are part of a swarm of agents, each independently investigating behaviors in the same +model. Your task is to: + +1. **Find a behavior**: Discover a prompt where the model does something interesting + (e.g., predicts the correct gendered pronoun, completes a pattern, etc.) + +2. **Understand the mechanism**: Figure out which components are involved and how they + work together to produce the behavior + +3. **Document your findings**: Write a clear explanation with supporting evidence + +## The SPD App Backend + +You have access to an SPD (Stochastic Parameter Decomposition) app backend running at: +`http://localhost:{port}` + +This app provides APIs for: +- Loading decomposed models +- Computing attribution graphs showing how components interact +- Optimizing sparse circuits for specific behaviors +- Running interventions (ablations) to test hypotheses +- Viewing component interpretations and correlations +- Searching the training dataset + +## API Reference + +### Health Check +```bash +curl http://localhost:{port}/api/health +# Returns: {{"status": "ok"}} +``` + +### Load a Run (ALREADY DONE FOR YOU) +The run is pre-loaded. Check status with: +```bash +curl http://localhost:{port}/api/status +``` + +### Create a Custom Prompt +To analyze a specific prompt: +```bash +curl -X POST "http://localhost:{port}/api/prompts/custom?text=The%20boy%20ate%20his" +# Returns: {{"id": , "token_ids": [...], "tokens": [...], "preview": "...", "next_token_probs": [...]}} +``` + +### Compute Optimized Attribution Graph (MOST IMPORTANT) +This optimizes a sparse circuit that achieves a behavior: +```bash +curl -X POST "http://localhost:{port}/api/graphs/optimized/stream?prompt_id=&loss_type=ce&loss_position=&label_token=&steps=100&imp_min_coeff=0.1&pnorm=0.5&mask_type=hard&loss_coeff=1.0&ci_threshold=0.01&normalize=target" +# Streams SSE events, final event has type="complete" with graph data +``` + +Parameters: +- `prompt_id`: ID from creating custom prompt +- `loss_type`: "ce" for cross-entropy (predicting specific token) or "kl" (matching full distribution) +- `loss_position`: Token position to optimize (0-indexed, usually last position) +- `label_token`: Token ID to predict (for CE loss) +- `steps`: Optimization steps (50-200 typical) +- `imp_min_coeff`: Importance minimization coefficient (0.05-0.3) +- `pnorm`: P-norm for sparsity (0.3-1.0, lower = sparser) +- `mask_type`: "hard" for binary masks, "soft" for continuous +- `ci_threshold`: Threshold for including nodes in graph (0.01-0.1) +- `normalize`: "target" normalizes by target layer, "none" for raw values + +### Get Component Interpretations +```bash +curl "http://localhost:{port}/api/correlations/interpretations" +# Returns: {{"h.0.mlp.c_fc:5": {{"label": "...", "confidence": "high"}}, ...}} +``` + +Get full interpretation details: +```bash +curl "http://localhost:{port}/api/correlations/interpretations/h.0.mlp.c_fc/5" +# Returns: {{"reasoning": "...", "prompt": "..."}} +``` + +### Get Component Token Statistics +```bash +curl "http://localhost:{port}/api/correlations/token_stats/h.0.mlp.c_fc/5?top_k=20" +# Returns input/output token associations +``` + +### Get Component Correlations +```bash +curl "http://localhost:{port}/api/correlations/components/h.0.mlp.c_fc/5?top_k=20" +# Returns components that frequently co-activate +``` + +### Run Intervention (Ablation) +Test a hypothesis by running the model with only selected components active: +```bash +curl -X POST "http://localhost:{port}/api/intervention/run" \\ + -H "Content-Type: application/json" \\ + -d '{{"graph_id": , "text": "The boy ate his", "selected_nodes": ["h.0.mlp.c_fc:3:5", "h.1.attn.o_proj:3:10"], "top_k": 10}}' +# Returns predictions with only selected components active vs full model +``` + +Node format: "layer:seq_pos:component_idx" +- `layer`: e.g., "h.0.mlp.c_fc", "h.1.attn.o_proj" +- `seq_pos`: Position in sequence (0-indexed) +- `component_idx`: Component index within layer + +### Search Dataset +Find prompts with specific patterns: +```bash +curl -X POST "http://localhost:{port}/api/dataset/search?query=she%20said&split=train" +curl "http://localhost:{port}/api/dataset/results?page=1&page_size=20" +``` + +### Get Random Samples with Loss +Find high/low loss examples: +```bash +curl "http://localhost:{port}/api/dataset/random_with_loss?n_samples=20&seed=42" +``` + +### Probe Component Activation +See how a component responds to arbitrary text: +```bash +curl -X POST "http://localhost:{port}/api/activation_contexts/probe" \\ + -H "Content-Type: application/json" \\ + -d '{{"text": "The boy ate his", "layer": "h.0.mlp.c_fc", "component_idx": 5}}' +# Returns CI values and activations at each position +``` + +### Get Dataset Attributions +See which components influence each other across the training data: +```bash +curl "http://localhost:{port}/api/dataset_attributions/h.0.mlp.c_fc/5?k=10" +# Returns positive/negative sources and targets +``` + +## Investigation Methodology + +### Step 1: Find an Interesting Behavior + +Start by exploring the model's behavior: + +1. **Search for patterns**: Use `/api/dataset/search` to find prompts with specific + linguistic patterns (pronouns, verb conjugations, completions, etc.) + +2. **Look at high-loss examples**: Use `/api/dataset/random_with_loss` to find where + the model struggles or succeeds + +3. **Create test prompts**: Use `/api/prompts/custom` to create prompts that test + specific capabilities + +Good behaviors to investigate: +- Gendered pronoun prediction ("The doctor said she" vs "The doctor said he") +- Subject-verb agreement ("The cats are" vs "The cat is") +- Pattern completion ("1, 2, 3," → "4") +- Semantic associations ("The capital of France is" → "Paris") +- Grammatical structure (completing sentences correctly) + +### Step 2: Optimize a Sparse Circuit + +Once you have a behavior: + +1. **Create the prompt** via `/api/prompts/custom` + +2. **Identify the target token**: What token should be predicted? Get its ID from + the tokenizer or from the prompt creation response. + +3. **Run optimization** via `/api/graphs/optimized/stream`: + - Use `loss_type=ce` with the target token + - Set `loss_position` to the position where prediction matters + - Start with `imp_min_coeff=0.1`, `pnorm=0.5`, `steps=100` + - Use `ci_threshold=0.01` to see active components + +4. **Examine the graph**: The response shows: + - `nodeCiVals`: Which components are active (high CI = important) + - `edges`: How components connect (gradient flow) + - `outputProbs`: Model predictions + +### Step 3: Understand Component Roles + +For each important component in the graph: + +1. **Check the interpretation**: Use `/api/correlations/interpretations//` + to see if we already have an idea what this component does + +2. **Look at token stats**: Use `/api/correlations/token_stats//` to see + what tokens activate this component (input) and what it predicts (output) + +3. **Check correlations**: Use `/api/correlations/components//` to see + what other components co-activate + +4. **Probe on variations**: Use `/api/activation_contexts/probe` to see how the + component responds to related prompts + +### Step 4: Test with Ablations + +Form hypotheses and test them: + +1. **Hypothesis**: "Component X stores information about gender" + +2. **Test**: Run intervention with and without component X + - If prediction changes as expected → supports hypothesis + - If no change → component may not be necessary for this + - If unexpected change → revise hypothesis + +3. **Control**: Try ablating other components to ensure specificity + +### Step 5: Document Your Findings + +Write a `BehaviorExplanation` with: +- Clear subject prompt +- Description of the behavior +- Components and their roles +- How they work together +- Supporting evidence from ablations/attributions +- Confidence level +- Alternative hypotheses you considered +- Limitations + +## Scientific Principles + +### Be Epistemologically Humble +- Your first hypothesis is probably wrong or incomplete +- Always consider alternative explanations +- A single confirming example doesn't prove a theory +- Look for disconfirming evidence + +### Be Bayesian +- Start with priors from component interpretations +- Update beliefs based on evidence +- Consider the probability of the evidence under different hypotheses +- Don't anchor too strongly on initial observations + +### Triangulate Evidence +- Don't rely on a single type of evidence +- Ablation results + attribution patterns + token stats together are stronger +- Look for convergent evidence from multiple sources + +### Document Uncertainty +- Be explicit about what you're confident in vs. uncertain about +- Note when evidence is weak or ambiguous +- Identify what additional tests would strengthen the explanation + +## Output Format + +Write your findings by appending to the output files: + +### events.jsonl +Log progress and observations: +```json +{{"event_type": "observation", "message": "Component h.0.mlp.c_fc:5 has high CI when subject is male", "details": {{"ci_value": 0.85}}, "timestamp": "..."}} +``` + +### explanations.jsonl +When you have a complete explanation: +```json +{{ + "subject_prompt": "The boy ate his lunch", + "behavior_description": "Correctly predicts gendered pronoun 'his' after male subject", + "components_involved": [ + {{"component_key": "h.0.mlp.c_fc:5", "role": "Encodes subject gender as male", "interpretation": "male names/subjects"}}, + {{"component_key": "h.1.attn.o_proj:10", "role": "Transmits gender information to output", "interpretation": null}} + ], + "explanation": "Component h.0.mlp.c_fc:5 activates on male subjects and stores gender information...", + "supporting_evidence": [ + {{"evidence_type": "ablation", "description": "Removing component causes prediction to change from 'his' to 'her'", "details": {{"without_component": {{"his": 0.1, "her": 0.6}}, "with_component": {{"his": 0.8, "her": 0.1}}}}}} + ], + "confidence": "medium", + "alternative_hypotheses": ["Component might encode broader concept of masculine entities, not just humans"], + "limitations": ["Only tested on simple subject-pronoun sentences"] +}} +``` + +## Getting Started + +1. Check the current status: `curl http://localhost:{port}/api/status` +2. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` +3. Search for interesting prompts or create your own +4. Optimize a sparse circuit for a behavior you find +5. Investigate the components involved +6. Test hypotheses with ablations +7. Document your findings + +Remember: You are exploring! Not every investigation will lead to a clear explanation. +Document what you learn, even if it's "this was more complicated than expected." + +Good luck, and happy investigating! +""" + + +def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) -> str: + """Generate the full agent prompt with runtime parameters filled in. + + Args: + port: The port the backend is running on. + wandb_path: The WandB path of the loaded run. + task_id: The SLURM task ID for this agent. + output_dir: Path to the agent's output directory. + + Returns: + The complete agent prompt with parameters substituted. + """ + prompt = AGENT_SYSTEM_PROMPT.format(port=port) + + runtime_context = f""" +## Runtime Context + +- **Backend URL**: http://localhost:{port} +- **Loaded Run**: {wandb_path} +- **Task ID**: {task_id} +- **Output Directory**: {output_dir} + +Your output files: +- `{output_dir}/events.jsonl` - Log events and observations here +- `{output_dir}/explanations.jsonl` - Write complete explanations here + +To append to these files, use the Write tool or shell redirection. +""" + return prompt + runtime_context diff --git a/spd/agent_swarm/schemas.py b/spd/agent_swarm/schemas.py new file mode 100644 index 000000000..d554db855 --- /dev/null +++ b/spd/agent_swarm/schemas.py @@ -0,0 +1,120 @@ +"""Schemas for agent swarm outputs. + +All agent outputs are append-only JSONL files. Each line is a JSON object +conforming to one of the schemas defined here. +""" + +from datetime import UTC, datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ComponentInfo(BaseModel): + """Information about a component involved in a behavior.""" + + component_key: str = Field( + ..., + description="Component key in format 'layer:component_idx' (e.g., 'h.0.mlp.c_fc:5')", + ) + role: str = Field( + ..., + description="The role this component plays in the behavior (e.g., 'stores subject gender')", + ) + interpretation: str | None = Field( + default=None, + description="Auto-interp label for this component if available", + ) + + +class Evidence(BaseModel): + """A piece of supporting evidence for an explanation.""" + + evidence_type: Literal["ablation", "attribution", "activation_pattern", "correlation", "other"] + description: str = Field( + ..., + description="Description of the evidence", + ) + details: dict[str, Any] = Field( + default_factory=dict, + description="Additional structured details (e.g., ablation results, attribution values)", + ) + + +class BehaviorExplanation(BaseModel): + """A candidate explanation for a behavior discovered by an agent. + + This is the primary output schema for agent investigations. Each explanation + describes a behavior (demonstrated by a subject prompt), the components involved, + and supporting evidence. + """ + + subject_prompt: str = Field( + ..., + description="A prompt that demonstrates the behavior being explained", + ) + behavior_description: str = Field( + ..., + description="Clear description of the behavior (e.g., 'correctly predicts gendered pronoun')", + ) + components_involved: list[ComponentInfo] = Field( + ..., + description="List of components involved in this behavior and their roles", + ) + explanation: str = Field( + ..., + description="Explanation of how the components work together to produce the behavior", + ) + supporting_evidence: list[Evidence] = Field( + default_factory=list, + description="Evidence supporting this explanation (ablations, attributions, etc.)", + ) + confidence: Literal["high", "medium", "low"] = Field( + ..., + description="Agent's confidence in this explanation", + ) + alternative_hypotheses: list[str] = Field( + default_factory=list, + description="Alternative hypotheses that were considered but not fully supported", + ) + limitations: list[str] = Field( + default_factory=list, + description="Known limitations of this explanation", + ) + + +class SwarmEvent(BaseModel): + """A generic event logged by an agent during investigation. + + Used for logging progress, observations, and other non-explanation events. + """ + + event_type: Literal[ + "start", + "progress", + "observation", + "hypothesis", + "test_result", + "explanation", + "error", + "complete", + ] + timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) + message: str + details: dict[str, Any] = Field(default_factory=dict) + + +class AgentOutput(BaseModel): + """Container for all outputs from a single agent run. + + Written to the agent's output directory as output.json upon completion. + """ + + task_id: int + wandb_path: str + started_at: datetime + completed_at: datetime | None = None + explanations: list[BehaviorExplanation] = Field(default_factory=list) + events: list[SwarmEvent] = Field(default_factory=list) + status: Literal["running", "completed", "failed"] = "running" + error: str | None = None diff --git a/spd/agent_swarm/scripts/__init__.py b/spd/agent_swarm/scripts/__init__.py new file mode 100644 index 000000000..9d0e8ed1b --- /dev/null +++ b/spd/agent_swarm/scripts/__init__.py @@ -0,0 +1 @@ +"""Agent swarm SLURM scripts.""" diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py new file mode 100644 index 000000000..83a32752c --- /dev/null +++ b/spd/agent_swarm/scripts/run_agent.py @@ -0,0 +1,284 @@ +"""Worker script that runs inside each SLURM job. + +This script: +1. Creates an isolated output directory for this agent +2. Starts the app backend with an isolated database +3. Loads the SPD run +4. Launches Claude Code with investigation instructions +5. Handles cleanup on exit +""" + +import os +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from types import FrameType + +import fire +import requests + +from spd.agent_swarm.agent_prompt import get_agent_prompt +from spd.agent_swarm.schemas import SwarmEvent +from spd.agent_swarm.scripts.run_slurm import get_swarm_output_dir +from spd.log import logger + + +def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: + """Find an available port starting from start_port.""" + for offset in range(max_attempts): + port = start_port + offset + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except OSError: + continue + raise RuntimeError( + f"Could not find available port in range {start_port}-{start_port + max_attempts}" + ) + + +def wait_for_backend(port: int, timeout: float = 120.0) -> bool: + """Wait for the backend to become healthy.""" + url = f"http://localhost:{port}/api/health" + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + return True + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + return False + + +def load_run(port: int, wandb_path: str, context_length: int) -> bool: + """Load the SPD run into the backend.""" + url = f"http://localhost:{port}/api/runs/load" + params = {"wandb_path": wandb_path, "context_length": context_length} + try: + resp = requests.post(url, params=params, timeout=300) + return resp.status_code == 200 + except Exception as e: + logger.error(f"Failed to load run: {e}") + return False + + +def log_event(events_path: Path, event: SwarmEvent) -> None: + """Append an event to the events log.""" + with open(events_path, "a") as f: + f.write(event.model_dump_json() + "\n") + + +def run_agent( + wandb_path: str, + task_id: int, + swarm_id: str, + context_length: int = 128, +) -> None: + """Run a single investigation agent. + + Args: + wandb_path: WandB path of the SPD run. + task_id: SLURM task ID (1-indexed). + swarm_id: Unique identifier for this swarm. + context_length: Context length for prompts. + """ + # Setup output directory + swarm_dir = get_swarm_output_dir(swarm_id) + task_dir = swarm_dir / f"task_{task_id}" + task_dir.mkdir(parents=True, exist_ok=True) + + events_path = task_dir / "events.jsonl" + explanations_path = task_dir / "explanations.jsonl" + db_path = task_dir / "app.db" + + # Initialize empty output files + explanations_path.touch() + + log_event( + events_path, + SwarmEvent( + event_type="start", + message=f"Agent {task_id} starting", + details={"wandb_path": wandb_path, "swarm_id": swarm_id}, + ), + ) + + # Find available port (offset by task_id to reduce collisions) + port = find_available_port(start_port=8000 + (task_id - 1) * 10) + logger.info(f"[Task {task_id}] Using port {port}") + + log_event( + events_path, + SwarmEvent( + event_type="progress", + message=f"Starting backend on port {port}", + details={"port": port, "db_path": str(db_path)}, + ), + ) + + # Start backend with isolated database + env = os.environ.copy() + env["SPD_APP_DB_PATH"] = str(db_path) + + backend_cmd = [ + sys.executable, + "-m", + "spd.app.backend.server", + "--port", + str(port), + ] + + backend_proc = subprocess.Popen( + backend_cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + # Setup cleanup handler + def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: + _ = frame # Unused but required by signal handler signature + logger.info(f"[Task {task_id}] Cleaning up...") + if backend_proc.poll() is None: + backend_proc.terminate() + try: + backend_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + backend_proc.kill() + if signum is not None: + sys.exit(1) + + signal.signal(signal.SIGTERM, cleanup) + signal.signal(signal.SIGINT, cleanup) + + try: + # Wait for backend to be ready + logger.info(f"[Task {task_id}] Waiting for backend...") + if not wait_for_backend(port): + log_event( + events_path, + SwarmEvent( + event_type="error", + message="Backend failed to start", + ), + ) + raise RuntimeError("Backend failed to start") + + logger.info(f"[Task {task_id}] Backend ready, loading run...") + log_event( + events_path, + SwarmEvent( + event_type="progress", + message="Backend ready, loading run", + ), + ) + + # Load the SPD run + if not load_run(port, wandb_path, context_length): + log_event( + events_path, + SwarmEvent( + event_type="error", + message="Failed to load run", + details={"wandb_path": wandb_path}, + ), + ) + raise RuntimeError(f"Failed to load run: {wandb_path}") + + logger.info(f"[Task {task_id}] Run loaded, launching Claude Code...") + log_event( + events_path, + SwarmEvent( + event_type="progress", + message="Run loaded, launching Claude Code agent", + ), + ) + + # Generate agent prompt + agent_prompt = get_agent_prompt( + port=port, + wandb_path=wandb_path, + task_id=task_id, + output_dir=str(task_dir), + ) + + # Write prompt to file for reference + prompt_path = task_dir / "agent_prompt.md" + prompt_path.write_text(agent_prompt) + + # Launch Claude Code + # The agent will investigate behaviors and write to the output files + claude_cmd = [ + "claude", + "--print", # Print output to stdout + "--dangerously-skip-permissions", # Allow file writes + ] + + logger.info(f"[Task {task_id}] Starting Claude Code session...") + + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + cwd=str(task_dir), + ) + + # Send the investigation prompt + investigation_request = f""" +{agent_prompt} + +--- + +Please begin your investigation. Start by checking the backend status and exploring +available component interpretations. Then find an interesting behavior and investigate it. + +Remember to log your progress to events.jsonl and write complete explanations to +explanations.jsonl when you discover something. +""" + + stdout, _ = claude_proc.communicate(input=investigation_request) + + # Save Claude's output + output_path = task_dir / "claude_output.txt" + output_path.write_text(stdout or "") + + log_event( + events_path, + SwarmEvent( + event_type="complete", + message="Investigation complete", + details={"exit_code": claude_proc.returncode}, + ), + ) + + logger.info(f"[Task {task_id}] Investigation complete") + + except Exception as e: + log_event( + events_path, + SwarmEvent( + event_type="error", + message=f"Agent failed: {e}", + details={"error_type": type(e).__name__}, + ), + ) + logger.error(f"[Task {task_id}] Failed: {e}") + raise + finally: + cleanup() + + +def cli() -> None: + fire.Fire(run_agent) + + +if __name__ == "__main__": + cli() diff --git a/spd/agent_swarm/scripts/run_slurm.py b/spd/agent_swarm/scripts/run_slurm.py new file mode 100644 index 000000000..8b99e253d --- /dev/null +++ b/spd/agent_swarm/scripts/run_slurm.py @@ -0,0 +1,119 @@ +"""SLURM launcher for agent swarm. + +Submits a SLURM array job where each task runs an independent agent investigating +behaviors in an SPD model decomposition. + +Each agent: +1. Starts an isolated app backend (unique port, isolated database) +2. Launches Claude Code with investigation instructions +3. Writes findings to append-only JSONL files +""" + +import secrets +from pathlib import Path + +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.git_utils import create_git_snapshot +from spd.utils.slurm import ( + SlurmArrayConfig, + generate_array_script, + submit_slurm_job, +) + + +def get_swarm_output_dir(swarm_id: str) -> Path: + """Get the output directory for a swarm run.""" + return SPD_OUT_DIR / "agent_swarm" / swarm_id + + +def launch_agent_swarm( + wandb_path: str, + n_agents: int, + context_length: int = 128, + partition: str = "h200-reserved", + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a swarm of agents to investigate behaviors. + + Args: + wandb_path: WandB run path for the SPD decomposition. + n_agents: Number of agents to launch. + context_length: Context length for prompts. + partition: SLURM partition. + time: Time limit per agent. + job_suffix: Optional suffix for job names. + """ + swarm_id = f"swarm-{secrets.token_hex(4)}" + output_dir = get_swarm_output_dir(swarm_id) + output_dir.mkdir(parents=True, exist_ok=True) + + snapshot_branch, commit_hash = create_git_snapshot(swarm_id) + logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + + suffix = f"-{job_suffix}" if job_suffix else "" + job_name = f"spd-swarm{suffix}" + + # Write swarm metadata + metadata_path = output_dir / "metadata.json" + import json + + metadata = { + "swarm_id": swarm_id, + "wandb_path": wandb_path, + "n_agents": n_agents, + "context_length": context_length, + "snapshot_branch": snapshot_branch, + "commit_hash": commit_hash, + } + metadata_path.write_text(json.dumps(metadata, indent=2)) + + # Build worker commands (SLURM arrays are 1-indexed) + worker_commands = [] + for task_id in range(1, n_agents + 1): + cmd = ( + f"python -m spd.agent_swarm.scripts.run_agent " + f'"{wandb_path}" ' + f"--task_id {task_id} " + f"--swarm_id {swarm_id} " + f"--context_length {context_length}" + ) + worker_commands.append(cmd) + + array_config = SlurmArrayConfig( + job_name=job_name, + partition=partition, + n_gpus=1, + time=time, + snapshot_branch=snapshot_branch, + max_concurrent_tasks=min(n_agents, 8), # Respect cluster limits + ) + array_script = generate_array_script(array_config, worker_commands) + array_result = submit_slurm_job( + array_script, + "agent_swarm", + is_array=True, + n_array_tasks=n_agents, + ) + + logger.section("Agent swarm jobs submitted!") + logger.values( + { + "Swarm ID": swarm_id, + "WandB path": wandb_path, + "N agents": n_agents, + "Context length": context_length, + "Output directory": str(output_dir), + "Snapshot": f"{snapshot_branch} ({commit_hash[:8]})", + "Job ID": array_result.job_id, + "Logs": array_result.log_pattern, + "Script": str(array_result.script_path), + } + ) + logger.info("") + logger.info("Monitor progress:") + logger.info(f" tail -f {output_dir}/task_*/events.jsonl") + logger.info("") + logger.info("View explanations:") + logger.info(f" cat {output_dir}/task_*/explanations.jsonl | jq .") diff --git a/spd/agent_swarm/scripts/run_slurm_cli.py b/spd/agent_swarm/scripts/run_slurm_cli.py new file mode 100644 index 000000000..20a6d8457 --- /dev/null +++ b/spd/agent_swarm/scripts/run_slurm_cli.py @@ -0,0 +1,62 @@ +"""CLI entry point for agent swarm SLURM launcher. + +Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. + +Usage: + spd-swarm --n_agents 10 + spd-swarm --n_agents 5 --context_length 128 + +Examples: + # Launch 10 agents to investigate a decomposition + spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 10 + + # Launch 5 agents with custom context length and time limit + spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --context_length 64 --time 4:00:00 +""" + +import fire + +from spd.settings import DEFAULT_PARTITION_NAME + + +def main( + wandb_path: str, + n_agents: int, + context_length: int = 128, + partition: str = DEFAULT_PARTITION_NAME, + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a swarm of agents to investigate behaviors in an SPD model. + + Each agent runs in its own SLURM job with an isolated app backend instance. + Agents use Claude Code to investigate behaviors and write findings to + append-only JSONL files. + + Args: + wandb_path: WandB run path for the SPD decomposition to investigate. + Format: "entity/project/runs/run_id" or "wandb:entity/project/run_id" + n_agents: Number of agents to launch (each gets 1 GPU). + context_length: Context length for prompts (default 128). + partition: SLURM partition name. + time: Job time limit per agent (default 8 hours). + job_suffix: Optional suffix for SLURM job names. + """ + from spd.agent_swarm.scripts.run_slurm import launch_agent_swarm + + launch_agent_swarm( + wandb_path=wandb_path, + n_agents=n_agents, + context_length=context_length, + partition=partition, + time=time, + job_suffix=job_suffix, + ) + + +def cli() -> None: + fire.Fire(main) + + +if __name__ == "__main__": + cli() diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index e5ee4db59..1ee06ce5a 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -8,6 +8,7 @@ import hashlib import json +import os import sqlite3 from dataclasses import asdict from pathlib import Path @@ -23,8 +24,21 @@ GraphType = Literal["standard", "optimized", "manual"] # Persistent data directories +# Can be overridden via SPD_APP_DB_PATH environment variable for isolation _APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path, respecting SPD_APP_DB_PATH env var.""" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH + + +# For backwards compatibility +DEFAULT_DB_PATH = _DEFAULT_DB_PATH class Run(BaseModel): @@ -107,7 +121,7 @@ class PromptAttrDB: """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None From 498d459e89f1360464dbdfce4a8681e9ab75f093 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 21:04:25 +0000 Subject: [PATCH 02/62] Stream Claude Code output to file in real-time Previously used communicate() which buffers all output until process completes. Now streams directly to claude_output.txt so you can monitor agent activity with: tail -f /claude_output.txt https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- spd/agent_swarm/scripts/run_agent.py | 39 +++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index 83a32752c..072d8a8c6 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -212,8 +212,8 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: prompt_path = task_dir / "agent_prompt.md" prompt_path.write_text(agent_prompt) - # Launch Claude Code - # The agent will investigate behaviors and write to the output files + # Launch Claude Code with output streaming to file + claude_output_path = task_dir / "claude_output.txt" claude_cmd = [ "claude", "--print", # Print output to stdout @@ -221,18 +221,21 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: ] logger.info(f"[Task {task_id}] Starting Claude Code session...") + logger.info(f"[Task {task_id}] Monitor with: tail -f {claude_output_path}") + + # Open output file for streaming writes + with open(claude_output_path, "w") as output_file: + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=output_file, + stderr=subprocess.STDOUT, + text=True, + cwd=str(task_dir), + ) - claude_proc = subprocess.Popen( - claude_cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - cwd=str(task_dir), - ) - - # Send the investigation prompt - investigation_request = f""" + # Send the investigation prompt and close stdin + investigation_request = f""" {agent_prompt} --- @@ -243,12 +246,12 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: Remember to log your progress to events.jsonl and write complete explanations to explanations.jsonl when you discover something. """ + assert claude_proc.stdin is not None + claude_proc.stdin.write(investigation_request) + claude_proc.stdin.close() - stdout, _ = claude_proc.communicate(input=investigation_request) - - # Save Claude's output - output_path = task_dir / "claude_output.txt" - output_path.write_text(stdout or "") + # Wait for Claude to finish (output streams to file in real-time) + claude_proc.wait() log_event( events_path, From efe5928ebafca30d30d9dcc1b9e40a1412b1fa19 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 21:31:05 +0000 Subject: [PATCH 03/62] Use stream-json output format and add max_turns limit - Switch to --output-format stream-json for structured JSONL output - Add --max-turns parameter (default 50) to prevent runaway agents - Output file changed from claude_output.txt to claude_output.jsonl - Updated monitoring commands in logs to use jq for parsing Monitor with: tail -f task_*/claude_output.jsonl | jq -r '.result // empty' https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- spd/agent_swarm/scripts/run_agent.py | 15 ++++++++++----- spd/agent_swarm/scripts/run_slurm.py | 10 +++++++++- spd/agent_swarm/scripts/run_slurm_cli.py | 9 ++++++--- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index 072d8a8c6..f048335e3 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -79,6 +79,7 @@ def run_agent( task_id: int, swarm_id: str, context_length: int = 128, + max_turns: int = 50, ) -> None: """Run a single investigation agent. @@ -87,6 +88,7 @@ def run_agent( task_id: SLURM task ID (1-indexed). swarm_id: Unique identifier for this swarm. context_length: Context length for prompts. + max_turns: Maximum agentic turns before stopping (prevents runaway agents). """ # Setup output directory swarm_dir = get_swarm_output_dir(swarm_id) @@ -212,16 +214,19 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: prompt_path = task_dir / "agent_prompt.md" prompt_path.write_text(agent_prompt) - # Launch Claude Code with output streaming to file - claude_output_path = task_dir / "claude_output.txt" + # Launch Claude Code with streaming JSON output + claude_output_path = task_dir / "claude_output.jsonl" claude_cmd = [ "claude", - "--print", # Print output to stdout - "--dangerously-skip-permissions", # Allow file writes + "--print", + "--output-format", "stream-json", # Structured JSONL for parsing + "--max-turns", str(max_turns), # Prevent runaway agents + "--dangerously-skip-permissions", ] - logger.info(f"[Task {task_id}] Starting Claude Code session...") + logger.info(f"[Task {task_id}] Starting Claude Code (max_turns={max_turns})...") logger.info(f"[Task {task_id}] Monitor with: tail -f {claude_output_path}") + logger.info(f"[Task {task_id}] Parse with: tail -f {claude_output_path} | jq -r '.result // empty'") # Open output file for streaming writes with open(claude_output_path, "w") as output_file: diff --git a/spd/agent_swarm/scripts/run_slurm.py b/spd/agent_swarm/scripts/run_slurm.py index 8b99e253d..f596e1ed9 100644 --- a/spd/agent_swarm/scripts/run_slurm.py +++ b/spd/agent_swarm/scripts/run_slurm.py @@ -31,6 +31,7 @@ def launch_agent_swarm( wandb_path: str, n_agents: int, context_length: int = 128, + max_turns: int = 50, partition: str = "h200-reserved", time: str = "8:00:00", job_suffix: str | None = None, @@ -41,6 +42,7 @@ def launch_agent_swarm( wandb_path: WandB run path for the SPD decomposition. n_agents: Number of agents to launch. context_length: Context length for prompts. + max_turns: Maximum agentic turns per agent (prevents runaway). partition: SLURM partition. time: Time limit per agent. job_suffix: Optional suffix for job names. @@ -64,6 +66,7 @@ def launch_agent_swarm( "wandb_path": wandb_path, "n_agents": n_agents, "context_length": context_length, + "max_turns": max_turns, "snapshot_branch": snapshot_branch, "commit_hash": commit_hash, } @@ -77,7 +80,8 @@ def launch_agent_swarm( f'"{wandb_path}" ' f"--task_id {task_id} " f"--swarm_id {swarm_id} " - f"--context_length {context_length}" + f"--context_length {context_length} " + f"--max_turns {max_turns}" ) worker_commands.append(cmd) @@ -104,6 +108,7 @@ def launch_agent_swarm( "WandB path": wandb_path, "N agents": n_agents, "Context length": context_length, + "Max turns": max_turns, "Output directory": str(output_dir), "Snapshot": f"{snapshot_branch} ({commit_hash[:8]})", "Job ID": array_result.job_id, @@ -115,5 +120,8 @@ def launch_agent_swarm( logger.info("Monitor progress:") logger.info(f" tail -f {output_dir}/task_*/events.jsonl") logger.info("") + logger.info("Monitor Claude output (stream-json):") + logger.info(f" tail -f {output_dir}/task_*/claude_output.jsonl | jq -r '.result // empty'") + logger.info("") logger.info("View explanations:") logger.info(f" cat {output_dir}/task_*/explanations.jsonl | jq .") diff --git a/spd/agent_swarm/scripts/run_slurm_cli.py b/spd/agent_swarm/scripts/run_slurm_cli.py index 20a6d8457..9b75ce95f 100644 --- a/spd/agent_swarm/scripts/run_slurm_cli.py +++ b/spd/agent_swarm/scripts/run_slurm_cli.py @@ -4,14 +4,14 @@ Usage: spd-swarm --n_agents 10 - spd-swarm --n_agents 5 --context_length 128 + spd-swarm --n_agents 5 --max_turns 30 Examples: # Launch 10 agents to investigate a decomposition spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 10 - # Launch 5 agents with custom context length and time limit - spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --context_length 64 --time 4:00:00 + # Launch 5 agents with custom settings + spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --max_turns 30 --time 4:00:00 """ import fire @@ -23,6 +23,7 @@ def main( wandb_path: str, n_agents: int, context_length: int = 128, + max_turns: int = 50, partition: str = DEFAULT_PARTITION_NAME, time: str = "8:00:00", job_suffix: str | None = None, @@ -38,6 +39,7 @@ def main( Format: "entity/project/runs/run_id" or "wandb:entity/project/run_id" n_agents: Number of agents to launch (each gets 1 GPU). context_length: Context length for prompts (default 128). + max_turns: Maximum agentic turns per agent (default 50, prevents runaway). partition: SLURM partition name. time: Job time limit per agent (default 8 hours). job_suffix: Optional suffix for SLURM job names. @@ -48,6 +50,7 @@ def main( wandb_path=wandb_path, n_agents=n_agents, context_length=context_length, + max_turns=max_turns, partition=partition, time=time, job_suffix=job_suffix, From ef5b0fd80ee91eb9ed5c392ef00581693de964c8 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 22:08:49 +0000 Subject: [PATCH 04/62] Fix stream-json output requiring --verbose flag Claude Code requires --verbose when using --output-format=stream-json with --print mode. https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- spd/agent_swarm/scripts/run_agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index f048335e3..b37f4e451 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -219,8 +219,9 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: claude_cmd = [ "claude", "--print", - "--output-format", "stream-json", # Structured JSONL for parsing - "--max-turns", str(max_turns), # Prevent runaway agents + "--verbose", # Required for stream-json output + "--output-format", "stream-json", + "--max-turns", str(max_turns), "--dangerously-skip-permissions", ] From f40f02e443bdd7099fd11d1bdf56915e797ea09f Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Fri, 30 Jan 2026 23:24:50 +0000 Subject: [PATCH 05/62] Add GPU lock to prevent concurrent GPU operations When multiple GPU-intensive requests are made concurrently (graph computation, optimization, intervention), the backend would hang. This adds a lock that returns HTTP 503 immediately if a GPU operation is already in progress, allowing clients to retry later. Co-Authored-By: Claude Opus 4.5 --- spd/agent_swarm/scripts/run_agent.py | 10 ++- spd/app/backend/routers/graphs.py | 61 +++++++++------ spd/app/backend/routers/intervention.py | 100 +++++++++++++----------- spd/app/backend/state.py | 23 ++++++ 4 files changed, 121 insertions(+), 73 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index b37f4e451..c41c8f30c 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -220,14 +220,18 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: "claude", "--print", "--verbose", # Required for stream-json output - "--output-format", "stream-json", - "--max-turns", str(max_turns), + "--output-format", + "stream-json", + "--max-turns", + str(max_turns), "--dangerously-skip-permissions", ] logger.info(f"[Task {task_id}] Starting Claude Code (max_turns={max_turns})...") logger.info(f"[Task {task_id}] Monitor with: tail -f {claude_output_path}") - logger.info(f"[Task {task_id}] Parse with: tail -f {claude_output_path} | jq -r '.result // empty'") + logger.info( + f"[Task {task_id}] Parse with: tail -f {claude_output_path} | jq -r '.result // empty'" + ) # Open output file for streaming writes with open(claude_output_path, "w") as output_file: diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index c0f478744..1ea02e5c6 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -240,8 +240,20 @@ def build_out_probs( def stream_computation( work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: @@ -256,28 +268,31 @@ def compute_thread() -> None: progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] == "progress": + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -456,7 +471,7 @@ def work(on_progress: ProgressCallback) -> GraphData: l0_total=len(filtered_node_ci_vals), ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -660,7 +675,7 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: ), ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _add_pseudo_layer_nodes( diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 4c46e136c..a8fcb3fbc 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -148,45 +148,48 @@ def _run_intervention_forward( @router.post("") @log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: +def run_intervention( + request: InterventionRequest, loaded: DepLoadedRun, manager: DepStateManager +) -> InterventionResponse: """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [(n.layer, n.seq_pos, n.component_idx) for n in request.nodes] - - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) + with manager.gpu_lock(): + token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [(n.layer, n.seq_pos, n.component_idx) for n in request.nodes] + + seq_len = tokens.shape[1] + for _, seq_pos, _ in active_nodes: + if seq_pos >= seq_len: + raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") + + result = compute_intervention_forward( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + top_k=request.top_k, + tokenizer=loaded.tokenizer, + ) - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions + predictions_per_position = [ + [ + TokenPrediction( + token=token, + token_id=token_id, + spd_prob=spd_prob, + target_prob=target_prob, + logit=logit, + target_logit=target_logit, + ) + for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions + ] + for pos_predictions in result.predictions_per_position ] - for pos_predictions in result.predictions_per_position - ] - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return InterventionResponse( + input_tokens=result.input_tokens, + predictions_per_position=predictions_per_position, + ) @router.post("/run") @@ -195,14 +198,16 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + response = _run_intervention_forward( + text=request.text, + selected_nodes=request.selected_nodes, + top_k=request.top_k, + loaded=loaded, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, @@ -310,12 +315,13 @@ def fork_intervention_run( modified_text = loaded.tokenizer.decode(modified_token_ids) # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + response = _run_intervention_forward( + text=modified_text, + selected_nodes=parent_run.selected_nodes, + top_k=request.top_k, + loaded=loaded, + ) # Save the forked run fork_id = db.save_forked_intervention_run( diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index 47dacfe51..7364ff1d1 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -5,9 +5,13 @@ - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +from fastapi import HTTPException from transformers.tokenization_utils_base import PreTrainedTokenizerBase from spd.app.backend.database import PromptAttrDB, Run @@ -147,6 +151,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -189,3 +194,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() From 567fb198c938513c2a348c61dacc97f9435e45b8 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Fri, 30 Jan 2026 23:38:11 +0000 Subject: [PATCH 06/62] Add research_log.md for human-readable agent progress Agents now create and update a research_log.md file with readable progress updates. This makes it easy to follow what the agent is doing and discovering without parsing JSONL files. Co-Authored-By: Claude Opus 4.5 --- spd/agent_swarm/CLAUDE.md | 11 ++++- spd/agent_swarm/agent_prompt.py | 70 +++++++++++++++++++++++----- spd/agent_swarm/scripts/run_agent.py | 14 ++++-- 3 files changed, 77 insertions(+), 18 deletions(-) diff --git a/spd/agent_swarm/CLAUDE.md b/spd/agent_swarm/CLAUDE.md index ee2e89be2..ee4a57db4 100644 --- a/spd/agent_swarm/CLAUDE.md +++ b/spd/agent_swarm/CLAUDE.md @@ -42,11 +42,12 @@ spd/agent_swarm/ SPD_OUT_DIR/agent_swarm// ├── metadata.json # Swarm configuration ├── task_1/ -│ ├── events.jsonl # Progress and observations +│ ├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) +│ ├── events.jsonl # Structured progress and observations │ ├── explanations.jsonl # Complete behavior explanations │ ├── app.db # Isolated SQLite database │ ├── agent_prompt.md # The prompt given to the agent -│ └── claude_output.txt # Raw Claude Code output +│ └── claude_output.jsonl # Raw Claude Code output (stream-json format) ├── task_2/ │ └── ... └── task_N/ @@ -90,6 +91,12 @@ This prevents conflicts when multiple agents run on the same machine. ## Monitoring ```bash +# Watch research logs (best way to follow agent progress) +tail -f SPD_OUT_DIR/agent_swarm//task_*/research_log.md + +# Watch a specific agent's research log +cat SPD_OUT_DIR/agent_swarm//task_1/research_log.md + # Watch events from all agents tail -f SPD_OUT_DIR/agent_swarm//task_*/events.jsonl diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py index e3cf4fd21..8ee02d6dd 100644 --- a/spd/agent_swarm/agent_prompt.py +++ b/spd/agent_swarm/agent_prompt.py @@ -254,10 +254,50 @@ ## Output Format -Write your findings by appending to the output files: +Write your findings to the output files. **The research log is your primary output for humans to read.** + +### research_log.md (MOST IMPORTANT - Write here frequently!) +This is a human-readable log of your investigation. Write here often so someone can follow your progress. +Use clear markdown formatting: + +```markdown +## [HH:MM] Starting Investigation + +Looking at component interpretations to find interesting patterns... + +## [HH:MM] Hypothesis: Gendered Pronoun Circuit + +Found components that seem related to pronouns: +- h.0.mlp.c_fc:42 - "he/his pronouns after male subjects" +- h.0.mlp.c_fc:89 - "she/her pronouns after female subjects" + +Testing with prompt: "The boy said that he" + +## [HH:MM] Optimization Results + +Ran optimization for "he" prediction at position 4: +- Found 15 active components +- Key components: h.0.mlp.c_fc:42 (CI=0.92), h.1.attn.o_proj:156 (CI=0.78) + +## [HH:MM] Ablation Test + +Ablating h.0.mlp.c_fc:42: +- Before: P(he)=0.82, P(she)=0.11 +- After: P(he)=0.23, P(she)=0.45 + +This confirms the component is important for masculine pronoun prediction! + +## [HH:MM] Conclusion + +Found a circuit for gendered pronoun prediction. Components h.0.mlp.c_fc:42 and +h.1.attn.o_proj:156 work together to predict masculine pronouns after male subjects. +``` + +**IMPORTANT**: Update the research log every few minutes with your current progress, +findings, and next steps. This is how humans monitor your work! ### events.jsonl -Log progress and observations: +Log structured progress and observations: ```json {{"event_type": "observation", "message": "Component h.0.mlp.c_fc:5 has high CI when subject is male", "details": {{"ci_value": 0.85}}, "timestamp": "..."}} ``` @@ -284,15 +324,20 @@ ## Getting Started -1. Check the current status: `curl http://localhost:{port}/api/status` -2. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` -3. Search for interesting prompts or create your own -4. Optimize a sparse circuit for a behavior you find -5. Investigate the components involved -6. Test hypotheses with ablations -7. Document your findings +1. **Create your research log**: Start by creating `research_log.md` with a header +2. Check the current status: `curl http://localhost:{port}/api/status` +3. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` +4. Search for interesting prompts or create your own +5. **Update research_log.md** with what you're investigating +6. Optimize a sparse circuit for a behavior you find +7. Investigate the components involved +8. Test hypotheses with ablations +9. **Update research_log.md** with findings +10. Document complete explanations in `explanations.jsonl` + +**Remember to update research_log.md frequently** - this is how humans follow your progress! -Remember: You are exploring! Not every investigation will lead to a clear explanation. +You are exploring! Not every investigation will lead to a clear explanation. Document what you learn, even if it's "this was more complicated than expected." Good luck, and happy investigating! @@ -322,9 +367,10 @@ def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) - **Output Directory**: {output_dir} Your output files: -- `{output_dir}/events.jsonl` - Log events and observations here +- `{output_dir}/research_log.md` - **PRIMARY OUTPUT** - Write readable progress updates here frequently! +- `{output_dir}/events.jsonl` - Log structured events and observations here - `{output_dir}/explanations.jsonl` - Write complete explanations here -To append to these files, use the Write tool or shell redirection. +**Start by creating research_log.md with a header, then update it every few minutes!** """ return prompt + runtime_context diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index c41c8f30c..3c5a78449 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -250,11 +250,17 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: --- -Please begin your investigation. Start by checking the backend status and exploring -available component interpretations. Then find an interesting behavior and investigate it. +Please begin your investigation: -Remember to log your progress to events.jsonl and write complete explanations to -explanations.jsonl when you discover something. +1. **FIRST**: Create `{task_dir}/research_log.md` with a header like "# Research Log - Task {task_id}" +2. Check the backend status and explore component interpretations +3. Find an interesting behavior to investigate +4. **Update research_log.md frequently** with your progress, findings, and next steps + +Remember: +- research_log.md is your PRIMARY output - humans will read this to follow your work +- Update it every few minutes with what you're doing and discovering +- Write complete explanations to explanations.jsonl when you finish investigating a behavior """ assert claude_proc.stdin is not None claude_proc.stdin.write(investigation_request) From 4c4a843b2bb6a9b6e7fb67508e6a6c63960494fc Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Fri, 30 Jan 2026 23:56:57 +0000 Subject: [PATCH 07/62] Add full timestamps to research log examples Show YYYY-MM-DD HH:MM:SS format and provide tip for getting timestamps. Co-Authored-By: Claude Opus 4.5 --- spd/agent_swarm/agent_prompt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py index 8ee02d6dd..b469b47b5 100644 --- a/spd/agent_swarm/agent_prompt.py +++ b/spd/agent_swarm/agent_prompt.py @@ -261,11 +261,11 @@ Use clear markdown formatting: ```markdown -## [HH:MM] Starting Investigation +## [2026-01-30 14:23:15] Starting Investigation Looking at component interpretations to find interesting patterns... -## [HH:MM] Hypothesis: Gendered Pronoun Circuit +## [2026-01-30 14:25:42] Hypothesis: Gendered Pronoun Circuit Found components that seem related to pronouns: - h.0.mlp.c_fc:42 - "he/his pronouns after male subjects" @@ -273,13 +273,13 @@ Testing with prompt: "The boy said that he" -## [HH:MM] Optimization Results +## [2026-01-30 14:28:03] Optimization Results Ran optimization for "he" prediction at position 4: - Found 15 active components - Key components: h.0.mlp.c_fc:42 (CI=0.92), h.1.attn.o_proj:156 (CI=0.78) -## [HH:MM] Ablation Test +## [2026-01-30 14:31:17] Ablation Test Ablating h.0.mlp.c_fc:42: - Before: P(he)=0.82, P(she)=0.11 @@ -287,12 +287,14 @@ This confirms the component is important for masculine pronoun prediction! -## [HH:MM] Conclusion +## [2026-01-30 14:35:44] Conclusion Found a circuit for gendered pronoun prediction. Components h.0.mlp.c_fc:42 and h.1.attn.o_proj:156 work together to predict masculine pronouns after male subjects. ``` +**TIP**: Get the current timestamp with `date '+%Y-%m-%d %H:%M:%S'` for your log entries. + **IMPORTANT**: Update the research log every few minutes with your current progress, findings, and next steps. This is how humans monitor your work! From cb6e6f063808af3f46f976f41531dd9ff92b4351 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Sat, 31 Jan 2026 19:08:07 +0000 Subject: [PATCH 08/62] wip: Integrate agent swarm with MCP for Claude Code tool access --- spd/agent_swarm/CLAUDE.md | 69 +- spd/agent_swarm/agent_prompt.py | 369 ++---- spd/agent_swarm/scripts/run_agent.py | 52 +- spd/app/CLAUDE.md | 3 +- spd/app/backend/routers/__init__.py | 4 + spd/app/backend/routers/investigations.py | 262 ++++ spd/app/backend/routers/mcp.py | 1171 +++++++++++++++++ spd/app/backend/server.py | 29 + .../src/components/InvestigationsTab.svelte | 497 +++++++ .../frontend/src/components/RunView.svelte | 15 +- spd/app/frontend/src/lib/api/index.ts | 1 + .../frontend/src/lib/api/investigations.ts | 55 + 12 files changed, 2211 insertions(+), 316 deletions(-) create mode 100644 spd/app/backend/routers/investigations.py create mode 100644 spd/app/backend/routers/mcp.py create mode 100644 spd/app/frontend/src/components/InvestigationsTab.svelte create mode 100644 spd/app/frontend/src/lib/api/investigations.ts diff --git a/spd/agent_swarm/CLAUDE.md b/spd/agent_swarm/CLAUDE.md index ee4a57db4..48fabc504 100644 --- a/spd/agent_swarm/CLAUDE.md +++ b/spd/agent_swarm/CLAUDE.md @@ -7,9 +7,10 @@ that investigate behaviors in SPD model decompositions. The agent swarm system allows you to: 1. Launch many parallel agents (each as a SLURM job with 1 GPU) -2. Each agent runs an isolated app backend instance -3. Agents investigate behaviors using the SPD app API -4. Findings are written to append-only JSONL files +2. Each agent runs an isolated app backend instance with MCP support +3. Agents investigate behaviors using SPD tools via MCP (Model Context Protocol) +4. Progress is streamed in real-time via MCP SSE events +5. Findings are written to append-only JSONL files ## Usage @@ -36,22 +37,55 @@ spd/agent_swarm/ └── run_agent.py # Worker script (runs in each SLURM job) ``` +## MCP Tools + +Agents access ALL SPD functionality via MCP (Model Context Protocol). The backend exposes +these tools at `/mcp`. Agents don't need file system access - everything is done through MCP. + +**Analysis Tools:** + +| Tool | Description | +|------|-------------| +| `optimize_graph` | Find minimal circuit for a behavior (streams progress) | +| `get_component_info` | Get component interpretation, token stats, correlations | +| `run_ablation` | Test circuit by running with selected components only | +| `search_dataset` | Search SimpleStories training data for patterns | +| `create_prompt` | Tokenize text and get next-token probabilities | + +**Output Tools:** + +| Tool | Description | +|------|-------------| +| `update_research_log` | Append content to the agent's research log (PRIMARY OUTPUT) | +| `save_explanation` | Save a complete, validated behavior explanation | +| `set_investigation_summary` | Set title and summary shown in the investigations UI | +| `submit_suggestion` | Submit ideas for improving the tools or system | + +The `optimize_graph` tool streams progress events via SSE, giving real-time visibility +into long-running optimization operations. + +Suggestions from all agents are collected in `SPD_OUT_DIR/agent_swarm/suggestions.jsonl` (global file). + ## Output Structure ``` -SPD_OUT_DIR/agent_swarm// -├── metadata.json # Swarm configuration -├── task_1/ -│ ├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) -│ ├── events.jsonl # Structured progress and observations -│ ├── explanations.jsonl # Complete behavior explanations -│ ├── app.db # Isolated SQLite database -│ ├── agent_prompt.md # The prompt given to the agent -│ └── claude_output.jsonl # Raw Claude Code output (stream-json format) -├── task_2/ -│ └── ... -└── task_N/ - └── ... +SPD_OUT_DIR/agent_swarm/ +├── suggestions.jsonl # System improvement suggestions from ALL agents (global) +└── / + ├── metadata.json # Swarm configuration + ├── task_1/ + │ ├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) + │ ├── events.jsonl # Structured progress and observations + │ ├── explanations.jsonl # Complete behavior explanations + │ ├── summary.json # Agent-provided title and summary for UI + │ ├── app.db # Isolated SQLite database + │ ├── agent_prompt.md # The prompt given to the agent + │ ├── mcp_config.json # MCP server configuration for Claude Code + │ └── claude_output.jsonl # Raw Claude Code output (stream-json format) + ├── task_2/ + │ └── ... + └── task_N/ + └── ... ``` ## Key Files @@ -103,6 +137,9 @@ tail -f SPD_OUT_DIR/agent_swarm//task_*/events.jsonl # View all explanations cat SPD_OUT_DIR/agent_swarm//task_*/explanations.jsonl | jq . +# View agent suggestions for system improvement (global file) +cat SPD_OUT_DIR/agent_swarm/suggestions.jsonl | jq . + # Check SLURM job status squeue --me diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py index b469b47b5..44424c190 100644 --- a/spd/agent_swarm/agent_prompt.py +++ b/spd/agent_swarm/agent_prompt.py @@ -1,8 +1,7 @@ """System prompt for SPD investigation agents. This module contains the detailed instructions given to each agent in the swarm. -The prompt explains how to use the SPD app API and the scientific methodology -for investigating model behaviors. +The agent has access to SPD tools via MCP - tools are self-documenting. """ AGENT_SYSTEM_PROMPT = """ @@ -10,7 +9,7 @@ You are a research agent investigating behaviors in a neural network model decomposition. Your goal is to find interesting behaviors, understand how components interact to produce -them, and document your findings as explanations. +them, and document your findings. ## Your Mission @@ -23,326 +22,126 @@ 2. **Understand the mechanism**: Figure out which components are involved and how they work together to produce the behavior -3. **Document your findings**: Write a clear explanation with supporting evidence +3. **Document your findings**: Write clear explanations with supporting evidence -## The SPD App Backend +## Available Tools (via MCP) -You have access to an SPD (Stochastic Parameter Decomposition) app backend running at: -`http://localhost:{port}` +You have access to SPD analysis tools. Use them directly - they have full documentation. -This app provides APIs for: -- Loading decomposed models -- Computing attribution graphs showing how components interact -- Optimizing sparse circuits for specific behaviors -- Running interventions (ablations) to test hypotheses -- Viewing component interpretations and correlations -- Searching the training dataset +**Analysis Tools:** +- **optimize_graph**: Find the minimal circuit for a behavior (e.g., "boy" → "he") +- **get_component_info**: Get interpretation and token stats for a component +- **run_ablation**: Test a circuit by running with only selected components +- **search_dataset**: Find examples in the training data +- **create_prompt**: Tokenize text for analysis -## API Reference - -### Health Check -```bash -curl http://localhost:{port}/api/health -# Returns: {{"status": "ok"}} -``` - -### Load a Run (ALREADY DONE FOR YOU) -The run is pre-loaded. Check status with: -```bash -curl http://localhost:{port}/api/status -``` - -### Create a Custom Prompt -To analyze a specific prompt: -```bash -curl -X POST "http://localhost:{port}/api/prompts/custom?text=The%20boy%20ate%20his" -# Returns: {{"id": , "token_ids": [...], "tokens": [...], "preview": "...", "next_token_probs": [...]}} -``` - -### Compute Optimized Attribution Graph (MOST IMPORTANT) -This optimizes a sparse circuit that achieves a behavior: -```bash -curl -X POST "http://localhost:{port}/api/graphs/optimized/stream?prompt_id=&loss_type=ce&loss_position=&label_token=&steps=100&imp_min_coeff=0.1&pnorm=0.5&mask_type=hard&loss_coeff=1.0&ci_threshold=0.01&normalize=target" -# Streams SSE events, final event has type="complete" with graph data -``` - -Parameters: -- `prompt_id`: ID from creating custom prompt -- `loss_type`: "ce" for cross-entropy (predicting specific token) or "kl" (matching full distribution) -- `loss_position`: Token position to optimize (0-indexed, usually last position) -- `label_token`: Token ID to predict (for CE loss) -- `steps`: Optimization steps (50-200 typical) -- `imp_min_coeff`: Importance minimization coefficient (0.05-0.3) -- `pnorm`: P-norm for sparsity (0.3-1.0, lower = sparser) -- `mask_type`: "hard" for binary masks, "soft" for continuous -- `ci_threshold`: Threshold for including nodes in graph (0.01-0.1) -- `normalize`: "target" normalizes by target layer, "none" for raw values - -### Get Component Interpretations -```bash -curl "http://localhost:{port}/api/correlations/interpretations" -# Returns: {{"h.0.mlp.c_fc:5": {{"label": "...", "confidence": "high"}}, ...}} -``` - -Get full interpretation details: -```bash -curl "http://localhost:{port}/api/correlations/interpretations/h.0.mlp.c_fc/5" -# Returns: {{"reasoning": "...", "prompt": "..."}} -``` - -### Get Component Token Statistics -```bash -curl "http://localhost:{port}/api/correlations/token_stats/h.0.mlp.c_fc/5?top_k=20" -# Returns input/output token associations -``` - -### Get Component Correlations -```bash -curl "http://localhost:{port}/api/correlations/components/h.0.mlp.c_fc/5?top_k=20" -# Returns components that frequently co-activate -``` - -### Run Intervention (Ablation) -Test a hypothesis by running the model with only selected components active: -```bash -curl -X POST "http://localhost:{port}/api/intervention/run" \\ - -H "Content-Type: application/json" \\ - -d '{{"graph_id": , "text": "The boy ate his", "selected_nodes": ["h.0.mlp.c_fc:3:5", "h.1.attn.o_proj:3:10"], "top_k": 10}}' -# Returns predictions with only selected components active vs full model -``` - -Node format: "layer:seq_pos:component_idx" -- `layer`: e.g., "h.0.mlp.c_fc", "h.1.attn.o_proj" -- `seq_pos`: Position in sequence (0-indexed) -- `component_idx`: Component index within layer - -### Search Dataset -Find prompts with specific patterns: -```bash -curl -X POST "http://localhost:{port}/api/dataset/search?query=she%20said&split=train" -curl "http://localhost:{port}/api/dataset/results?page=1&page_size=20" -``` - -### Get Random Samples with Loss -Find high/low loss examples: -```bash -curl "http://localhost:{port}/api/dataset/random_with_loss?n_samples=20&seed=42" -``` - -### Probe Component Activation -See how a component responds to arbitrary text: -```bash -curl -X POST "http://localhost:{port}/api/activation_contexts/probe" \\ - -H "Content-Type: application/json" \\ - -d '{{"text": "The boy ate his", "layer": "h.0.mlp.c_fc", "component_idx": 5}}' -# Returns CI values and activations at each position -``` - -### Get Dataset Attributions -See which components influence each other across the training data: -```bash -curl "http://localhost:{port}/api/dataset_attributions/h.0.mlp.c_fc/5?k=10" -# Returns positive/negative sources and targets -``` +**Output Tools:** +- **update_research_log**: Append to your research log (PRIMARY OUTPUT - use frequently!) +- **save_explanation**: Save a complete, validated behavior explanation +- **set_investigation_summary**: Set a title and summary for your investigation (shown in UI) +- **submit_suggestion**: Submit ideas for improving the tools or system ## Investigation Methodology ### Step 1: Find an Interesting Behavior -Start by exploring the model's behavior: - -1. **Search for patterns**: Use `/api/dataset/search` to find prompts with specific - linguistic patterns (pronouns, verb conjugations, completions, etc.) - -2. **Look at high-loss examples**: Use `/api/dataset/random_with_loss` to find where - the model struggles or succeeds - -3. **Create test prompts**: Use `/api/prompts/custom` to create prompts that test - specific capabilities - -Good behaviors to investigate: -- Gendered pronoun prediction ("The doctor said she" vs "The doctor said he") -- Subject-verb agreement ("The cats are" vs "The cat is") -- Pattern completion ("1, 2, 3," → "4") -- Semantic associations ("The capital of France is" → "Paris") -- Grammatical structure (completing sentences correctly) +Start by exploring: +- Search for linguistic patterns: pronouns, verb agreement, completions +- Create test prompts that show clear model behavior +- Good targets: gendered pronouns, subject-verb agreement, semantic associations ### Step 2: Optimize a Sparse Circuit Once you have a behavior: - -1. **Create the prompt** via `/api/prompts/custom` - -2. **Identify the target token**: What token should be predicted? Get its ID from - the tokenizer or from the prompt creation response. - -3. **Run optimization** via `/api/graphs/optimized/stream`: - - Use `loss_type=ce` with the target token - - Set `loss_position` to the position where prediction matters - - Start with `imp_min_coeff=0.1`, `pnorm=0.5`, `steps=100` - - Use `ci_threshold=0.01` to see active components - -4. **Examine the graph**: The response shows: - - `nodeCiVals`: Which components are active (high CI = important) - - `edges`: How components connect (gradient flow) - - `outputProbs`: Model predictions +1. Use `optimize_graph` with your prompt and target token +2. Examine which components have high CI values +3. Note the circuit size (fewer = cleaner mechanism) ### Step 3: Understand Component Roles -For each important component in the graph: - -1. **Check the interpretation**: Use `/api/correlations/interpretations//` - to see if we already have an idea what this component does - -2. **Look at token stats**: Use `/api/correlations/token_stats//` to see - what tokens activate this component (input) and what it predicts (output) - -3. **Check correlations**: Use `/api/correlations/components//` to see - what other components co-activate - -4. **Probe on variations**: Use `/api/activation_contexts/probe` to see how the - component responds to related prompts +For each important component: +1. Use `get_component_info` to see its interpretation and token stats +2. Look at what tokens activate it (input) and what it predicts (output) +3. Check correlated components ### Step 4: Test with Ablations Form hypotheses and test them: - -1. **Hypothesis**: "Component X stores information about gender" - -2. **Test**: Run intervention with and without component X - - If prediction changes as expected → supports hypothesis - - If no change → component may not be necessary for this - - If unexpected change → revise hypothesis - -3. **Control**: Try ablating other components to ensure specificity +1. Use `run_ablation` with the circuit's components +2. Verify predictions match expectations +3. Try removing individual components to find critical ones ### Step 5: Document Your Findings -Write a `BehaviorExplanation` with: -- Clear subject prompt -- Description of the behavior -- Components and their roles -- How they work together -- Supporting evidence from ablations/attributions -- Confidence level -- Alternative hypotheses you considered -- Limitations +Use `update_research_log` frequently - this is how humans monitor your work! +When you complete an investigation, use `save_explanation` to create a structured record. ## Scientific Principles -### Be Epistemologically Humble -- Your first hypothesis is probably wrong or incomplete -- Always consider alternative explanations -- A single confirming example doesn't prove a theory -- Look for disconfirming evidence - -### Be Bayesian -- Start with priors from component interpretations -- Update beliefs based on evidence -- Consider the probability of the evidence under different hypotheses -- Don't anchor too strongly on initial observations - -### Triangulate Evidence -- Don't rely on a single type of evidence -- Ablation results + attribution patterns + token stats together are stronger -- Look for convergent evidence from multiple sources - -### Document Uncertainty -- Be explicit about what you're confident in vs. uncertain about -- Note when evidence is weak or ambiguous -- Identify what additional tests would strengthen the explanation +- **Be skeptical**: Your first hypothesis is probably incomplete +- **Triangulate**: Don't rely on a single type of evidence +- **Document uncertainty**: Note what you're confident in vs. uncertain about +- **Consider alternatives**: What else could explain the behavior? ## Output Format -Write your findings to the output files. **The research log is your primary output for humans to read.** +### Research Log (PRIMARY OUTPUT - Update frequently!) -### research_log.md (MOST IMPORTANT - Write here frequently!) -This is a human-readable log of your investigation. Write here often so someone can follow your progress. -Use clear markdown formatting: +Use `update_research_log` with markdown content. Call it every few minutes to show progress: -```markdown -## [2026-01-30 14:23:15] Starting Investigation - -Looking at component interpretations to find interesting patterns... - -## [2026-01-30 14:25:42] Hypothesis: Gendered Pronoun Circuit - -Found components that seem related to pronouns: -- h.0.mlp.c_fc:42 - "he/his pronouns after male subjects" -- h.0.mlp.c_fc:89 - "she/her pronouns after female subjects" - -Testing with prompt: "The boy said that he" - -## [2026-01-30 14:28:03] Optimization Results - -Ran optimization for "he" prediction at position 4: -- Found 15 active components -- Key components: h.0.mlp.c_fc:42 (CI=0.92), h.1.attn.o_proj:156 (CI=0.78) +Example calls: +``` +update_research_log("# Research Log - Task 1\n\nStarting investigation...\n\n") -## [2026-01-30 14:31:17] Ablation Test +update_research_log("## [14:25:42] Hypothesis: Gendered Pronoun Circuit\n\nTesting prompt: 'The boy said that' → expecting ' he'\n\nUsed optimize_graph - found 15 active components:\n- h.0.mlp.c_fc:407 (CI=0.95) - 'male subjects'\n- h.3.attn.o_proj:262 (CI=0.92) - 'masculine pronouns'\n\n") -Ablating h.0.mlp.c_fc:42: -- Before: P(he)=0.82, P(she)=0.11 -- After: P(he)=0.23, P(she)=0.45 +update_research_log("## [14:28:03] Ablation Test\n\nResult: P(he) = 0.89 (vs 0.22 baseline)\n\nThis confirms the circuit is sufficient!\n\n") +``` -This confirms the component is important for masculine pronoun prediction! +### Saving Explanations -## [2026-01-30 14:35:44] Conclusion +When you have a complete explanation, use `save_explanation`: -Found a circuit for gendered pronoun prediction. Components h.0.mlp.c_fc:42 and -h.1.attn.o_proj:156 work together to predict masculine pronouns after male subjects. +``` +save_explanation( + subject_prompt="The boy said that", + behavior_description="Predicts masculine pronoun 'he' after male subject", + components_involved=[ + {{"component_key": "h.0.mlp.c_fc:407", "role": "Male subject detector"}}, + {{"component_key": "h.3.attn.o_proj:262", "role": "Masculine pronoun promoter"}} + ], + explanation="Component h.0.mlp.c_fc:407 activates on male subjects...", + confidence="medium", + limitations=["Only tested on simple sentences"] +) ``` -**TIP**: Get the current timestamp with `date '+%Y-%m-%d %H:%M:%S'` for your log entries. +### Submitting Suggestions -**IMPORTANT**: Update the research log every few minutes with your current progress, -findings, and next steps. This is how humans monitor your work! +If you have ideas for improving the system, use `submit_suggestion`: -### events.jsonl -Log structured progress and observations: -```json -{{"event_type": "observation", "message": "Component h.0.mlp.c_fc:5 has high CI when subject is male", "details": {{"ci_value": 0.85}}, "timestamp": "..."}} ``` - -### explanations.jsonl -When you have a complete explanation: -```json -{{ - "subject_prompt": "The boy ate his lunch", - "behavior_description": "Correctly predicts gendered pronoun 'his' after male subject", - "components_involved": [ - {{"component_key": "h.0.mlp.c_fc:5", "role": "Encodes subject gender as male", "interpretation": "male names/subjects"}}, - {{"component_key": "h.1.attn.o_proj:10", "role": "Transmits gender information to output", "interpretation": null}} - ], - "explanation": "Component h.0.mlp.c_fc:5 activates on male subjects and stores gender information...", - "supporting_evidence": [ - {{"evidence_type": "ablation", "description": "Removing component causes prediction to change from 'his' to 'her'", "details": {{"without_component": {{"his": 0.1, "her": 0.6}}, "with_component": {{"his": 0.8, "her": 0.1}}}}}} - ], - "confidence": "medium", - "alternative_hypotheses": ["Component might encode broader concept of masculine entities, not just humans"], - "limitations": ["Only tested on simple subject-pronoun sentences"] -}} +submit_suggestion( + category="tool_improvement", + title="Add batch ablation support", + description="It would be faster to test multiple ablations at once...", + context="I was testing 10 different component subsets one at a time" +) ``` ## Getting Started -1. **Create your research log**: Start by creating `research_log.md` with a header -2. Check the current status: `curl http://localhost:{port}/api/status` -3. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` -4. Search for interesting prompts or create your own -5. **Update research_log.md** with what you're investigating -6. Optimize a sparse circuit for a behavior you find -7. Investigate the components involved -8. Test hypotheses with ablations -9. **Update research_log.md** with findings -10. Document complete explanations in `explanations.jsonl` - -**Remember to update research_log.md frequently** - this is how humans follow your progress! +1. **Create your research log** with `update_research_log("# Research Log - Task N\n\n...")` +2. Use analysis tools to explore the model +3. Find an interesting behavior to investigate +4. **Call `update_research_log` frequently** - humans are watching! +5. Use `save_explanation` for complete findings +6. **Call `set_investigation_summary`** with a title and summary when done (or periodically for updates) You are exploring! Not every investigation will lead to a clear explanation. Document what you learn, even if it's "this was more complicated than expected." -Good luck, and happy investigating! +Good luck! """ @@ -350,7 +149,7 @@ def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) """Generate the full agent prompt with runtime parameters filled in. Args: - port: The port the backend is running on. + port: The port the backend is running on (for reference, tools use MCP). wandb_path: The WandB path of the loaded run. task_id: The SLURM task ID for this agent. output_dir: Path to the agent's output directory. @@ -358,21 +157,19 @@ def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) Returns: The complete agent prompt with parameters substituted. """ - prompt = AGENT_SYSTEM_PROMPT.format(port=port) - runtime_context = f""" ## Runtime Context -- **Backend URL**: http://localhost:{port} -- **Loaded Run**: {wandb_path} +- **Model Run**: {wandb_path} - **Task ID**: {task_id} -- **Output Directory**: {output_dir} -Your output files: -- `{output_dir}/research_log.md` - **PRIMARY OUTPUT** - Write readable progress updates here frequently! -- `{output_dir}/events.jsonl` - Log structured events and observations here -- `{output_dir}/explanations.jsonl` - Write complete explanations here +Use the MCP tools for ALL output: +- `update_research_log` → **PRIMARY OUTPUT** - Update frequently with your progress! +- `save_explanation` → Save complete, validated behavior explanations +- `submit_suggestion` → Share ideas for improving the system -**Start by creating research_log.md with a header, then update it every few minutes!** +**Start by calling update_research_log to create your log, then investigate!** """ - return prompt + runtime_context + # Note: output_dir and port are available but agents shouldn't need them + _ = output_dir, port + return AGENT_SYSTEM_PROMPT + runtime_context diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index 3c5a78449..627b6d473 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -4,10 +4,12 @@ 1. Creates an isolated output directory for this agent 2. Starts the app backend with an isolated database 3. Loads the SPD run -4. Launches Claude Code with investigation instructions -5. Handles cleanup on exit +4. Configures MCP server for Claude Code +5. Launches Claude Code with investigation instructions +6. Handles cleanup on exit """ +import json import os import signal import socket @@ -26,6 +28,21 @@ from spd.log import logger +def write_mcp_config(task_dir: Path, port: int) -> Path: + """Write MCP configuration file for Claude Code.""" + mcp_config = { + "mcpServers": { + "spd": { + "type": "http", + "url": f"http://localhost:{port}/mcp", + } + } + } + config_path = task_dir / "mcp_config.json" + config_path.write_text(json.dumps(mcp_config, indent=2)) + return config_path + + def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: """Find an available port starting from start_port.""" for offset in range(max_attempts): @@ -124,9 +141,13 @@ def run_agent( ), ) - # Start backend with isolated database + # Start backend with isolated database and swarm configuration env = os.environ.copy() env["SPD_APP_DB_PATH"] = str(db_path) + env["SPD_MCP_EVENTS_PATH"] = str(events_path) + env["SPD_MCP_TASK_DIR"] = str(task_dir) + # Suggestions go to a global file (one level above swarm dirs) + env["SPD_MCP_SUGGESTIONS_PATH"] = str(swarm_dir.parent / "suggestions.jsonl") backend_cmd = [ sys.executable, @@ -214,7 +235,12 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: prompt_path = task_dir / "agent_prompt.md" prompt_path.write_text(agent_prompt) - # Launch Claude Code with streaming JSON output + # Write MCP config for Claude Code + mcp_config_path = write_mcp_config(task_dir, port) + logger.info(f"[Task {task_id}] MCP config written to {mcp_config_path}") + + # Launch Claude Code with streaming JSON output and MCP + # No --dangerously-skip-permissions needed - agents use MCP tools for all I/O claude_output_path = task_dir / "claude_output.jsonl" claude_cmd = [ "claude", @@ -224,7 +250,8 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: "stream-json", "--max-turns", str(max_turns), - "--dangerously-skip-permissions", + "--mcp-config", + str(mcp_config_path), ] logger.info(f"[Task {task_id}] Starting Claude Code (max_turns={max_turns})...") @@ -252,15 +279,16 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: Please begin your investigation: -1. **FIRST**: Create `{task_dir}/research_log.md` with a header like "# Research Log - Task {task_id}" -2. Check the backend status and explore component interpretations -3. Find an interesting behavior to investigate -4. **Update research_log.md frequently** with your progress, findings, and next steps +1. **FIRST**: Use the `update_research_log` tool to create your research log with a header like: + "# Research Log - Task {task_id}\\n\\nStarting investigation of {wandb_path}\\n\\n" +2. Explore component interpretations using `get_component_info` +3. Find an interesting behavior to investigate with `optimize_graph` +4. **Use `update_research_log` frequently** to document your progress, findings, and next steps Remember: -- research_log.md is your PRIMARY output - humans will read this to follow your work -- Update it every few minutes with what you're doing and discovering -- Write complete explanations to explanations.jsonl when you finish investigating a behavior +- The research log is your PRIMARY output - use `update_research_log` every few minutes +- Use `save_explanation` to record complete, validated explanations +- Use `submit_suggestion` if you have ideas for improving the tools or system """ assert claude_proc.stdin is not None claude_proc.stdin.write(investigation_request) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 95d6bc1b3..c1eed4f49 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -48,7 +48,8 @@ backend/ ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering ├── dataset_search.py # SimpleStories dataset search - └── agents.py # Various useful endpoints that AI agents should look at when helping + ├── agents.py # Various useful endpoints that AI agents should look at when helping + └── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ``` Note: Activation contexts, correlations, and token stats are now loaded from pre-harvested data (see `spd/harvest/`). The app no longer computes these on-the-fly. diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index 79cea1087..83f6e8ac2 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -9,6 +9,8 @@ from spd.app.backend.routers.dataset_search import router as dataset_search_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router from spd.app.backend.routers.prompts import router as prompts_router from spd.app.backend.routers.runs import router as runs_router @@ -22,6 +24,8 @@ "dataset_search_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", "prompts_router", "runs_router", ] diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..3ea6244c3 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,262 @@ +"""Investigations endpoint for viewing agent swarm results. + +Lists and serves investigation data from SPD_OUT_DIR/agent_swarm/. +Each task is treated as an independent investigation (flattened across swarms). +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.settings import SPD_OUT_DIR + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +SWARM_DIR = SPD_OUT_DIR / "agent_swarm" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation (task).""" + + id: str # swarm_id/task_id + swarm_id: str + task_id: int + wandb_path: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + # Agent-provided summary + title: str | None + summary: str | None + status: str | None # in_progress, completed, inconclusive + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + swarm_id: str + task_id: int + wandb_path: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + # Agent-provided summary + title: str | None + summary: str | None + status: str | None + + +def _parse_swarm_metadata(swarm_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from a swarm directory.""" + metadata_path = swarm_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except Exception: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + try: + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + except Exception: + pass + + return last_time, last_msg, count + + +def _parse_task_summary(task_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from a task directory. Returns (title, summary, status).""" + summary_path = task_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except Exception: + return None, None, None + + +def _get_task_created_at(task_path: Path, swarm_metadata: dict[str, Any] | None) -> str: + """Get creation time for a task.""" + # Try to get from first event + events_path = task_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except Exception: + pass + + # Fall back to swarm metadata + if swarm_metadata and "created_at" in swarm_metadata: + return swarm_metadata["created_at"] + + # Fall back to directory mtime + return datetime.fromtimestamp(task_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations() -> list[InvestigationSummary]: + """List all investigations (tasks) flattened across swarms.""" + if not SWARM_DIR.exists(): + return [] + + results = [] + + for swarm_path in SWARM_DIR.iterdir(): + if not swarm_path.is_dir() or not swarm_path.name.startswith("swarm-"): + continue + + swarm_id = swarm_path.name + metadata = _parse_swarm_metadata(swarm_path) + wandb_path = metadata.get("wandb_path") if metadata else None + + for task_path in swarm_path.iterdir(): + if not task_path.is_dir() or not task_path.name.startswith("task_"): + continue + + try: + task_id = int(task_path.name.split("_")[1]) + except (ValueError, IndexError): + continue + + events_path = task_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(task_path) + + results.append( + InvestigationSummary( + id=f"{swarm_id}/{task_id}", + swarm_id=swarm_id, + task_id=task_id, + wandb_path=wandb_path, + created_at=_get_task_created_at(task_path, metadata), + has_research_log=(task_path / "research_log.md").exists(), + has_explanations=(task_path / "explanations.jsonl").exists() + and (task_path / "explanations.jsonl").stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + # Sort by creation time, newest first + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{swarm_id}/{task_id}") +def get_investigation(swarm_id: str, task_id: int) -> InvestigationDetail: + """Get full details of an investigation.""" + swarm_path = SWARM_DIR / swarm_id + task_path = swarm_path / f"task_{task_id}" + + if not task_path.exists() or not task_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {swarm_id}/{task_id} not found") + + metadata = _parse_swarm_metadata(swarm_path) + + # Read research log + research_log = None + research_log_path = task_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + # Read events + events = [] + events_path = task_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + # Read explanations + explanations: list[dict[str, Any]] = [] + explanations_path = task_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(task_path) + + return InvestigationDetail( + id=f"{swarm_id}/{task_id}", + swarm_id=swarm_id, + task_id=task_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + created_at=_get_task_created_at(task_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + title=title, + summary=summary, + status=status, + ) diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..c986109a0 --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1171 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Generator +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_intervention_forward, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import build_out_probs +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + +# Optional paths for swarm integration (set via environment at runtime) +_events_log_path: Path | None = None +_task_dir: Path | None = None +_suggestions_path: Path | None = None + + +def set_events_log_path(path: Path | None) -> None: + """Set the path for logging MCP tool events (for swarm monitoring).""" + global _events_log_path + _events_log_path = path + + +def set_task_dir(path: Path | None) -> None: + """Set the task directory for research log and explanations output.""" + global _task_dir + _task_dir = path + + +def set_suggestions_path(path: Path | None) -> None: + """Set the path for the central suggestions file.""" + global _suggestions_path + _suggestions_path = path + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if configured.""" + if _events_log_path is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response.""" + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Layer name (e.g., 'h.0.mlp.c_fc', 'h.2.attn.o_proj')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., 'h.0.mlp.c_fc:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="submit_suggestion", + description="""Submit a suggestion for improving the SPD system. + +Use this when you encounter limitations, have ideas for new tools, or think +of ways the system could better support investigation work. + +Suggestions are collected centrally and reviewed by humans to improve the system.""", + inputSchema={ + "type": "object", + "properties": { + "category": { + "type": "string", + "enum": ["tool_improvement", "new_tool", "documentation", "bug", "other"], + "description": "Category of suggestion", + }, + "title": { + "type": "string", + "description": "Brief title for the suggestion", + }, + "description": { + "type": "string", + "description": "Detailed description of the suggestion", + }, + "context": { + "type": "string", + "description": "What you were trying to do when you had this idea", + }, + }, + "required": ["category", "title", "description"], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text, add_special_tokens=False) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token, add_special_tokens=False) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + # Build output + out_probs = build_out_probs( + ci_masked_out_probs=result.ci_masked_out_probs.cpu(), + ci_masked_out_logits=result.ci_masked_out_logits.cpu(), + target_out_probs=result.target_out_probs.cpu(), + target_out_logits=result.target_out_logits.cpu(), + output_prob_threshold=0.01, + token_strings=loaded.token_strings, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + out_probs=out_probs, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.token_strings[t] for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + component_key = f"{layer}:{component_idx}" + + _log_event( + "tool_call", f"get_component_info: {component_key}", {"layer": layer, "idx": component_idx} + ) + + result: dict[str, Any] = {"component_key": component_key} + + # Get interpretation + interpretations = loaded.harvest.interpretations + if component_key in interpretations: + interp = interpretations[component_key] + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + + # Get token stats + token_stats = loaded.harvest.token_stats + input_stats = analysis.get_input_token_stats( + token_stats, component_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, component_key, loaded.tokenizer, top_k + ) + + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + + # Get correlations + correlations = loaded.harvest.correlations + if analysis.has_component(correlations, component_key): + result["correlated_components"] = { + "precision": [ + {"key": c.component_key, "score": c.score} + for c in analysis.get_correlated_components( + correlations, component_key, "precision", top_k + ) + ], + "pmi": [ + {"key": c.component_key, "score": c.score} + for c in analysis.get_correlated_components( + correlations, component_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Parse node keys + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention_forward( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + top_k=top_k, + tokenizer=loaded.tokenizer, + ) + + predictions = [] + for pos_predictions in result.predictions_per_position: + pos_result = [] + for token, token_id, spd_prob, _logit, target_prob, _target_logit in pos_predictions: + pos_result.append( + { + "token": token, + "token_id": token_id, + "circuit_prob": round(spd_prob, 6), + "full_model_prob": round(target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.token_strings[t] for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + if _task_dir is None: + raise ValueError("Research log not available - not running in swarm mode") + + content = params["content"] + research_log_path = _task_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + # Append content with a newline separator + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.agent_swarm.schemas import BehaviorExplanation, ComponentInfo, Evidence + + if _task_dir is None: + raise ValueError("Explanations file not available - not running in swarm mode") + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + # Build components + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + # Build evidence + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = _task_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_submit_suggestion(params: dict[str, Any]) -> dict[str, Any]: + """Submit a suggestion for system improvement.""" + if _suggestions_path is None: + raise ValueError("Suggestions not available - not running in swarm mode") + + suggestion = { + "timestamp": datetime.now(UTC).isoformat(), + "category": params["category"], + "title": params["title"], + "description": params["description"], + "context": params.get("context"), + } + + _log_event( + "tool_call", + f"submit_suggestion: [{params['category']}] {params['title']}", + suggestion, + ) + + # Ensure parent directory exists + _suggestions_path.parent.mkdir(parents=True, exist_ok=True) + + with open(_suggestions_path, "a") as f: + f.write(json.dumps(suggestion) + "\n") + + return {"status": "ok", "message": "Suggestion recorded. Thank you!"} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + if _task_dir is None: + raise ValueError("Summary not available - not running in swarm mode") + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = _task_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name == "optimize_graph": + # This tool streams progress + return _tool_optimize_graph(arguments) + elif name == "get_component_info": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_get_component_info(arguments), indent=2)} + ] + } + elif name == "run_ablation": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_run_ablation(arguments), indent=2)} + ] + } + elif name == "search_dataset": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_search_dataset(arguments), indent=2)} + ] + } + elif name == "create_prompt": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_create_prompt(arguments), indent=2)} + ] + } + elif name == "update_research_log": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_update_research_log(arguments), indent=2)} + ] + } + elif name == "save_explanation": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_save_explanation(arguments), indent=2)} + ] + } + elif name == "submit_suggestion": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_submit_suggestion(arguments), indent=2)} + ] + } + elif name == "set_investigation_summary": + return { + "content": [ + { + "type": "text", + "text": json.dumps(_tool_set_investigation_summary(arguments), indent=2), + } + ] + } + else: + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse(content=MCPResponse(id=mcp_request.id, result=result).model_dump()) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump())}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump())}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump() + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump() + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump() + ) diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 45f5d9afb..68316bb69 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -34,6 +34,8 @@ dataset_search_router, graphs_router, intervention_router, + investigations_router, + mcp_router, prompts_router, runs_router, ) @@ -47,6 +49,15 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + import os + from pathlib import Path + + from spd.app.backend.routers.mcp import ( + set_events_log_path, + set_suggestions_path, + set_task_dir, + ) + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -57,6 +68,22 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for agent swarm mode + mcp_events_path = os.environ.get("SPD_MCP_EVENTS_PATH") + if mcp_events_path: + set_events_log_path(Path(mcp_events_path)) + logger.info(f"[STARTUP] MCP events logging to: {mcp_events_path}") + + mcp_task_dir = os.environ.get("SPD_MCP_TASK_DIR") + if mcp_task_dir: + set_task_dir(Path(mcp_task_dir)) + logger.info(f"[STARTUP] MCP task dir: {mcp_task_dir}") + + mcp_suggestions_path = os.environ.get("SPD_MCP_SUGGESTIONS_PATH") + if mcp_suggestions_path: + set_suggestions_path(Path(mcp_suggestions_path)) + logger.info(f"[STARTUP] MCP suggestions file: {mcp_suggestions_path}") + yield manager.close() @@ -157,6 +184,8 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_attributions_router) app.include_router(agents_router) app.include_router(component_data_router) +app.include_router(investigations_router) +app.include_router(mcp_router) def cli(port: int = 8000) -> None: diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte new file mode 100644 index 000000000..e9b4a7cb8 --- /dev/null +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -0,0 +1,497 @@ + + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} +
+ {#if selected.data.research_log} +
{selected.data.research_log}
+ {:else} +

No research log available

+ {/if} +
+ {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

No investigations found. Run spd-swarm to create one.

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/RunView.svelte b/spd/app/frontend/src/components/RunView.svelte index 734c9657d..06fd2ebbe 100644 --- a/spd/app/frontend/src/components/RunView.svelte +++ b/spd/app/frontend/src/components/RunView.svelte @@ -3,13 +3,14 @@ import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; import ClusterPathInput from "./ClusterPathInput.svelte"; import DatasetExplorerTab from "./DatasetExplorerTab.svelte"; + import InvestigationsTab from "./InvestigationsTab.svelte"; import PromptAttributionsTab from "./PromptAttributionsTab.svelte"; import DisplaySettingsDropdown from "./ui/DisplaySettingsDropdown.svelte"; import ActivationContextsTab from "./ActivationContextsTab.svelte"; const runState = getContext(RUN_KEY); - let activeTab = $state<"prompts" | "components" | "dataset-search" | null>(null); + let activeTab = $state<"prompts" | "components" | "dataset-search" | "investigations" | null>(null); let showRunMenu = $state(false); @@ -32,6 +33,14 @@ {/if}