diff --git a/.claude/skills/gpudash.md b/.claude/skills/gpudash.md new file mode 100644 index 000000000..5e60a2b6a --- /dev/null +++ b/.claude/skills/gpudash.md @@ -0,0 +1,12 @@ +--- +name: gpudash +description: Check GPU availability across the SLURM cluster +user_invocable: true +--- + +# gpudash + +Run the `gpudash` command to show GPU availability across the cluster. + +## Steps +1. Run `gpudash` and show the output to the user. diff --git a/.claude/worktrees/bold-elm-8kpb b/.claude/worktrees/bold-elm-8kpb new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/bold-elm-8kpb @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/bright-fox-a4i0 b/.claude/worktrees/bright-fox-a4i0 new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/bright-fox-a4i0 @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/calm-owl-v4pj b/.claude/worktrees/calm-owl-v4pj new file mode 160000 index 000000000..dbe0668a4 --- /dev/null +++ b/.claude/worktrees/calm-owl-v4pj @@ -0,0 +1 @@ +Subproject commit dbe0668a4119885b7fe952ed820b4ba8b4a3d693 diff --git a/.claude/worktrees/cozy-frolicking-stream b/.claude/worktrees/cozy-frolicking-stream new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/cozy-frolicking-stream @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/stateless-dancing-blanket b/.claude/worktrees/stateless-dancing-blanket new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/stateless-dancing-blanket @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/swift-owl-yep9 b/.claude/worktrees/swift-owl-yep9 new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/swift-owl-yep9 @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/swift-ray-amfs b/.claude/worktrees/swift-ray-amfs new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/swift-ray-amfs @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/vectorized-wiggling-whisper b/.claude/worktrees/vectorized-wiggling-whisper new file mode 160000 index 000000000..cb18c86a7 --- /dev/null +++ b/.claude/worktrees/vectorized-wiggling-whisper @@ -0,0 +1 @@ +Subproject commit cb18c86a77720f94a292e7421a19694082813c8c diff --git a/.claude/worktrees/xenodochial-germain b/.claude/worktrees/xenodochial-germain new file mode 160000 index 000000000..5c9f344eb --- /dev/null +++ b/.claude/worktrees/xenodochial-germain @@ -0,0 +1 @@ +Subproject commit 5c9f344eb490e90bed9db5102325459d42c3c0f4 diff --git a/.gitignore b/.gitignore index 4780cbd03..b5601daf4 100644 --- a/.gitignore +++ b/.gitignore @@ -177,4 +177,6 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* + +.claude/worktrees \ No newline at end of file diff --git a/.mcp.json b/.mcp.json index fefb52c9a..700113020 100644 --- a/.mcp.json +++ b/.mcp.json @@ -1,8 +1,3 @@ { - "mcpServers": { - "svelte-llm": { - "type": "http", - "url": "https://svelte-llm.stanislav.garden/mcp/mcp" - } - } + "mcpServers": {} } \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 6bb73e8b2..bd70b3d9d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,14 +3,17 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Environment Setup + **IMPORTANT**: Always activate the virtual environment before running Python or git operations: + ```bash source .venv/bin/activate ``` -Repo requires `.env` file with WandB credentials (see `.env.example`) +Repo requires `.env` file with WandB credentials (see `.env.example`) ## Project Overview + SPD (Stochastic Parameter Decomposition) is a research framework for analyzing neural network components and their interactions through sparse parameter decomposition techniques. - Target model parameters are decomposed as a sum of `parameter components` @@ -46,7 +49,7 @@ This repository implements methods from two key research papers on parameter dec **Stochastic Parameter Decomposition (SPD)** - [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) -- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. +- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. - Introduces the core SPD framework - Details the stochastic masking approach and optimization techniques used throughout the codebase - Useful reading for understanding the implementation details, though may be outdated. @@ -95,6 +98,7 @@ This repository implements methods from two key research papers on parameter dec ## Architecture Overview **Core SPD Framework:** + - `spd/run_spd.py` - Main SPD optimization logic called by all experiments - `spd/configs.py` - Pydantic config classes for all experiment types - `spd/registry.py` - Centralized experiment registry with all experiment configurations @@ -105,15 +109,17 @@ This repository implements methods from two key research papers on parameter dec - `spd/figures.py` - Figures for logging to WandB (e.g. CI histograms, Identity plots, etc.) **Terminology: Sources vs Masks:** + - **Sources** (`adv_sources`, `PPGDSources`, `self.sources`): The raw values that PGD optimizes adversarially. These are interpolated with CI to produce component masks: `mask = ci + (1 - ci) * source`. Used in both regular PGD (`spd/metrics/pgd_utils.py`) and persistent PGD (`spd/persistent_pgd.py`). - **Masks** (`component_masks`, `RoutingMasks`, `make_mask_infos`, `n_mask_samples`): The materialized per-component masks used during forward passes. These are produced from sources (in PGD) or from stochastic sampling, and are a general SPD concept across the whole codebase. **Experiment Structure:** Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: + - `models.py` - Experiment-specific model classes and pretrained loading - `*_decomposition.py` - Main SPD execution script -- `train_*.py` - Training script for target models +- `train_*.py` - Training script for target models - `*_config.yaml` - Configuration files - `plotting.py` - Visualization utilities @@ -127,7 +133,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: **Configuration System:** - YAML configs define all experiment parameters -- Pydantic models provide type safety and validation +- Pydantic models provide type safety and validation - WandB integration for experiment tracking and model storage - Supports both local paths and `wandb:project/runs/run_id` format for model loading - Centralized experiment registry (`spd/registry.py`) manages all experiment configurations @@ -137,8 +143,9 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: - `spd/harvest/` - Offline GPU pipeline for collecting component statistics (correlations, token stats, activation examples) - `spd/autointerp/` - LLM-based automated interpretation of components - `spd/dataset_attributions/` - Multi-GPU pipeline for computing component-to-component attribution strengths aggregated over training data -- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions}//` -- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, and `spd/dataset_attributions/CLAUDE.md` for details +- `spd/graph_interp/` - Context-aware component labeling using graph structure (attributions + correlations) +- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions,graph_interp}//` +- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, `spd/dataset_attributions/CLAUDE.md`, and `spd/graph_interp/CLAUDE.md` for details **Output Directory (`SPD_OUT_DIR`):** @@ -160,12 +167,14 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── investigate/ # Agent investigation (see investigate/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) │ ├── dataset_attributions/ # Dataset attributions (see dataset_attributions/CLAUDE.md) │ ├── harvest/ # Statistics collection (see harvest/CLAUDE.md) │ ├── postprocess/ # Unified postprocessing pipeline (harvest + attributions + autointerp) +│ ├── graph_interp/ # Context-aware interpretation (see graph_interp/CLAUDE.md) │ ├── pretrain/ # Target model pretraining (see pretrain/CLAUDE.md) │ ├── experiments/ # Experiment implementations │ │ ├── tms/ # Toy Model of Superposition @@ -201,14 +210,17 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-autointerp` | `spd/autointerp/scripts/run_slurm_cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | | `spd-postprocess` | `spd/postprocess/cli.py` | Unified postprocessing pipeline (harvest + attributions + interpret + evals) | +| `spd-graph-interp` | `spd/graph_interp/scripts/run_slurm_cli.py` | Submit graph interpretation SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | | `spd-pretrain` | `spd/pretrain/scripts/run_slurm_cli.py` | Pretrain target models | +| `spd-investigate` | `spd/investigate/scripts/run_slurm_cli.py` | Launch investigation agent | ### Files to Skip When Searching Use `spd/` as the search root (not repo root) to avoid noise. **Always skip:** + - `.venv/` - Virtual environment - `__pycache__/`, `.pytest_cache/`, `.ruff_cache/` - Build artifacts - `node_modules/` - Frontend dependencies @@ -218,27 +230,37 @@ Use `spd/` as the search root (not repo root) to avoid noise. - `wandb/` - WandB local files **Usually skip unless relevant:** + - `tests/` - Test files (unless debugging test failures) - `papers/` - Research paper drafts ### Common Call Chains **Running Experiments:** + - `spd-run` → `spd/scripts/run.py` → `spd/utils/slurm.py` → SLURM → `spd/run_spd.py` - `spd-local` → `spd/scripts/run_local.py` → `spd/run_spd.py` directly **Harvest Pipeline:** + - `spd-harvest` → `spd/harvest/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/harvest/scripts/run.py` → `spd/harvest/harvest.py` **Autointerp Pipeline:** + - `spd-autointerp` → `spd/autointerp/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → `spd/autointerp/interpret.py` **Dataset Attributions Pipeline:** + - `spd-attributions` → `spd/dataset_attributions/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/dataset_attributions/harvest.py` **Clustering Pipeline:** + - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Investigation Pipeline:** + +- `spd-investigate` → `spd/investigate/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM → `spd/investigate/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -285,6 +307,28 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Investigation (`spd-investigate`) + +Launch a Claude Code agent to investigate a specific question about an SPD model: + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What components are involved in verb agreement?" --time 4:00:00 +``` + +Each investigation: + +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates the specific research question using SPD tools via MCP +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/investigations//` + +For parallel investigations, run the command multiple times with different prompts. + +See `spd/investigate/CLAUDE.md` for details. + ### Unified Postprocessing (`spd-postprocess`) Run all postprocessing steps for a completed SPD run with a single command: @@ -295,6 +339,7 @@ spd-postprocess --config custom_config.yaml # Use custom config ``` Defaults are defined in `PostprocessConfig` (`spd/postprocess/config.py`). Pass a custom YAML/JSON config to override. Set any section to `null` to skip it: + - `attributions: null` — skip dataset attributions - `autointerp: null` — skip autointerp entirely (interpret + evals) - `autointerp.evals: null` — skip evals but still run interpret @@ -323,6 +368,7 @@ spd-run # Run all experiments ``` All `spd-run` executions: + - Submit jobs to SLURM - Create a git snapshot for reproducibility - Create W&B workspace views @@ -343,6 +389,7 @@ spd-run --experiments --sweep --n_agents [--cpu] ``` Examples: + ```bash spd-run --experiments tms_5-2 --sweep --n_agents 4 # Run TMS 5-2 sweep with 4 GPU agents spd-run --experiments resid_mlp2 --sweep --n_agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents @@ -364,6 +411,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee - Default sweep parameters are loaded from `spd/scripts/sweep_params.yaml` - You can specify a custom sweep parameters file by passing its path to `--sweep` - Sweep parameters support both experiment-specific and global configurations: + ```yaml # Global parameters applied to all experiments global: @@ -376,7 +424,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee # Experiment-specific parameters (override global) tms_5-2: seed: - values: [100, 200] # Overrides global seed + values: [100, 200] # Overrides global seed task_config: feature_probability: values: [0.05, 0.1] @@ -402,6 +450,7 @@ model = ComponentModel.from_run_info(run_info) # Local paths work too model = ComponentModel.from_pretrained("/path/to/checkpoint.pt") ``` + **Path Formats:** - WandB: `wandb:entity/project/run_id` or `wandb:entity/project/runs/run_id` @@ -415,12 +464,12 @@ Downloaded runs are cached in `SPD_OUT_DIR/runs/-/`. - This includes not setting off multiple sweeps/evals that total >8 GPUs - Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` - ## Coding Guidelines & Software Engineering Principles **This is research code, not production. Prioritize simplicity and fail-fast over defensive programming.** Core principles: + - **Fail fast** - assert assumptions, crash on violations, don't silently recover - **No legacy support** - delete unused code, don't add fallbacks for old formats or migration shims - **Narrow types** - avoid `| None` unless null is semantically meaningful; use discriminated unions over bags of optional fields @@ -451,11 +500,12 @@ config = get_config(path) value = config.key ``` - ### Tests + - The point of tests in this codebase is to ensure that the code is working as expected, not to prevent production outages - there's no deployment here. Therefore, don't worry about lots of larger integration/end-to-end tests. These often require too much overhead for what it's worth in our case, and this codebase is interactively run so often that issues will likely be caught by the user at very little cost. ### Assertions and error handling + - If you have an invariant in your head, assert it. Are you afraid to assert? Sounds like your program might already be broken. Assert, assert, assert. Never soft fail. - Do not write: `if everythingIsOk: continueHappyPath()`. Instead do `assert everythingIsOk` - You should have a VERY good reason to handle an error gracefully. If your program isn't working like it should then it shouldn't be running, you should be fixing it. @@ -463,11 +513,13 @@ value = config.key - **Write for the golden path.** Never let edge cases bloat the code. Before handling them, just raise an exception. If an edge case becomes annoying enough, we'll handle it then — but write first and foremost for the common case. ### Control Flow + - Keep I/O as high up as possible. Make as many functions as possible pure. - Prefer `match` over `if/elif/else` chains when dispatching on conditions - more declarative and makes cases explicit - If you either have (a and b) or neither, don't make them both independently optional. Instead, put them in an optional tuple ### Types, Arguments, and Defaults + - Write your invariants into types as much as possible. - Use jaxtyping for tensor shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. @@ -483,12 +535,11 @@ value = config.key - Don't use `from __future__ import annotations` — use string quotes for forward references instead. ### Tensor Operations + - Try to use einops by default for clarity. - Assert shapes liberally - Document complex tensor manipulations - - ### Comments - Comments hide sloppy code. If you feel the need to write a comment, consider that you should instead @@ -498,10 +549,11 @@ value = config.key - separate an inlined computation into a meaningfully named variable - Don’t write dialogic / narrativised comments or code. Instead, write comments that describe the code as is, not the diff you're making. Examples of narrativising comments: - - `# the function now uses y instead of x` - - `# changed to be faster` - - `# we now traverse in reverse` + - `# the function now uses y instead of x` + - `# changed to be faster` + - `# we now traverse in reverse` - Here's an example of a bad diff, where the new comment makes reference to a change in code, not just the state of the code: + ``` 95 - # Reservoir states 96 - reservoir_states: list[ReservoirState] @@ -509,14 +561,15 @@ value = config.key 96 + reservoir: TensorReservoirState ``` - ### Other Important Software Development Practices + - Don't add legacy fallbacks or migration code - just change it and let old data be manually migrated if needed. -- Delete unused code. +- Delete unused code. - If an argument is always x, strongly consider removing as an argument and just inlining - **Update CLAUDE.md files** when changing code structure, adding/removing files, or modifying key interfaces. Update the CLAUDE.md in the same directory (or nearest parent) as the changed files. ### GitHub + - To view github issues and PRs, use the github cli (e.g. `gh issue view 28` or `gh pr view 30`). - When making PRs, use the github template defined in `.github/pull_request_template.md`. - Before committing, ALWAYS ensure you are on the correct branch and do not use `git add .` to add all unstaged files. Instead, add only the individual files you changed, don't commit all files. diff --git a/pyproject.toml b/pyproject.toml index 88c3405a8..38c516f7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,9 @@ spd-clustering = "spd.clustering.scripts.run_pipeline:cli" spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" +spd-investigate = "spd.investigate.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" +spd-graph-interp = "spd.graph_interp.scripts.run_slurm_cli:cli" [build-system] requires = ["setuptools", "wheel"] @@ -69,7 +71,7 @@ include = ["spd*"] [tool.ruff] line-length = 100 fix = true -extend-exclude = ["spd/app/frontend"] +extend-exclude = ["spd/app/frontend", ".circuits-ref"] [tool.ruff.lint] ignore = [ diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 0f95d54a8..d1ef6b58d 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -15,6 +15,7 @@ This is a **rapidly iterated research tool**. Key implications: - **Database is disposable**: Delete `.data/app/prompt_attr.db` if schema changes break things - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert +- **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. ## Running the App @@ -50,6 +51,9 @@ backend/ ├── intervention.py # Selective component activation ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering + ├── dataset_search.py # SimpleStories dataset search + ├── agents.py # Various useful endpoints that AI agents should look at when helping + ├── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ├── dataset_search.py # Dataset search (reads dataset from run config) └── agents.py # Various useful endpoints that AI agents should look at when helping ``` diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 6b2b09552..d0740dafd 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -9,6 +9,7 @@ import hashlib import io import json +import os import sqlite3 from pathlib import Path from typing import Literal @@ -24,7 +25,24 @@ # Persistent data directories _APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path. + + Checks env vars in order: + 1. SPD_INVESTIGATION_DIR - investigation mode, db at dir/app.db + 2. SPD_APP_DB_PATH - explicit override + 3. Default: .data/app/prompt_attr.db + """ + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + return Path(investigation_dir) / "app.db" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH class Run(BaseModel): @@ -111,7 +129,7 @@ class PromptAttrDB: """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index b7a6f8ed3..49dc651f2 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -7,8 +7,11 @@ from spd.app.backend.routers.data_sources import router as data_sources_router from spd.app.backend.routers.dataset_attributions import router as dataset_attributions_router from spd.app.backend.routers.dataset_search import router as dataset_search_router +from spd.app.backend.routers.graph_interp import router as graph_interp_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router from spd.app.backend.routers.pretrain_info import router as pretrain_info_router from spd.app.backend.routers.prompts import router as prompts_router from spd.app.backend.routers.runs import router as runs_router @@ -20,9 +23,12 @@ "correlations_router", "data_sources_router", "dataset_attributions_router", + "graph_interp_router", "dataset_search_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", "pretrain_info_router", "prompts_router", "runs_router", diff --git a/spd/app/backend/routers/data_sources.py b/spd/app/backend/routers/data_sources.py index 5287d91bd..6888b339f 100644 --- a/spd/app/backend/routers/data_sources.py +++ b/spd/app/backend/routers/data_sources.py @@ -28,15 +28,21 @@ class AutointerpInfo(BaseModel): class AttributionsInfo(BaseModel): subrun_id: str - n_batches_processed: int n_tokens_processed: int ci_threshold: float +class GraphInterpInfo(BaseModel): + subrun_id: str + config: dict[str, Any] | None + label_counts: dict[str, int] + + class DataSourcesResponse(BaseModel): harvest: HarvestInfo | None autointerp: AutointerpInfo | None attributions: AttributionsInfo | None + graph_interp: GraphInterpInfo | None router = APIRouter(prefix="/api/data_sources", tags=["data_sources"]) @@ -70,13 +76,21 @@ def get_data_sources(loaded: DepLoadedRun) -> DataSourcesResponse: storage = loaded.attributions.get_attributions() attributions_info = AttributionsInfo( subrun_id=loaded.attributions.subrun_id, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, ci_threshold=storage.ci_threshold, ) + graph_interp_info: GraphInterpInfo | None = None + if loaded.graph_interp is not None: + graph_interp_info = GraphInterpInfo( + subrun_id=loaded.graph_interp.subrun_id, + config=loaded.graph_interp.get_config(), + label_counts=loaded.graph_interp.get_label_counts(), + ) + return DataSourcesResponse( harvest=harvest_info, autointerp=autointerp_info, attributions=attributions_info, + graph_interp=graph_interp_info, ) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 4c3d07753..178eefc72 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -7,46 +7,43 @@ from typing import Annotated, Literal from fastapi import APIRouter, HTTPException, Query -from jaxtyping import Float from pydantic import BaseModel -from torch import Tensor from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry -from spd.dataset_attributions.storage import DatasetAttributionStorage +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool - n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -54,91 +51,67 @@ class ComponentAttributions(BaseModel): ) -def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str: - """Translate canonical layer + idx to concrete storage key. - - "embed" maps to the concrete embedding path (e.g. "wte") in storage. - "output" is a pseudo-layer used as-is in storage. - """ - if canonical_layer == "output": - return f"output:{component_idx}" - concrete = loaded.topology.canon_to_target(canonical_layer) - return f"{concrete}:{component_idx}" - - def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) return loaded.attributions.get_attributions() -def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" - if not storage.has_source(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as source in attributions", - ) - - -def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" - if not storage.has_target(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as target in attributions", - ) - - -def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" - return loaded.topology.get_unembed_weight() - - def _to_api_entries( - loaded: DepLoadedRun, entries: list[StorageEntry] + entries: list[StorageEntry], loaded: DepLoadedRun ) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format with canonical keys.""" - - def _canonicalize_layer(layer: str) -> str: - if layer == "output": - return layer - return loaded.topology.target_to_canon(layer) - return [ DatasetAttributionEntry( - component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}", - layer=_canonicalize_layer(e.layer), + component_key=e.component_key, + layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + loaded: DepLoadedRun, + component_key: str, + k: int, + metric: AttrMetric, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric), loaded + ), + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric), loaded + ), + positive_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "positive", metric), loaded + ), + negative_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "negative", metric), loaded + ), + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" if loaded.attributions is None: return DatasetAttributionMetadata( available=False, - n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() return DatasetAttributionMetadata( available=True, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -150,58 +123,18 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all metrics.""" storage = _require_storage(loaded) - component_key = _to_concrete_key(layer, component_idx, loaded) - - # Component can be both a source and a target, so we need to check both - is_source = storage.has_source(component_key) - is_target = storage.has_target(component_key) - - if not is_source and not is_target: - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found in attributions", - ) - - w_unembed = _get_w_unembed(loaded) if is_source else None - - return ComponentAttributions( - positive_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "positive") - ) - if is_target - else [], - negative_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "negative") - ) - if is_target - else [], - positive_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "positive", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], - negative_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "negative", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], + component_key = f"{layer}:{component_idx}" + + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, loaded, component_key, k, metric + ) + for metric in ATTR_METRICS + } ) @@ -213,16 +146,11 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = _to_concrete_key(layer, component_idx, loaded) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if layer == "output" else None - return _to_api_entries( - loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed) + storage.get_top_sources(f"{layer}:{component_idx}", k, sign, metric), loaded ) @@ -234,35 +162,9 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(layer, component_idx, loaded) - _require_source(storage, source_key) - - w_unembed = _get_w_unembed(loaded) - return _to_api_entries( - loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed) + storage.get_top_targets(f"{layer}:{component_idx}", k, sign, metric), loaded ) - - -@router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") -@log_errors -def get_attribution_between( - source_layer: str, - source_idx: int, - target_layer: str, - target_idx: int, - loaded: DepLoadedRun, -) -> float: - """Get attribution strength from source component to target component.""" - storage = _require_storage(loaded) - source_key = _to_concrete_key(source_layer, source_idx, loaded) - target_key = _to_concrete_key(target_layer, target_idx, loaded) - _require_source(storage, source_key) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py new file mode 100644 index 000000000..272bd4d80 --- /dev/null +++ b/spd/app/backend/routers/graph_interp.py @@ -0,0 +1,373 @@ +"""Graph interpretation endpoints. + +Serves context-aware component labels (output/input/unified) and the +prompt-edge graph produced by the graph_interp pipeline. +""" + +import random + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors +from spd.graph_interp.schemas import LabelResult +from spd.topology import TransformerTopology + +# TODO(oli): Remove MOCK_MODE after real data is available +MOCK_MODE = False + +MAX_GRAPH_NODES = 500 + + +_ALREADY_CANONICAL = {"embed", "output"} + + +def _concrete_to_canonical_key(concrete_key: str, topology: TransformerTopology) -> str: + layer, idx = concrete_key.rsplit(":", 1) + if layer in _ALREADY_CANONICAL: + return concrete_key + canonical = topology.target_to_canon(layer) + return f"{canonical}:{idx}" + + +def _canonical_to_concrete_key( + canonical_layer: str, component_idx: int, topology: TransformerTopology +) -> str: + concrete = topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +# -- Schemas ------------------------------------------------------------------- + + +class GraphInterpHeadline(BaseModel): + label: str + confidence: str + output_label: str | None + input_label: str | None + + +class LabelDetail(BaseModel): + label: str + confidence: str + reasoning: str + prompt: str + + +class GraphInterpDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + + +class PromptEdgeResponse(BaseModel): + related_key: str + pass_name: str + attribution: float + related_label: str | None + related_confidence: str | None + token_str: str | None + + +class GraphInterpComponentDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + edges: list[PromptEdgeResponse] + + +class GraphNode(BaseModel): + component_key: str + label: str + confidence: str + + +class GraphEdge(BaseModel): + source: str + target: str + attribution: float + pass_name: str + + +class ModelGraphResponse(BaseModel): + nodes: list[GraphNode] + edges: list[GraphEdge] + + +# -- Router -------------------------------------------------------------------- + +router = APIRouter(prefix="/api/graph_interp", tags=["graph_interp"]) + + +@router.get("/labels") +@log_errors +def get_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + if MOCK_MODE: + return _mock_all_labels(loaded) + + repo = loaded.graph_interp + if repo is None: + return {} + + topology = loaded.topology + unified = repo.get_all_unified_labels() + output = repo.get_all_output_labels() + input_ = repo.get_all_input_labels() + + all_keys = set(unified) | set(output) | set(input_) + result: dict[str, GraphInterpHeadline] = {} + + for concrete_key in all_keys: + u = unified.get(concrete_key) + o = output.get(concrete_key) + i = input_.get(concrete_key) + + label = u or o or i + assert label is not None + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + + result[canonical_key] = GraphInterpHeadline( + label=label.label, + confidence=label.confidence, + output_label=o.label if o else None, + input_label=i.label if i else None, + ) + + return result + + +def _to_detail(label: LabelResult | None) -> LabelDetail | None: + if label is None: + return None + return LabelDetail( + label=label.label, + confidence=label.confidence, + reasoning=label.reasoning, + prompt=label.prompt, + ) + + +@router.get("/labels/{layer}/{c_idx}") +@log_errors +def get_label_detail(layer: str, c_idx: int, loaded: DepLoadedRun) -> GraphInterpDetail: + if MOCK_MODE: + return _mock_label_detail(layer, c_idx) + + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + concrete_key = _canonical_to_concrete_key(layer, c_idx, loaded.topology) + + return GraphInterpDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + ) + + +@router.get("/detail/{layer}/{c_idx}") +@log_errors +def get_component_detail( + layer: str, c_idx: int, loaded: DepLoadedRun +) -> GraphInterpComponentDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + concrete_key = _canonical_to_concrete_key(layer, c_idx, topology) + + raw_edges = repo.get_prompt_edges(concrete_key) + tokenizer = loaded.tokenizer + edges = [] + for e in raw_edges: + rel_layer, rel_idx = e.related_key.rsplit(":", 1) + token_str = tokenizer.decode([int(rel_idx)]) if rel_layer in ("embed", "output") else None + edges.append( + PromptEdgeResponse( + related_key=_concrete_to_canonical_key(e.related_key, topology), + pass_name=e.pass_name, + attribution=e.attribution, + related_label=e.related_label, + related_confidence=e.related_confidence, + token_str=token_str, + ) + ) + + return GraphInterpComponentDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + edges=edges, + ) + + +@router.get("/graph") +@log_errors +def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + if MOCK_MODE: + return _mock_model_graph(loaded) + + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + + unified = repo.get_all_unified_labels() + nodes = [] + for concrete_key, label in unified.items(): + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + nodes.append( + GraphNode( + component_key=canonical_key, + label=label.label, + confidence=label.confidence, + ) + ) + + nodes = nodes[:MAX_GRAPH_NODES] + node_keys = {n.component_key for n in nodes} + + raw_edges = repo.get_all_prompt_edges() + edges = [] + for e in raw_edges: + comp_canon = _concrete_to_canonical_key(e.component_key, topology) + rel_canon = _concrete_to_canonical_key(e.related_key, topology) + + match e.pass_name: + case "output": + source, target = comp_canon, rel_canon + case "input": + source, target = rel_canon, comp_canon + + if source not in node_keys or target not in node_keys: + continue + + edges.append( + GraphEdge( + source=source, + target=target, + attribution=e.attribution, + pass_name=e.pass_name, + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) + + +# -- Mock data (TODO: remove after real data available) ------------------------ + +_MOCK_LABELS = [ + "sentence-final punctuation", + "proper noun completion", + "emotional adjective selection", + "temporal adverb prediction", + "morphological suffix (-ing/-ed)", + "determiner before noun", + "dialogue quotation marks", + "plural noun suffix", + "clause boundary detection", + "verb tense agreement", + "spatial preposition", + "possessive pronoun", + "narrative action verb", + "abstract emotion noun", + "comparative adjective form", + "subject-verb agreement", + "article selection (a/the)", + "comma splice detection", + "pronoun resolution", + "negation scope", +] + +_MOCK_INPUT_LABELS = [ + "sentence-initial capitals", + "mid-sentence verb position", + "adjective-noun boundary", + "clause-final position", + "article-noun sequence", + "subject pronoun at boundary", + "preposition-object pair", + "verb stem before suffix", + "quotation boundary", + "comma-separated items", +] + + +def _mock_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + rng = random.Random(42) + topology = loaded.topology + confidences = ["high", "high", "high", "medium", "medium", "low"] + + result: dict[str, GraphInterpHeadline] = {} + for target_path, components in loaded.model.components.items(): + canon = topology.target_to_canon(target_path) + n_components = components.C + n_mock = min(n_components, rng.randint(5, 20)) + indices = sorted(rng.sample(range(n_components), n_mock)) + for idx in indices: + key = f"{canon}:{idx}" + result[key] = GraphInterpHeadline( + label=rng.choice(_MOCK_LABELS), + confidence=rng.choice(confidences), + output_label=rng.choice(_MOCK_LABELS), + input_label=rng.choice(_MOCK_INPUT_LABELS), + ) + return result + + +def _mock_label_detail(layer: str, c_idx: int) -> GraphInterpDetail: + rng = random.Random(hash((layer, c_idx))) + conf = rng.choice(["high", "medium", "low"]) + return GraphInterpDetail( + output=LabelDetail( + label=rng.choice(_MOCK_LABELS), + confidence=conf, + reasoning=f"Output: Component {layer}:{c_idx} writes {rng.choice(_MOCK_LABELS).lower()} tokens to the residual stream.", + prompt="(mock prompt)", + ), + input=LabelDetail( + label=rng.choice(_MOCK_INPUT_LABELS), + confidence=conf, + reasoning=f"Input: Component {layer}:{c_idx} fires on {rng.choice(_MOCK_INPUT_LABELS).lower()} patterns.", + prompt="(mock prompt)", + ), + unified=LabelDetail( + label=rng.choice(_MOCK_LABELS), + confidence=conf, + reasoning=f"Unified: Combines output ({rng.choice(_MOCK_LABELS).lower()}) and input ({rng.choice(_MOCK_INPUT_LABELS).lower()}) functions.", + prompt="(mock prompt)", + ), + ) + + +def _mock_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + labels = _mock_all_labels(loaded) + + nodes = [ + GraphNode(component_key=key, label=h.label, confidence=h.confidence) + for key, h in labels.items() + ] + + rng = random.Random(42) + keys = list(labels.keys()) + edges: list[GraphEdge] = [] + + for key in keys: + layer = key.rsplit(":", 1)[0] + later_keys = [k for k in keys if k.rsplit(":", 1)[0] != layer] + n_edges = rng.randint(1, 4) + for target in rng.sample(later_keys, min(n_edges, len(later_keys))): + edges.append( + GraphEdge( + source=key, + target=target, + attribution=rng.uniform(-1.0, 1.0), + pass_name=rng.choice(["output", "input"]), + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index a51b1649c..238c8f7fe 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -264,8 +264,20 @@ def _build_out_probs( def stream_computation( work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: @@ -280,28 +292,31 @@ def compute_thread() -> None: progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] == "progress": + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -513,7 +528,7 @@ def work(on_progress: ProgressCallback) -> GraphData: l0_total=fg.l0_total, ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -738,7 +753,7 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: ), ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) @dataclass diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 1ccdbf86c..bf5fc981c 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -152,52 +152,55 @@ def _run_intervention_forward( @router.post("") @log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: +def run_intervention( + request: InterventionRequest, loaded: DepLoadedRun, manager: DepStateManager +) -> InterventionResponse: """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + with manager.gpu_lock(): + token_ids = loaded.tokenizer.encode(request.text) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [ + ( + loaded.topology.canon_to_target(n.layer), + n.seq_pos, + n.component_idx, + ) + for n in request.nodes + ] - active_nodes = [ - ( - loaded.topology.canon_to_target(n.layer), - n.seq_pos, - n.component_idx, + seq_len = tokens.shape[1] + for _, seq_pos, _ in active_nodes: + if seq_pos >= seq_len: + raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") + + result = compute_intervention_forward( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + top_k=request.top_k, + tokenizer=loaded.tokenizer, ) - for n in request.nodes - ] - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions + predictions_per_position = [ + [ + TokenPrediction( + token=token, + token_id=token_id, + spd_prob=spd_prob, + target_prob=target_prob, + logit=logit, + target_logit=target_logit, + ) + for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions + ] + for pos_predictions in result.predictions_per_position ] - for pos_predictions in result.predictions_per_position - ] - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return InterventionResponse( + input_tokens=result.input_tokens, + predictions_per_position=predictions_per_position, + ) @router.post("/run") @@ -206,14 +209,16 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + response = _run_intervention_forward( + text=request.text, + selected_nodes=request.selected_nodes, + top_k=request.top_k, + loaded=loaded, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, @@ -321,12 +326,13 @@ def fork_intervention_run( modified_text = loaded.tokenizer.decode(modified_token_ids) # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + response = _run_intervention_forward( + text=modified_text, + selected_nodes=parent_run.selected_nodes, + top_k=request.top_k, + loaded=loaded, + ) # Save the forked run fork_id = db.save_forked_intervention_run( diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..0784a0eb4 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,317 @@ +"""Investigations endpoint for viewing agent investigation results. + +Lists and serves investigation data from SPD_OUT_DIR/investigations/. +Each investigation directory contains findings from a single agent run. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +INVESTIGATIONS_DIR = SPD_OUT_DIR / "investigations" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + title: str | None + summary: str | None + status: str | None + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + artifact_ids: list[str] + title: str | None + summary: str | None + status: str | None + + +def _parse_metadata(inv_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from an investigation directory.""" + metadata_path = inv_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except Exception: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + try: + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + except Exception: + pass + + return last_time, last_msg, count + + +def _parse_task_summary(inv_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from an investigation directory. Returns (title, summary, status).""" + summary_path = inv_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except Exception: + return None, None, None + + +def _list_artifact_ids(inv_path: Path) -> list[str]: + """List all artifact IDs for an investigation.""" + artifacts_dir = inv_path / "artifacts" + if not artifacts_dir.exists(): + return [] + return [f.stem for f in sorted(artifacts_dir.glob("graph_*.json"))] + + +def _get_created_at(inv_path: Path, metadata: dict[str, Any] | None) -> str: + """Get creation time for an investigation.""" + events_path = inv_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except Exception: + pass + + if metadata and "created_at" in metadata: + return metadata["created_at"] + + return datetime.fromtimestamp(inv_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations(loaded: DepLoadedRun) -> list[InvestigationSummary]: + """List investigations for the currently loaded run.""" + if not INVESTIGATIONS_DIR.exists(): + return [] + + wandb_path = loaded.run.wandb_path + results = [] + + for inv_path in INVESTIGATIONS_DIR.iterdir(): + if not inv_path.is_dir() or not inv_path.name.startswith("inv-"): + continue + + inv_id = inv_path.name + metadata = _parse_metadata(inv_path) + + meta_wandb_path = metadata.get("wandb_path") if metadata else None + if meta_wandb_path is None: + continue + # Normalize to canonical form for comparison (strips "runs/", "wandb:" prefix, etc.) + try: + e, p, r = parse_wandb_run_path(meta_wandb_path) + canonical_meta_path = f"{e}/{p}/{r}" + except ValueError: + continue + if canonical_meta_path != wandb_path: + continue + + events_path = inv_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(inv_path) + + explanations_path = inv_path / "explanations.jsonl" + + results.append( + InvestigationSummary( + id=inv_id, + wandb_path=meta_wandb_path, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + has_research_log=(inv_path / "research_log.md").exists(), + has_explanations=explanations_path.exists() + and explanations_path.stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{inv_id}") +def get_investigation(inv_id: str) -> InvestigationDetail: + """Get full details of an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + + if not inv_path.exists() or not inv_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + + metadata = _parse_metadata(inv_path) + + research_log = None + research_log_path = inv_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + events = [] + events_path = inv_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + explanations: list[dict[str, Any]] = [] + explanations_path = inv_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(inv_path) + artifact_ids = _list_artifact_ids(inv_path) + + return InvestigationDetail( + id=inv_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + artifact_ids=artifact_ids, + title=title, + summary=summary, + status=status, + ) + + +class LaunchRequest(BaseModel): + prompt: str + + +class LaunchResponse(BaseModel): + inv_id: str + job_id: str + + +@router.post("/launch") +def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) -> LaunchResponse: + """Launch a new investigation for the currently loaded run.""" + from spd.investigate.scripts.run_slurm import launch_investigation + + result = launch_investigation( + wandb_path=loaded.run.wandb_path, + prompt=request.prompt, + context_length=loaded.context_length, + max_turns=50, + partition="h200-reserved", + time="8:00:00", + job_suffix=None, + ) + return LaunchResponse(inv_id=result.inv_id, job_id=result.job_id) + + +@router.get("/{inv_id}/artifacts") +def list_artifacts(inv_id: str) -> list[str]: + """List all artifact IDs for an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + if not inv_path.exists(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + return _list_artifact_ids(inv_path) + + +@router.get("/{inv_id}/artifacts/{artifact_id}") +def get_artifact(inv_id: str, artifact_id: str) -> dict[str, Any]: + """Get a specific artifact by ID.""" + inv_path = INVESTIGATIONS_DIR / inv_id + artifact_path = inv_path / "artifacts" / f"{artifact_id}.json" + + if not artifact_path.exists(): + raise HTTPException( + status_code=404, + detail=f"Artifact {artifact_id} not found in {inv_id}", + ) + + data: dict[str, Any] = json.loads(artifact_path.read_text()) + return data diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..ad2beeb13 --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1626 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Callable, Generator +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_ci_only, + compute_intervention_forward, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import _build_out_probs +from spd.app.backend.routers.pretrain_info import _get_pretrain_info +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + + +@dataclass +class InvestigationConfig: + """Configuration for investigation mode. All paths are required when in investigation mode.""" + + events_log_path: Path + investigation_dir: Path + + +_investigation_config: InvestigationConfig | None = None + + +def set_investigation_config(config: InvestigationConfig) -> None: + """Configure MCP for investigation mode.""" + global _investigation_config + _investigation_config = config + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if in investigation mode.""" + if _investigation_config is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_investigation_config.events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response. + + Per JSON-RPC 2.0 spec, exactly one of result/error must be present (not both, not neither). + Use model_dump(exclude_none=True) when serializing to avoid including null fields. + """ + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up', '2.attn.o')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., '0.mlp.up:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), + ToolDefinition( + name="save_graph_artifact", + description="""Save a graph as an artifact for inclusion in your research report. + +After calling optimize_graph and getting a graph_id, call this to save the graph +as an artifact. Then reference it in your research log using the spd:graph syntax: + +```spd:graph +artifact: graph_001 +``` + +This allows humans reviewing your investigation to see interactive circuit visualizations +inline with your research notes.""", + inputSchema={ + "type": "object", + "properties": { + "graph_id": { + "type": "integer", + "description": "The graph ID returned by optimize_graph", + }, + "caption": { + "type": "string", + "description": "Optional caption describing what this graph shows", + }, + }, + "required": ["graph_id"], + }, + ), + ToolDefinition( + name="probe_component", + description="""Fast CI probing on custom text. + +Computes causal importance values and subcomponent activations for a specific component +across all positions in the input text. Also returns next-token probabilities. + +Use this for quick, targeted analysis of how a component responds to specific inputs.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The input text to probe", + }, + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + }, + "required": ["text", "layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_activation_examples", + description="""Get activation examples from harvest data for a component. + +Returns examples showing token windows where the component fires, along with +CI values and activation strengths at each position. + +Use this to understand what inputs activate a component.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "limit": { + "type": "integer", + "description": "Maximum number of examples to return (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_attributions", + description="""Get dataset-level component dependencies from pre-computed attributions. + +Returns the top source and target components that this component attributes to/from, +aggregated over the training dataset. Both positive and negative attributions are returned. + +Use this to understand a component's role in the broader network.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up') or 'output'", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "k": { + "type": "integer", + "description": "Number of top attributions to return per direction (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_model_info", + description="""Get architecture details about the pretrained model. + +Returns model type, summary, target model config, topology, and pretrain info. +No parameters required.""", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), + ToolDefinition( + name="get_attribution_strength", + description="""Query the attribution strength between two specific components. + +Returns the dataset-aggregated attribution value from source to target. + +Use this to check if two components have a strong connection.""", + inputSchema={ + "type": "object", + "properties": { + "source_layer": { + "type": "string", + "description": "Canonical layer name of source component (e.g., '0.mlp.up')", + }, + "source_idx": { + "type": "integer", + "description": "Source component index", + }, + "target_layer": { + "type": "string", + "description": "Canonical layer name of target component (e.g., '1.attn.q') or 'output'", + }, + "target_idx": { + "type": "integer", + "description": "Target component index", + }, + }, + "required": ["source_layer", "source_idx", "target_layer", "target_idx"], + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _concrete_key(canonical_layer: str, component_idx: int, loaded: Any) -> str: + """Translate canonical layer + idx to concrete storage key.""" + if canonical_layer == "output": + return f"output:{component_idx}" + concrete = loaded.topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +def _canonicalize_layer(layer: str, loaded: Any) -> str: + """Translate concrete layer name to canonical, passing through 'output'.""" + if layer == "output": + return layer + return loaded.topology.target_to_canon(layer) + + +def _canonicalize_key(concrete_key: str, loaded: Any) -> str: + """Translate concrete component key (e.g. 'h.0.mlp.c_fc:444') to canonical ('0.mlp.up:444').""" + layer, idx = concrete_key.rsplit(":", 1) + return f"{_canonicalize_layer(layer, loaded)}:{idx}" + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + adv_pgd=None, # AdvPGDConfig(n_steps=10, step_size=0.01, init="random"), + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + # Build output probs for response + out_probs = _build_out_probs( + ci_masked_out_logits, + target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + canonical_key = f"{layer}:{component_idx}" + + # Harvest/interp repos store concrete keys (e.g. "h.0.mlp.c_fc:444") + concrete_layer = loaded.topology.canon_to_target(layer) + concrete_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_info: {canonical_key}", + {"layer": layer, "idx": component_idx}, + ) + + result: dict[str, Any] = {"component_key": canonical_key} + + # Get interpretation + if loaded.interp is not None: + interp = loaded.interp.get_interpretation(concrete_key) + if interp is not None: + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + else: + result["interpretation"] = None + + # Get token stats + assert loaded.harvest is not None, "harvest data not loaded" + token_stats = loaded.harvest.get_token_stats() + if token_stats is not None: + input_stats = analysis.get_input_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + else: + result["token_stats"] = None + + # Get correlations (return canonical keys) + correlations = loaded.harvest.get_correlations() + if correlations is not None and analysis.has_component(correlations, concrete_key): + result["correlated_components"] = { + "precision": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "precision", top_k + ) + ], + "pmi": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Parse node keys + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "embed", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention_forward( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + top_k=top_k, + tokenizer=loaded.tokenizer, + ) + + predictions = [] + for pos_predictions in result.predictions_per_position: + pos_result = [] + for token, token_id, spd_prob, _logit, target_prob, _target_logit in pos_predictions: + pos_result.append( + { + "token": token, + "token_id": token_id, + "circuit_prob": round(spd_prob, 6), + "full_model_prob": round(target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _require_investigation_config() -> InvestigationConfig: + """Get investigation config, raising if not in investigation mode.""" + assert _investigation_config is not None, "Not running in investigation mode" + return _investigation_config + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + config = _require_investigation_config() + content = params["content"] + research_log_path = config.investigation_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.investigate.schemas import BehaviorExplanation, ComponentInfo, Evidence + + config = _require_investigation_config() + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = config.investigation_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + config = _require_investigation_config() + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = config.investigation_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +def _tool_save_graph_artifact(params: dict[str, Any]) -> dict[str, Any]: + """Save a graph as an artifact for the research report. + + Uses the same filtering logic as the main graph API: + 1. Filter nodes by CI threshold + 2. Add pseudo nodes (wte, output) + 3. Filter edges to only active nodes + 4. Apply edge limit + """ + config = _require_investigation_config() + manager, loaded = _get_state() + + graph_id = params["graph_id"] + caption = params.get("caption") + ci_threshold = params.get("ci_threshold", 0.5) + edge_limit = params.get("edge_limit", 5000) + + _log_event( + "tool_call", + f"save_graph_artifact: graph_id={graph_id}", + {"graph_id": graph_id, "caption": caption}, + ) + + # Fetch graph from DB + result = manager.db.get_graph(graph_id) + if result is None: + raise ValueError(f"Graph with id={graph_id} not found") + + graph, prompt_id = result + + # Get tokens from prompt + prompt_record = manager.db.get_prompt(prompt_id) + if prompt_record is None: + raise ValueError(f"Prompt with id={prompt_id} not found") + + tokens = [loaded.tokenizer.get_tok_display(tid) for tid in prompt_record.token_ids] + num_tokens = len(tokens) + + # Create artifacts directory + artifacts_dir = config.investigation_dir / "artifacts" + artifacts_dir.mkdir(exist_ok=True) + + # Generate artifact ID (find max existing number to avoid collisions) + existing_nums = [] + for f in artifacts_dir.glob("graph_*.json"): + try: + num = int(f.stem.split("_")[1]) + existing_nums.append(num) + except (IndexError, ValueError): + continue + artifact_num = max(existing_nums, default=0) + 1 + artifact_id = f"graph_{artifact_num:03d}" + + # Compute out_probs from stored logits + out_probs = _build_out_probs( + graph.ci_masked_out_logits, + graph.target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Step 1: Filter nodes by CI threshold (same as main graph API) + filtered_ci_vals = {k: v for k, v in graph.node_ci_vals.items() if v > ci_threshold} + l0_total = len(filtered_ci_vals) + + # Step 2: Add pseudo nodes (embed and output) - same as _add_pseudo_layer_nodes + node_ci_vals_with_pseudo = dict(filtered_ci_vals) + for seq_pos in range(num_tokens): + node_ci_vals_with_pseudo[f"embed:{seq_pos}:0"] = 1.0 + for key, out_prob in out_probs.items(): + seq_pos, token_id = key.split(":") + node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob + + # Step 3: Filter edges to only active nodes + active_node_keys = set(node_ci_vals_with_pseudo.keys()) + filtered_edges = [ + e + for e in graph.edges + if str(e.source) in active_node_keys and str(e.target) in active_node_keys + ] + + # Step 4: Sort by strength and apply edge limit + filtered_edges.sort(key=lambda e: abs(e.strength), reverse=True) + filtered_edges = filtered_edges[:edge_limit] + + # Build edges data + edges_data = [ + { + "src": str(e.source), + "tgt": str(e.target), + "val": e.strength, + } + for e in filtered_edges + ] + + # Compute max abs attr from filtered edges + max_abs_attr = max((abs(e.strength) for e in filtered_edges), default=0.0) + + # Filter nodeSubcompActs to match nodeCiVals + filtered_subcomp_acts = { + k: v for k, v in graph.node_subcomp_acts.items() if k in node_ci_vals_with_pseudo + } + + # Build artifact data (self-contained GraphData, same structure as API response) + artifact = { + "type": "graph", + "id": artifact_id, + "caption": caption, + "graph_id": graph_id, + "data": { + "tokens": tokens, + "edges": edges_data, + "outputProbs": { + k: { + "prob": v.prob, + "logit": v.logit, + "target_prob": v.target_prob, + "target_logit": v.target_logit, + "token": v.token, + } + for k, v in out_probs.items() + }, + "nodeCiVals": node_ci_vals_with_pseudo, + "nodeSubcompActs": filtered_subcomp_acts, + "maxAbsAttr": max_abs_attr, + "l0_total": l0_total, + }, + } + + # Save artifact + artifact_path = artifacts_dir / f"{artifact_id}.json" + artifact_path.write_text(json.dumps(artifact, indent=2)) + + _log_event( + "artifact_saved", + f"Saved graph artifact: {artifact_id}", + {"artifact_id": artifact_id, "graph_id": graph_id, "path": str(artifact_path)}, + ) + + return {"artifact_id": artifact_id, "path": str(artifact_path)} + + +def _tool_probe_component(params: dict[str, Any]) -> dict[str, Any]: + """Fast CI probing on custom text for a specific component.""" + manager, loaded = _get_state() + + text = params["text"] + layer = params["layer"] + component_idx = params["component_idx"] + + _log_event( + "tool_call", + f"probe_component: '{text[:50]}...' layer={layer} idx={component_idx}", + {"text": text, "layer": layer, "component_idx": component_idx}, + ) + + token_ids = loaded.tokenizer.encode(text) + assert token_ids, "Text produced no tokens" + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + concrete_layer = loaded.topology.canon_to_target(layer) + + with manager.gpu_lock(): + result = compute_ci_only( + model=loaded.model, tokens=tokens_tensor, sampling=loaded.config.sampling + ) + + ci_values = result.ci_lower_leaky[concrete_layer][0, :, component_idx].tolist() + subcomp_acts = result.component_acts[concrete_layer][0, :, component_idx].tolist() + + # Get next token probs from target model output + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = result.target_out_probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "tokens": token_strings, + "ci_values": ci_values, + "subcomp_acts": subcomp_acts, + "next_token_probs": next_token_probs, + } + + +def _tool_get_component_activation_examples(params: dict[str, Any]) -> dict[str, Any]: + """Get activation examples from harvest data.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + limit = params.get("limit", 10) + + concrete_layer = loaded.topology.canon_to_target(layer) + component_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_activation_examples: {component_key}", + {"layer": layer, "component_idx": component_idx, "limit": limit}, + ) + + assert loaded.harvest is not None, "harvest data not loaded" + canonical_key = f"{layer}:{component_idx}" + comp = loaded.harvest.get_component(component_key) + if comp is None: + return {"component_key": canonical_key, "examples": [], "total": 0} + + examples = [] + for ex in comp.activation_examples[:limit]: + token_strings = [loaded.tokenizer.get_tok_display(t) for t in ex.token_ids] + examples.append( + { + "tokens": token_strings, + "ci_values": ex.activations["causal_importance"], + "component_acts": ex.activations["component_activation"], + } + ) + + return { + "component_key": canonical_key, + "examples": examples, + "total": len(comp.activation_examples), + "mean_ci": comp.mean_activations["causal_importance"], + } + + +# def _tool_get_component_attributions(params: dict[str, Any]) -> dict[str, Any]: +# """Get dataset-level component dependencies.""" +# _, loaded = _get_state() + +# layer = params["layer"] +# component_idx = params["component_idx"] +# k = params.get("k", 10) + +# assert loaded.attributions is not None, "dataset attributions not loaded" +# storage = loaded.attributions.get_attributions() + +# concrete_layer = loaded.topology.canon_to_target(layer) if layer != "output" else "output" +# component_key = f"{concrete_layer}:{component_idx}" + +# _log_event( +# "tool_call", +# f"get_component_attributions: {component_key}", +# {"layer": layer, "component_idx": component_idx, "k": k}, +# ) + +# is_source = storage.has_source(component_key) +# is_target = storage.has_target(component_key) + +# assert is_source or is_target, f"Component {component_key} not found in attributions" + +# w_unembed = loaded.topology.get_unembed_weight() if is_source else None + +# def _entries_to_dicts( +# entries: list[Any], +# ) -> list[dict[str, Any]]: +# return [ +# { +# "component_key": f"{_canonicalize_layer(e.layer, loaded)}:{e.component_idx}", +# "layer": _canonicalize_layer(e.layer, loaded), +# "component_idx": e.component_idx, +# "value": e.value, +# } +# for e in entries +# ] + +# positive_sources = ( +# _entries_to_dicts(storage.get_top_sources(component_key, k, "positive")) +# if is_target +# else [] +# ) +# negative_sources = ( +# _entries_to_dicts(storage.get_top_sources(component_key, k, "negative")) +# if is_target +# else [] +# ) +# positive_targets = ( +# _entries_to_dicts( +# storage.get_top_targets( +# component_key, +# k, +# "positive", +# w_unembed=w_unembed, +# include_outputs=w_unembed is not None, +# ) +# ) +# if is_source +# else [] +# ) +# negative_targets = ( +# _entries_to_dicts( +# storage.get_top_targets( +# component_key, +# k, +# "negative", +# w_unembed=w_unembed, +# include_outputs=w_unembed is not None, +# ) +# ) +# if is_source +# else [] +# ) + +# return { +# "component_key": component_key, +# "positive_sources": positive_sources, +# "negative_sources": negative_sources, +# "positive_targets": positive_targets, +# "negative_targets": negative_targets, +# } + + +def _tool_get_model_info(_params: dict[str, Any]) -> dict[str, Any]: + """Get architecture details about the pretrained model.""" + _, loaded = _get_state() + + _log_event("tool_call", "get_model_info", {}) + + info = _get_pretrain_info(loaded.config) + return info.model_dump() + + +def _tool_get_attribution_strength(params: dict[str, Any]) -> dict[str, Any]: + """Query attribution between two specific components.""" + _, loaded = _get_state() + + source_layer = params["source_layer"] + source_idx = params["source_idx"] + target_layer = params["target_layer"] + target_idx = params["target_idx"] + + assert loaded.attributions is not None, "dataset attributions not loaded" + storage = loaded.attributions.get_attributions() + + source_key = _concrete_key(source_layer, source_idx, loaded) + target_key = _concrete_key(target_layer, target_idx, loaded) + + _log_event( + "tool_call", + f"get_attribution_strength: {source_key} → {target_key}", + {"source": source_key, "target": target_key}, + ) + + value = storage.get_attribution(source_key, target_key) + + return {"value": value} + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +_STREAMING_TOOLS: dict[str, Callable[..., Generator[dict[str, Any]]]] = { + "optimize_graph": _tool_optimize_graph, +} + +_SIMPLE_TOOLS: dict[str, Callable[..., dict[str, Any]]] = { + "get_component_info": _tool_get_component_info, + "run_ablation": _tool_run_ablation, + "search_dataset": _tool_search_dataset, + "create_prompt": _tool_create_prompt, + "update_research_log": _tool_update_research_log, + "save_explanation": _tool_save_explanation, + "set_investigation_summary": _tool_set_investigation_summary, + "save_graph_artifact": _tool_save_graph_artifact, + "probe_component": _tool_probe_component, + "get_component_activation_examples": _tool_get_component_activation_examples, + # "get_component_attributions": _tool_get_component_attributions, + "get_model_info": _tool_get_model_info, + "get_attribution_strength": _tool_get_attribution_strength, +} + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name in _STREAMING_TOOLS: + return _STREAMING_TOOLS[name](arguments) + + if name in _SIMPLE_TOOLS: + result = _SIMPLE_TOOLS[name](arguments) + return {"content": [{"type": "text", "text": json.dumps(result, indent=2)}]} + + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(exclude_none=True), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True) + ) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump(exclude_none=True))}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump(exclude_none=True))}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump( + exclude_none=True + ) + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump(exclude_none=True) + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump(exclude_none=True) + ) diff --git a/spd/app/backend/routers/pretrain_info.py b/spd/app/backend/routers/pretrain_info.py index 2872423c9..424f7b035 100644 --- a/spd/app/backend/routers/pretrain_info.py +++ b/spd/app/backend/routers/pretrain_info.py @@ -38,6 +38,7 @@ class TopologyInfo(BaseModel): class PretrainInfoResponse(BaseModel): model_type: str summary: str + dataset_short: str | None target_model_config: dict[str, Any] | None pretrain_config: dict[str, Any] | None pretrain_wandb_path: str | None @@ -161,6 +162,27 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) return " · ".join(parts) +_DATASET_SHORT_NAMES: dict[str, str] = { + "simplestories": "SS", + "pile": "Pile", + "tinystories": "TS", +} + + +def _get_dataset_short(pretrain_config: dict[str, Any] | None) -> str | None: + """Extract a short dataset label from the pretrain config.""" + if pretrain_config is None: + return None + dataset_name: str = ( + pretrain_config.get("train_dataset_config", {}).get("name", "") + or pretrain_config.get("dataset", "") + ).lower() + for key, short in _DATASET_SHORT_NAMES.items(): + if key in dataset_name: + return short + return None + + def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: """Extract pretrain info from an SPD config.""" model_class_name = spd_config.pretrained_model_class @@ -190,10 +212,12 @@ def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: n_blocks = target_model_config.get("n_layer", 0) if target_model_config else 0 topology = _build_topology(model_type, n_blocks) summary = _build_summary(model_type, target_model_config) + dataset_short = _get_dataset_short(pretrain_config) return PretrainInfoResponse( model_type=model_type, summary=summary, + dataset_short=dataset_short, target_model_config=target_model_config, pretrain_config=pretrain_config, pretrain_wandb_path=pretrain_wandb_path, diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 0989cea54..daac71373 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -10,11 +10,13 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.dependencies import DepStateManager +from spd.app.backend.routers.graph_interp import MOCK_MODE as _GRAPH_INTERP_MOCK_MODE from spd.app.backend.state import RunState from spd.app.backend.utils import log_errors from spd.autointerp.repo import InterpRepo from spd.configs import LMTaskConfig from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -42,6 +44,7 @@ class LoadedRun(BaseModel): backend_user: str dataset_attributions_available: bool dataset_search_enabled: bool + graph_interp_available: bool router = APIRouter(prefix="/api", tags=["runs"]) @@ -128,6 +131,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): harvest=HarvestRepo.open_most_recent(run_id), interp=InterpRepo.open(run_id), attributions=AttributionRepo.open(run_id), + graph_interp=GraphInterpRepo.open(run_id), ) logger.info(f"[API] Run {run.id} loaded on {DEVICE}") @@ -165,6 +169,9 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: backend_user=getpass.getuser(), dataset_attributions_available=manager.run_state.attributions is not None, dataset_search_enabled=dataset_search_enabled, + # TODO(oli): Remove MOCK_MODE import after real data available + graph_interp_available=manager.run_state.graph_interp is not None + or _GRAPH_INTERP_MOCK_MODE, ) diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 3804ce756..23cf47333 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -32,14 +32,18 @@ data_sources_router, dataset_attributions_router, dataset_search_router, + graph_interp_router, graphs_router, intervention_router, + investigations_router, + mcp_router, pretrain_info_router, prompts_router, runs_router, ) from spd.app.backend.state import StateManager from spd.log import logger +from spd.settings import SPD_APP_DEFAULT_RUN from spd.utils.distributed_utils import get_device DEVICE = get_device() @@ -48,6 +52,11 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + import os + from pathlib import Path + + from spd.app.backend.routers.mcp import InvestigationConfig, set_investigation_config + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -58,6 +67,24 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for investigation mode (derives paths from investigation dir) + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + inv_dir = Path(investigation_dir) + set_investigation_config( + InvestigationConfig( + events_log_path=inv_dir / "events.jsonl", + investigation_dir=inv_dir, + ) + ) + logger.info(f"[STARTUP] Investigation mode enabled: dir={investigation_dir}") + + if SPD_APP_DEFAULT_RUN is not None: + from spd.app.backend.routers.runs import load_run + + logger.info(f"[STARTUP] Auto-loading default run: {SPD_APP_DEFAULT_RUN}") + load_run(SPD_APP_DEFAULT_RUN, context_length=512, manager=manager) + yield manager.close() @@ -157,7 +184,10 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_search_router) app.include_router(dataset_attributions_router) app.include_router(agents_router) +app.include_router(investigations_router) +app.include_router(mcp_router) app.include_router(data_sources_router) +app.include_router(graph_interp_router) app.include_router(pretrain_info_router) diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index cf71c2bc6..2cdabda73 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -5,14 +5,20 @@ - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +from fastapi import HTTPException + from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB, Run from spd.autointerp.repo import InterpRepo from spd.configs import Config from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.models.component_model import ComponentModel from spd.topology import TransformerTopology @@ -32,6 +38,7 @@ class RunState: harvest: HarvestRepo | None interp: InterpRepo | None attributions: AttributionRepo | None + graph_interp: GraphInterpRepo | None @dataclass @@ -62,6 +69,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -104,3 +112,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() diff --git a/spd/app/frontend/package-lock.json b/spd/app/frontend/package-lock.json index b6c451303..32da0218c 100644 --- a/spd/app/frontend/package-lock.json +++ b/spd/app/frontend/package-lock.json @@ -7,6 +7,9 @@ "": { "name": "frontend", "version": "0.0.0", + "dependencies": { + "marked": "^17.0.1" + }, "devDependencies": { "@eslint/js": "^9.38.0", "@sveltejs/vite-plugin-svelte": "^6.2.1", @@ -2347,6 +2350,18 @@ "@jridgewell/sourcemap-codec": "^1.5.5" } }, + "node_modules/marked": { + "version": "17.0.1", + "resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz", + "integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==", + "license": "MIT", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 20" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", diff --git a/spd/app/frontend/package.json b/spd/app/frontend/package.json index f54e1bb3d..f298885ce 100644 --- a/spd/app/frontend/package.json +++ b/spd/app/frontend/package.json @@ -27,5 +27,8 @@ "typescript": "~5.9.3", "typescript-eslint": "^8.46.2", "vite": "^7.1.7" + }, + "dependencies": { + "marked": "^17.0.1" } } diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index d20831c1a..a82c71ce1 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -3,13 +3,13 @@ import { computeMaxAbsComponentAct } from "../lib/colors"; import { COMPONENT_CARD_CONSTANTS } from "../lib/componentCardConstants"; import { anyCorrelationStatsEnabled, displaySettings } from "../lib/displaySettings.svelte"; - import { getLayerAlias } from "../lib/layerAliasing"; import type { ActivationContextsSummary, SubcomponentMetadata } from "../lib/promptAttributionsTypes"; import { useComponentData } from "../lib/useComponentData.svelte"; import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; import ActivationContextsPagedTable from "./ActivationContextsPagedTable.svelte"; import ComponentProbeInput from "./ComponentProbeInput.svelte"; import ComponentCorrelationMetrics from "./ui/ComponentCorrelationMetrics.svelte"; + import GraphInterpBadge from "./ui/GraphInterpBadge.svelte"; import InterpretationBadge from "./ui/InterpretationBadge.svelte"; import SectionHeader from "./ui/SectionHeader.svelte"; import StatusText from "./ui/StatusText.svelte"; @@ -38,6 +38,9 @@ let currentIntruderScore = $derived( currentMetadata ? runState.getIntruderScore(`${selectedLayer}:${currentMetadata.subcomponent_idx}`) : null, ); + let currentGraphInterpLabel = $derived( + currentMetadata ? runState.getGraphInterpLabel(`${selectedLayer}:${currentMetadata.subcomponent_idx}`) : null, + ); // Component data hook - call load() explicitly when component changes const componentData = useComponentData(); @@ -288,7 +291,7 @@ @@ -412,11 +415,16 @@ {/if} - +
+ + {#if currentGraphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
{#if componentData.componentDetail.status === "loading"} @@ -701,6 +709,12 @@ gap: var(--space-2); } + .interpretation-badges { + display: flex; + flex-direction: column; + gap: var(--space-2); + } + .dataset-attributions-loading { display: flex; flex-direction: column; diff --git a/spd/app/frontend/src/components/DataSourcesTab.svelte b/spd/app/frontend/src/components/DataSourcesTab.svelte index bc9282c27..6dd4d31c1 100644 --- a/spd/app/frontend/src/components/DataSourcesTab.svelte +++ b/spd/app/frontend/src/components/DataSourcesTab.svelte @@ -104,9 +104,9 @@ {:else if data.status === "error"}

Failed to load data sources: {data.error}

{:else if data.status === "loaded"} - {@const { harvest, autointerp, attributions } = data.data} + {@const { harvest, autointerp, attributions, graph_interp } = data.data} - {#if !harvest && !autointerp && !attributions} + {#if !harvest && !autointerp && !attributions && !graph_interp}

No pipeline data available for this run.

{/if} @@ -138,9 +138,6 @@ Subrun {attributions.subrun_id} - Batches - {attributions.n_batches_processed.toLocaleString()} - Tokens {attributions.n_tokens_processed.toLocaleString()} @@ -150,6 +147,28 @@ {/if} + {#if graph_interp} +
+

Graph Interp

+
+ Subrun + {graph_interp.subrun_id} + + {#each Object.entries(graph_interp.label_counts) as [key, value] (key)} + {key} labels + {value.toLocaleString()} + {/each} + + {#if graph_interp.config} + {#each Object.entries(graph_interp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} + {/if} +
+
+ {/if} + {#if autointerp}

Autointerp

diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte new file mode 100644 index 000000000..b7752cb5f --- /dev/null +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -0,0 +1,645 @@ + + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ + {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} + {#if selected.data.research_log} + + {:else} +

No research log available

+ {/if} + {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ +
{ + e.preventDefault(); + handleLaunch(); + }} + > + + +
+ {#if launchState.status === "error"} +
{launchState.error}
+ {/if} + + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

+ No investigations found. Run spd-investigate to create one. +

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ModelGraph.svelte b/spd/app/frontend/src/components/ModelGraph.svelte new file mode 100644 index 000000000..ae49a25e3 --- /dev/null +++ b/spd/app/frontend/src/components/ModelGraph.svelte @@ -0,0 +1,520 @@ + + +
+ +
+
+ + + +
+
+ +
+
+ +
+
+ +
+
+ {filteredNodes.length} nodes, {visibleEdges.length} edges +
+
+ + + +
+
+ + + + + + {#if tooltipNode} +
+
{tooltipNode.label}
+
+ {tooltipNode.confidence} + {tooltipNode.key} +
+
+ {/if} + + + {#if selectedNodeKey} + {@const selectedNode = layout.nodes.get(selectedNodeKey)} + {#if selectedNode} +
+
+ {selectedNode.label} + {selectedNode.confidence} + +
+
{selectedNode.key}
+
+ {#if selectedNodeEdges.length > 0} +
+ {#each selectedNodeEdges as e, i (i)} + {@const other = e.source === selectedNodeKey ? e.target : e.source} + {@const otherNode = layout.nodes.get(other)} +
+ {e.source === selectedNodeKey ? "to" : "from"} + {otherNode?.label ?? other} + {e.attribution.toFixed(3)} +
+ {/each} +
+ {:else} + No edges + {/if} +
+
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ModelGraphTab.svelte b/spd/app/frontend/src/components/ModelGraphTab.svelte new file mode 100644 index 000000000..2e95e315d --- /dev/null +++ b/spd/app/frontend/src/components/ModelGraphTab.svelte @@ -0,0 +1,53 @@ + + +
+ {#if data.status === "loading"} +
Loading model graph...
+ {:else if data.status === "error"} +
Failed to load graph: {String(data.error)}
+ {:else if data.status === "loaded"} + + {:else} +
Initializing...
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index aa4728bd8..54f83829f 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -17,6 +17,20 @@ // Architecture info fetched in real-time for each canonical run let archInfo = $state>({}); + function formatArchLabel(info: PretrainInfoResponse): string { + const cfg = info.target_model_config; + const parts: string[] = []; + if (info.dataset_short) parts.push(info.dataset_short); + parts.push(info.model_type); + if (cfg) { + const nLayer = cfg.n_layer as number | undefined; + const nEmbd = cfg.n_embd as number | undefined; + if (nLayer != null) parts.push(`${nLayer}L`); + if (nEmbd != null) parts.push(`d${nEmbd}`); + } + return parts.join(" "); + } + onMount(() => { for (const entry of CANONICAL_RUNS) { archInfo[entry.wandbRunId] = "loading"; @@ -63,16 +77,17 @@ {#each CANONICAL_RUNS as entry (entry.wandbRunId)} {@const info = archInfo[entry.wandbRunId]} + + {#if runState.run.status === "loaded" && runState.run.data} + {/if} + + {#if expanded && detail} +
+
+
+ Input + {#if detail.input?.reasoning} +

{detail.input.reasoning}

+ {/if} + {#each incomingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+ Output + {#if detail.output?.reasoning} +

{detail.output.reasoning}

+ {/if} + {#each outgoingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+
+ {/if} + + + diff --git a/spd/app/frontend/src/lib/api/dataSources.ts b/spd/app/frontend/src/lib/api/dataSources.ts index e715af1b1..ac20b7220 100644 --- a/spd/app/frontend/src/lib/api/dataSources.ts +++ b/spd/app/frontend/src/lib/api/dataSources.ts @@ -20,15 +20,21 @@ export type AutointerpInfo = { export type AttributionsInfo = { subrun_id: string; - n_batches_processed: number; n_tokens_processed: number; ci_threshold: number; }; +export type GraphInterpInfoDS = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + export type DataSourcesResponse = { harvest: HarvestInfo | null; autointerp: AutointerpInfo | null; attributions: AttributionsInfo | null; + graph_interp: GraphInterpInfoDS | null; }; export async function fetchDataSources(): Promise { diff --git a/spd/app/frontend/src/lib/api/datasetAttributions.ts b/spd/app/frontend/src/lib/api/datasetAttributions.ts index f995a33f6..030eae6c6 100644 --- a/spd/app/frontend/src/lib/api/datasetAttributions.ts +++ b/spd/app/frontend/src/lib/api/datasetAttributions.ts @@ -9,15 +9,23 @@ export type DatasetAttributionEntry = { layer: string; component_idx: number; value: number; + token_str: string | null; }; -export type ComponentAttributions = { +export type SignedAttributions = { positive_sources: DatasetAttributionEntry[]; negative_sources: DatasetAttributionEntry[]; positive_targets: DatasetAttributionEntry[]; negative_targets: DatasetAttributionEntry[]; }; +export type AttrMetric = "attr" | "attr_abs"; + +export type AllMetricAttributions = { + attr: SignedAttributions; + attr_abs: SignedAttributions; +}; + export type DatasetAttributionsMetadata = { available: boolean; }; @@ -30,8 +38,8 @@ export async function getComponentAttributions( layer: string, componentIdx: number, k: number = 10, -): Promise { +): Promise { const url = apiUrl(`/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("k", String(k)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } diff --git a/spd/app/frontend/src/lib/api/graphInterp.ts b/spd/app/frontend/src/lib/api/graphInterp.ts new file mode 100644 index 000000000..8229e757c --- /dev/null +++ b/spd/app/frontend/src/lib/api/graphInterp.ts @@ -0,0 +1,81 @@ +/** + * API client for /api/graph_interp endpoints. + */ + +import { fetchJson } from "./index"; + +export type GraphInterpHeadline = { + label: string; + confidence: string; + output_label: string | null; + input_label: string | null; +}; + +export type LabelDetail = { + label: string; + confidence: string; + reasoning: string; + prompt: string; +}; + +export type GraphInterpDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; +}; + +export type PromptEdgeResponse = { + related_key: string; + pass_name: string; + attribution: number; + related_label: string | null; + related_confidence: string | null; + token_str: string | null; +}; + +export type GraphInterpComponentDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; + edges: PromptEdgeResponse[]; +}; + +export type GraphNode = { + component_key: string; + label: string; + confidence: string; +}; + +export type GraphEdge = { + source: string; + target: string; + attribution: number; + pass_name: string; +}; + +export type ModelGraphResponse = { + nodes: GraphNode[]; + edges: GraphEdge[]; +}; + +export type GraphInterpInfo = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + +export async function getAllGraphInterpLabels(): Promise> { + return fetchJson>("/api/graph_interp/labels"); +} + +export async function getGraphInterpDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/labels/${layer}/${cIdx}`); +} + +export async function getGraphInterpComponentDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/detail/${layer}/${cIdx}`); +} + +export async function getModelGraph(): Promise { + return fetchJson("/api/graph_interp/graph"); +} diff --git a/spd/app/frontend/src/lib/api/graphs.ts b/spd/app/frontend/src/lib/api/graphs.ts index 42490d531..7d2338fdb 100644 --- a/spd/app/frontend/src/lib/api/graphs.ts +++ b/spd/app/frontend/src/lib/api/graphs.ts @@ -4,7 +4,6 @@ import type { GraphData, TokenizeResponse, TokenInfo } from "../promptAttributionsTypes"; import { buildEdgeIndexes } from "../promptAttributionsTypes"; -import { setArchitecture } from "../layerAliasing"; import { apiUrl, ApiError, fetchJson } from "./index"; export type NormalizeType = "none" | "target" | "layer"; @@ -59,14 +58,6 @@ async function parseGraphSSEStream( } else if (data.type === "error") { throw new ApiError(data.error, 500); } else if (data.type === "complete") { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of data.data.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(data.data.edges); result = { ...data.data, edgesBySource, edgesByTarget }; await reader.cancel(); @@ -166,14 +157,6 @@ export async function getGraphs(promptId: number, normalize: NormalizeType, ciTh url.searchParams.set("ci_threshold", String(ciThreshold)); const graphs = await fetchJson[]>(url.toString()); return graphs.map((g) => { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of g.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); return { ...g, edgesBySource, edgesByTarget }; }); diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 773663636..88187c5d2 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -51,5 +51,7 @@ export * from "./datasetAttributions"; export * from "./intervention"; export * from "./dataset"; export * from "./clusters"; +export * from "./investigations"; export * from "./dataSources"; +export * from "./graphInterp"; export * from "./pretrainInfo"; diff --git a/spd/app/frontend/src/lib/api/investigations.ts b/spd/app/frontend/src/lib/api/investigations.ts new file mode 100644 index 000000000..42f1fb1f3 --- /dev/null +++ b/spd/app/frontend/src/lib/api/investigations.ts @@ -0,0 +1,101 @@ +/** + * API client for investigation results. + */ + +export interface InvestigationSummary { + id: string; // inv_id (e.g., "inv-abc12345") + wandb_path: string | null; + prompt: string | null; + created_at: string; + has_research_log: boolean; + has_explanations: boolean; + event_count: number; + last_event_time: string | null; + last_event_message: string | null; + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; // in_progress, completed, inconclusive +} + +export interface EventEntry { + event_type: string; + timestamp: string; + message: string; + details: Record | null; +} + +export interface InvestigationDetail { + id: string; + wandb_path: string | null; + prompt: string | null; + created_at: string; + research_log: string | null; + events: EventEntry[]; + explanations: Record[]; + artifact_ids: string[]; // List of artifact IDs available for this investigation + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; +} + +import type { EdgeData, OutputProbability } from "../promptAttributionsTypes"; + +/** Data for a graph artifact (subset of GraphData, self-contained for offline viewing) */ +export interface ArtifactGraphData { + tokens: string[]; + edges: EdgeData[]; + outputProbs: Record; + nodeCiVals: Record; + nodeSubcompActs: Record; + maxAbsAttr: number; + l0_total: number; +} + +export interface GraphArtifact { + type: "graph"; + id: string; + caption: string | null; + graph_id: number; + data: ArtifactGraphData; +} + +export interface LaunchResponse { + inv_id: string; + job_id: string; +} + +export async function launchInvestigation(prompt: string): Promise { + const res = await fetch("/api/investigations/launch", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ prompt }), + }); + if (!res.ok) throw new Error(`Failed to launch investigation: ${res.statusText}`); + return res.json(); +} + +export async function listInvestigations(): Promise { + const res = await fetch("/api/investigations"); + if (!res.ok) throw new Error(`Failed to list investigations: ${res.statusText}`); + return res.json(); +} + +export async function getInvestigation(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}`); + if (!res.ok) throw new Error(`Failed to get investigation: ${res.statusText}`); + return res.json(); +} + +export async function listArtifacts(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts`); + if (!res.ok) throw new Error(`Failed to list artifacts: ${res.statusText}`); + return res.json(); +} + +export async function getArtifact(invId: string, artifactId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts/${artifactId}`); + if (!res.ok) throw new Error(`Failed to get artifact: ${res.statusText}`); + return res.json(); +} diff --git a/spd/app/frontend/src/lib/api/pretrainInfo.ts b/spd/app/frontend/src/lib/api/pretrainInfo.ts index 7092c735a..0cd66bd97 100644 --- a/spd/app/frontend/src/lib/api/pretrainInfo.ts +++ b/spd/app/frontend/src/lib/api/pretrainInfo.ts @@ -20,6 +20,7 @@ export type TopologyInfo = { export type PretrainInfoResponse = { model_type: string; summary: string; + dataset_short: string | null; target_model_config: Record | null; pretrain_config: Record | null; pretrain_wandb_path: string | null; diff --git a/spd/app/frontend/src/lib/api/runs.ts b/spd/app/frontend/src/lib/api/runs.ts index 1430632a4..bb385a7ed 100644 --- a/spd/app/frontend/src/lib/api/runs.ts +++ b/spd/app/frontend/src/lib/api/runs.ts @@ -14,6 +14,7 @@ export type LoadedRun = { backend_user: string; dataset_attributions_available: boolean; dataset_search_enabled: boolean; + graph_interp_available: boolean; }; export async function getStatus(): Promise { diff --git a/spd/app/frontend/src/lib/componentKeys.ts b/spd/app/frontend/src/lib/componentKeys.ts new file mode 100644 index 000000000..ff83bda06 --- /dev/null +++ b/spd/app/frontend/src/lib/componentKeys.ts @@ -0,0 +1,17 @@ +/** + * Utilities for component key display (e.g. rendering embed/output keys with token strings). + */ + +export function isTokenNode(key: string): boolean { + const layer = key.split(":")[0]; + return layer === "embed" || layer === "output"; +} + +export function formatComponentKey(key: string, tokenStr: string | null): string { + if (tokenStr && isTokenNode(key)) { + const layer = key.split(":")[0]; + const label = layer === "embed" ? "input" : "output"; + return `'${tokenStr}' (${label})`; + } + return key; +} diff --git a/spd/app/frontend/src/lib/layerAliasing.ts b/spd/app/frontend/src/lib/layerAliasing.ts deleted file mode 100644 index 2c5269543..000000000 --- a/spd/app/frontend/src/lib/layerAliasing.ts +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Layer aliasing system - transforms internal module names to human-readable aliases. - * - * Formats: - * - Internal: "h.0.mlp.c_fc", "h.1.attn.q_proj" - * - Aliased: "L0.mlp.in", "L1.attn.q" - * - * Handles multiple architectures: - * - GPT-2: c_fc -> mlp.in, down_proj -> mlp.out - * - Llama SwiGLU: gate_proj -> mlp.gate, up_proj -> mlp.up, down_proj -> mlp.down - * - Attention: q_proj -> attn.q, k_proj -> attn.k, v_proj -> attn.v, o_proj -> attn.o - * - Special: lm_head -> W_U, embed/output unchanged - */ - -type Architecture = "gpt2" | "llama" | "unknown"; - -/** Mapping of internal module names to aliases by architecture */ -const ALIASES: Record> = { - gpt2: { - // MLP - c_fc: "in", - down_proj: "out", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - llama: { - // MLP (SwiGLU) - gate_proj: "gate", - up_proj: "up", - down_proj: "down", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - unknown: { - // Fallback - just do attention mappings - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, -}; - -/** Special layers with fixed display names */ -const SPECIAL_LAYERS: Record = { - lm_head: "W_U", - embed: "embed", - output: "output", -}; - -// Cache for detected architecture from the full model -let cachedArchitecture: Architecture | null = null; - -/** - * Detect architecture from a collection of layer names. - * Llama has gate_proj/up_proj, GPT-2 has c_fc. - * - * This should be called once with all available layer names to establish - * the architecture for the session, ensuring down_proj is aliased correctly. - */ -export function detectArchitectureFromLayers(layers: string[]): Architecture { - const hasLlamaLayers = layers.some((layer) => layer.includes("gate_proj") || layer.includes("up_proj")); - if (hasLlamaLayers) { - return "llama"; - } - - const hasGPT2Layers = layers.some((layer) => layer.includes("c_fc")); - if (hasGPT2Layers) { - return "gpt2"; - } - - return "unknown"; -} - -/** - * Set the architecture for aliasing operations. - * Call this when you have access to all layer names (e.g., when loading a graph). - */ -export function setArchitecture(layers: string[]): void { - cachedArchitecture = detectArchitectureFromLayers(layers); -} - -/** - * Detect architecture from layer name. - * Uses cached architecture if available (set via setArchitecture()), - * otherwise falls back to single-layer detection. - * - * Note: down_proj appears in both architectures with different meanings: - * - GPT-2: down_proj -> "out" (second MLP projection) - * - Llama: down_proj -> "down" (third MLP projection after gate/up) - * - * Single-layer detection cannot distinguish these cases reliably. - */ -function detectArchitecture(layer: string): Architecture { - // Use cached architecture if available - if (cachedArchitecture !== null) { - return cachedArchitecture; - } - - // Fallback: single-layer detection (less reliable for down_proj) - if (layer.includes("gate_proj") || layer.includes("up_proj")) { - return "llama"; - } - if (layer.includes("c_fc")) { - return "gpt2"; - } - // down_proj is ambiguous without context, default to GPT-2 - if (layer.includes("down_proj")) { - return "gpt2"; - } - return "unknown"; -} - -/** - * Parse a layer name into components. - * Returns null for special layers (embed, output, lm_head) or unrecognized formats. - */ -function parseLayerName(layer: string): { block: number; moduleType: string; submodule: string } | null { - if (layer in SPECIAL_LAYERS) { - return null; - } - - const match = layer.match(/^h\.(\d+)\.(attn|mlp)\.(\w+)$/); - if (!match) { - return null; - } - - const [, blockStr, moduleType, submodule] = match; - return { - block: parseInt(blockStr), - moduleType, - submodule, - }; -} - -/** - * Transform a layer name to its aliased form. - * - * Examples: - * - "h.0.mlp.c_fc" -> "L0.mlp.in" - * - "h.2.attn.q_proj" -> "L2.attn.q" - * - "lm_head" -> "W_U" - * - "embed" -> "embed" - */ -export function getLayerAlias(layer: string): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Get a row label for grouped display in graphs. - * - * @param layer - Internal layer name (e.g., "h.0.mlp.c_fc") - * @param isQkvGroup - Whether this represents a grouped QKV row - * @returns Label (e.g., "L0.mlp.in", "L2.attn.qkv") - * - * @example - * getAliasedRowLabel("h.0.mlp.c_fc") // => "L0.mlp.in" - * getAliasedRowLabel("h.2.attn.q_proj", true) // => "L2.attn.qkv" - */ -export function getAliasedRowLabel(layer: string, isQkvGroup = false): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - if (isQkvGroup) { - return `L${parsed.block}.${parsed.moduleType}.qkv`; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Format a node key with aliased layer names. - * - * Node keys are "layer:seq:cIdx" or "layer:cIdx" format. - * - * Examples: - * - "h.0.mlp.c_fc:3:5" -> "L0.mlp.in:3:5" - * - "h.1.attn.q_proj:2:10" -> "L1.attn.q:2:10" - */ -export function formatNodeKeyWithAliases(nodeKey: string): string { - const parts = nodeKey.split(":"); - const layer = parts[0]; - const aliasedLayer = getLayerAlias(layer); - return [aliasedLayer, ...parts.slice(1)].join(":"); -} diff --git a/spd/app/frontend/src/lib/promptAttributionsTypes.ts b/spd/app/frontend/src/lib/promptAttributionsTypes.ts index fc705fad0..6f29030aa 100644 --- a/spd/app/frontend/src/lib/promptAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/promptAttributionsTypes.ts @@ -20,6 +20,7 @@ export type EdgeAttribution = { key: string; // "layer:seq:cIdx" for prompt or "layer:cIdx" for dataset value: number; // raw attribution value (positive or negative) normalizedMagnitude: number; // |value| / maxAbsValue, for color intensity (0-1) + tokenStr: string | null; // resolved token string for embed/output layers }; export type OutputProbability = { @@ -233,7 +234,7 @@ export function formatNodeKeyForDisplay(nodeKey: string, displayNames: Record>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); // Current coords being loaded/displayed (for interpretation lookup) let currentCoords = $state(null); @@ -146,6 +148,26 @@ export function useComponentData() { interpretationDetail = { status: "error", error }; } }); + + // Fetch graph interp detail (skip if not available for this run) + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } /** @@ -159,6 +181,7 @@ export function useComponentData() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache - reactive to both coords and cache @@ -212,6 +235,9 @@ export function useComponentData() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts index f32dab70a..4371fa91b 100644 --- a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts @@ -14,10 +14,11 @@ import { getComponentAttributions, getComponentCorrelations, getComponentTokenStats, + getGraphInterpComponentDetail, getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; +import type { AllMetricAttributions, GraphInterpComponentDetail, InterpretationDetail } from "./api"; import type { SubcomponentCorrelationsResponse, SubcomponentActivationContexts, @@ -29,7 +30,7 @@ const DATASET_ATTRIBUTIONS_TOP_K = 20; /** Fetch more activation examples in background after initial cached load */ const ACTIVATION_EXAMPLES_FULL_LIMIT = 200; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; @@ -39,8 +40,9 @@ export function useComponentDataExpectCached() { let componentDetail = $state>({ status: "uninitialized" }); let correlations = $state>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); let currentCoords = $state(null); let requestId = 0; @@ -102,6 +104,26 @@ export function useComponentDataExpectCached() { interpretationDetail = { status: "error", error }; } }); + + // Fetch graph interp detail + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } function load(layer: string, cIdx: number) { @@ -144,6 +166,7 @@ export function useComponentDataExpectCached() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache @@ -197,6 +220,9 @@ export function useComponentDataExpectCached() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useRun.svelte.ts b/spd/app/frontend/src/lib/useRun.svelte.ts index de6d20c7d..32c9eb4b4 100644 --- a/spd/app/frontend/src/lib/useRun.svelte.ts +++ b/spd/app/frontend/src/lib/useRun.svelte.ts @@ -7,7 +7,7 @@ import type { Loadable } from "."; import * as api from "./api"; -import type { LoadedRun as RunData, InterpretationHeadline } from "./api"; +import type { LoadedRun as RunData, InterpretationHeadline, GraphInterpHeadline } from "./api"; import type { PromptPreview, SubcomponentActivationContexts, @@ -46,6 +46,9 @@ export function useRun() { /** Intruder eval scores keyed by component key */ let intruderScores = $state>>({ status: "uninitialized" }); + /** Graph interp labels keyed by component key (layer:cIdx) */ + let graphInterpLabels = $state>>({ status: "uninitialized" }); + /** Cluster mapping for the current run */ let clusterMapping = $state(null); @@ -71,6 +74,7 @@ export function useRun() { allTokens = { status: "uninitialized" }; interpretations = { status: "uninitialized" }; intruderScores = { status: "uninitialized" }; + graphInterpLabels = { status: "uninitialized" }; activationContextsSummary = { status: "uninitialized" }; _componentDetailsCache = {}; clusterMapping = null; @@ -88,6 +92,9 @@ export function useRun() { api.getIntruderScores() .then((data) => (intruderScores = { status: "loaded", data })) .catch((error) => (intruderScores = { status: "error", error })); + api.getAllGraphInterpLabels() + .then((data) => (graphInterpLabels = { status: "loaded", data })) + .catch((error) => (graphInterpLabels = { status: "error", error })); api.getAllInterpretations() .then((i) => { interpretations = { @@ -230,6 +237,11 @@ export function useRun() { return clusterMapping?.data[key] ?? null; } + function getGraphInterpLabel(componentKey: string): GraphInterpHeadline | null { + if (graphInterpLabels.status !== "loaded") return null; + return graphInterpLabels.data[componentKey] ?? null; + } + return { get run() { return run; @@ -237,6 +249,9 @@ export function useRun() { get interpretations() { return interpretations; }, + get graphInterpLabels() { + return graphInterpLabels; + }, get clusterMapping() { return clusterMapping; }, @@ -252,6 +267,9 @@ export function useRun() { get datasetAttributionsAvailable() { return run.status === "loaded" && run.data.dataset_attributions_available; }, + get graphInterpAvailable() { + return run.status === "loaded" && run.data.graph_interp_available; + }, loadRun, clearRun, syncStatus, @@ -259,6 +277,7 @@ export function useRun() { getInterpretation, setInterpretation, getIntruderScore, + getGraphInterpLabel, getActivationContextDetail, loadActivationContextsSummary, setClusterMapping, diff --git a/spd/autointerp/db.py b/spd/autointerp/db.py index aa3aca4ee..66d681333 100644 --- a/spd/autointerp/db.py +++ b/spd/autointerp/db.py @@ -40,7 +40,6 @@ def __init__(self, db_path: Path, readonly: bool = False) -> None: ) else: self._conn = sqlite3.connect(str(db_path), check_same_thread=False) - self._conn.execute("PRAGMA journal_mode=WAL") self._conn.executescript(_SCHEMA) self._conn.row_factory = sqlite3.Row diff --git a/spd/autointerp/interpret.py b/spd/autointerp/interpret.py index 9a0eff3a6..89e992902 100644 --- a/spd/autointerp/interpret.py +++ b/spd/autointerp/interpret.py @@ -190,10 +190,8 @@ def build_jobs() -> Iterable[LLMJob]: f"Error rate {error_rate:.0%} ({n_errors}/{len(remaining)}) exceeds 20% threshold" ) - except Exception as e: - logger.error(f"Error: {type(e).__name__}: {e}") + finally: db.close() - raise e logger.info(f"Completed {len(results)} interpretations -> {db_path}") return results diff --git a/spd/autointerp/prompt_helpers.py b/spd/autointerp/prompt_helpers.py new file mode 100644 index 000000000..63e580591 --- /dev/null +++ b/spd/autointerp/prompt_helpers.py @@ -0,0 +1,161 @@ +"""Shared prompt-building helpers for autointerp and graph interpretation. + +Pure functions for formatting component data into LLM prompt sections. +""" + +import re + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.utils import delimit_tokens +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData + +DATASET_DESCRIPTIONS: dict[str, str] = { + "SimpleStories/SimpleStories": ( + "SimpleStories: 2M+ short stories (200-350 words), grade 1-8 reading level. " + "Simple vocabulary, common narrative elements." + ), +} + +WEIGHT_NAMES: dict[str, str] = { + "attn.q": "attention query projection", + "attn.k": "attention key projection", + "attn.v": "attention value projection", + "attn.o": "attention output projection", + "mlp.up": "MLP up-projection", + "mlp.down": "MLP down-projection", + "glu.up": "GLU up-projection", + "glu.down": "GLU down-projection", + "glu.gate": "GLU gate projection", +} + +_ORDINALS = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"] + + +def ordinal(n: int) -> str: + if 1 <= n <= len(_ORDINALS): + return _ORDINALS[n - 1] + return f"{n}th" + + +def human_layer_desc(canonical: str, n_blocks: int) -> str: + """Convert canonical layer string to human-readable description. + + '0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks' + '1.attn.q' -> 'attention query projection in the 2nd of 4 blocks' + """ + m = re.match(r"(\d+)\.(.*)", canonical) + if not m: + return canonical + layer_idx = int(m.group(1)) + weight_key = m.group(2) + weight_name = WEIGHT_NAMES.get(weight_key, weight_key) + return f"{weight_name} in the {ordinal(layer_idx + 1)} of {n_blocks} blocks" + + +def layer_position_note(canonical: str, n_blocks: int) -> str: + """Brief note about what layer position means for interpretation.""" + m = re.match(r"(\d+)\.", canonical) + if not m: + return "" + layer_idx = int(m.group(1)) + if layer_idx == n_blocks - 1: + return "This is in the final block, so its output directly influences token predictions." + remaining = n_blocks - 1 - layer_idx + return ( + f"This is {remaining} block{'s' if remaining > 1 else ''} from the output, " + f"so its effect on token predictions is indirect — filtered through later layers." + ) + + +def density_note(firing_density: float) -> str: + if firing_density > 0.15: + return ( + "This is a high-density component (fires frequently). " + "High-density components often act as broad biases rather than selective features." + ) + if firing_density < 0.005: + return "This is a very sparse component, likely highly specific." + return "" + + +def build_output_section( + output_stats: TokenPRLift, + output_pmi: list[tuple[str, float]] | None, +) -> str: + section = "" + + if output_pmi: + section += ( + "**Output PMI (pointwise mutual information, in nats: how much more likely " + "a token is to be produced when this component fires, vs its base rate. " + "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**\n" + ) + for tok, pmi in output_pmi[:10]: + section += f"- {repr(tok)}: {pmi:.2f}\n" + + if output_stats.top_precision: + section += "\n**Output precision — of all probability mass for token X, what fraction is at positions where this component fires?**\n" + for tok, prec in output_stats.top_precision[:10]: + section += f"- {repr(tok)}: {prec * 100:.0f}%\n" + + return section + + +def build_input_section( + input_stats: TokenPRLift, + input_pmi: list[tuple[str, float]] | None, +) -> str: + section = "" + + if input_pmi: + section += "**Input PMI (same metric as above, for input tokens):**\n" + for tok, pmi in input_pmi[:6]: + section += f"- {repr(tok)}: {pmi:.2f}\n" + + if input_stats.top_recall: + section += "\n**Input recall — most common tokens when the component fires:**\n" + for tok, recall in input_stats.top_recall[:8]: + section += f"- {repr(tok)}: {recall * 100:.0f}%\n" + + if input_stats.top_precision: + section += "\n**Input precision — probability the component fires given the current token is X:**\n" + for tok, prec in input_stats.top_precision[:8]: + section += f"- {repr(tok)}: {prec * 100:.0f}%\n" + + return section + + +def build_fires_on_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> str: + section = "" + examples = component.activation_examples[:max_examples] + + for i, ex in enumerate(examples): + if any(ex.firings): + spans = app_tok.get_spans(ex.token_ids) + tokens = list(zip(spans, ex.firings, strict=True)) + section += f"{i + 1}. {delimit_tokens(tokens)}\n" + + return section + + +def build_says_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> str: + section = "" + examples = component.activation_examples[:max_examples] + + for i, ex in enumerate(examples): + if any(ex.firings): + spans = app_tok.get_spans(ex.token_ids) + shifted_firings = [False] + ex.firings[:-1] + tokens = list(zip(spans, shifted_firings, strict=True)) + section += f"{i + 1}. {delimit_tokens(tokens)}\n" + + return section diff --git a/spd/autointerp/scoring/detection.py b/spd/autointerp/scoring/detection.py index 7863d1598..afc00a163 100644 --- a/spd/autointerp/scoring/detection.py +++ b/spd/autointerp/scoring/detection.py @@ -16,6 +16,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.utils import delimit_tokens from spd.autointerp.config import DetectionEvalConfig +from spd.autointerp.db import InterpDB from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls from spd.autointerp.repo import InterpRepo from spd.harvest.schemas import ActivationExample, ComponentData @@ -122,6 +123,7 @@ class _TrialGroundTruth: async def run_detection_scoring( components: list[ComponentData], interp_repo: InterpRepo, + score_db: InterpDB, model: str, reasoning_effort: Effort, openrouter_api_key: str, @@ -144,7 +146,7 @@ async def run_detection_scoring( if limit is not None: eligible = eligible[:limit] - existing_scores = interp_repo.get_scores("detection") + existing_scores = score_db.get_scores("detection") completed = set(existing_scores.keys()) if completed: logger.info(f"Resuming: {len(completed)} already scored") @@ -237,7 +239,7 @@ async def run_detection_scoring( score = sum(t.balanced_acc for t in trials) / len(trials) if trials else 0.0 result = DetectionResult(component_key=ck, score=score, trials=trials, n_errors=n_err) results.append(result) - interp_repo.save_score(ck, "detection", score, json.dumps(asdict(result))) + score_db.save_score(ck, "detection", score, json.dumps(asdict(result))) logger.info(f"Scored {len(results)} components") return results diff --git a/spd/autointerp/scoring/fuzzing.py b/spd/autointerp/scoring/fuzzing.py index cafb8e746..9669a46cf 100644 --- a/spd/autointerp/scoring/fuzzing.py +++ b/spd/autointerp/scoring/fuzzing.py @@ -17,6 +17,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.utils import delimit_tokens from spd.autointerp.config import FuzzingEvalConfig +from spd.autointerp.db import InterpDB from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls from spd.autointerp.repo import InterpRepo from spd.harvest.schemas import ActivationExample, ComponentData @@ -116,6 +117,7 @@ class _TrialGroundTruth: async def run_fuzzing_scoring( components: list[ComponentData], interp_repo: InterpRepo, + score_db: InterpDB, model: str, reasoning_effort: Effort, openrouter_api_key: str, @@ -140,7 +142,7 @@ async def run_fuzzing_scoring( if limit is not None: eligible = eligible[:limit] - existing_scores = interp_repo.get_scores("fuzzing") + existing_scores = score_db.get_scores("fuzzing") completed = set(existing_scores.keys()) if completed: logger.info(f"Resuming: {len(completed)} already scored") @@ -235,7 +237,7 @@ async def run_fuzzing_scoring( score = (tpr + tnr) / 2 if (total_pos > 0 and total_neg > 0) else 0.0 result = FuzzingResult(component_key=ck, score=score, trials=trials, n_errors=n_err) results.append(result) - interp_repo.save_score(ck, "fuzzing", score, json.dumps(asdict(result))) + score_db.save_score(ck, "fuzzing", score, json.dumps(asdict(result))) logger.info(f"Scored {len(results)} components") return results diff --git a/spd/autointerp/scoring/scripts/run_label_scoring.py b/spd/autointerp/scoring/scripts/run_label_scoring.py index f2cf3acf1..be2efa388 100644 --- a/spd/autointerp/scoring/scripts/run_label_scoring.py +++ b/spd/autointerp/scoring/scripts/run_label_scoring.py @@ -12,6 +12,7 @@ from spd.adapters import adapter_from_id from spd.autointerp.config import AutointerpEvalConfig +from spd.autointerp.db import InterpDB from spd.autointerp.repo import InterpRepo from spd.autointerp.scoring.detection import run_detection_scoring from spd.autointerp.scoring.fuzzing import run_fuzzing_scoring @@ -40,14 +41,17 @@ def main( f"No autointerp data for {decomposition_id}. Run autointerp first." ) + # Separate writable DB for saving scores (the repo's DB is readonly/immutable) + score_db = InterpDB(interp_repo._subrun_dir / "interp.db") + if harvest_subrun_id is not None: harvest = HarvestRepo( decomposition_id=decomposition_id, subrun_id=harvest_subrun_id, - readonly=False, + readonly=True, ) else: - harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=False) + harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=True) assert harvest is not None, f"No harvest data for {decomposition_id}" components = harvest.get_all_components() @@ -58,6 +62,7 @@ def main( run_detection_scoring( components=components, interp_repo=interp_repo, + score_db=score_db, model=config.model, reasoning_effort=config.reasoning_effort, openrouter_api_key=openrouter_api_key, @@ -74,6 +79,7 @@ def main( run_fuzzing_scoring( components=components, interp_repo=interp_repo, + score_db=score_db, model=config.model, reasoning_effort=config.reasoning_effort, openrouter_api_key=openrouter_api_key, @@ -86,6 +92,8 @@ def main( ) ) + score_db.close() + def get_command( decomposition_id: str, diff --git a/spd/autointerp/scripts/run_slurm.py b/spd/autointerp/scripts/run_slurm.py index c521ac439..bcd83b94c 100644 --- a/spd/autointerp/scripts/run_slurm.py +++ b/spd/autointerp/scripts/run_slurm.py @@ -59,8 +59,7 @@ def submit_autointerp( interpret_slurm = SlurmConfig( job_name="spd-interpret", partition=config.partition, - n_gpus=0, - cpus_per_task=16, + n_gpus=2, time=config.time, snapshot_branch=snapshot_branch, dependency_job_id=dependency_job_id, @@ -98,8 +97,7 @@ def submit_autointerp( eval_slurm = SlurmConfig( job_name=f"spd-{scorer}", partition=config.partition, - n_gpus=0, - cpus_per_task=16, + n_gpus=2, time=config.evals_time, snapshot_branch=snapshot_branch, dependency_job_id=interpret_result.job_id, diff --git a/spd/autointerp/strategies/dual_view.py b/spd/autointerp/strategies/dual_view.py index 430405e41..1206c73f5 100644 --- a/spd/autointerp/strategies/dual_view.py +++ b/spd/autointerp/strategies/dual_view.py @@ -7,83 +7,22 @@ - Task framing asks for functional description, not detection label """ -import re - from spd.app.backend.app_tokenizer import AppTokenizer -from spd.app.backend.utils import delimit_tokens from spd.autointerp.config import DualViewConfig +from spd.autointerp.prompt_helpers import ( + DATASET_DESCRIPTIONS, + build_fires_on_examples, + build_input_section, + build_output_section, + build_says_examples, + density_note, + human_layer_desc, + layer_position_note, +) from spd.autointerp.schemas import ModelMetadata from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData -DATASET_DESCRIPTIONS: dict[str, str] = { - "SimpleStories/SimpleStories": ( - "SimpleStories: 2M+ short stories (200-350 words), grade 1-8 reading level. " - "Simple vocabulary, common narrative elements." - ), -} - -WEIGHT_NAMES: dict[str, str] = { - "attn.q": "attention query projection", - "attn.k": "attention key projection", - "attn.v": "attention value projection", - "attn.o": "attention output projection", - "mlp.up": "MLP up-projection", - "mlp.down": "MLP down-projection", - "glu.up": "GLU up-projection", - "glu.down": "GLU down-projection", - "glu.gate": "GLU gate projection", -} - -_ORDINALS = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"] - - -def _ordinal(n: int) -> str: - if 1 <= n <= len(_ORDINALS): - return _ORDINALS[n - 1] - return f"{n}th" - - -def _human_layer_desc(canonical: str, n_blocks: int) -> str: - """Convert canonical layer string to human-readable description. - - '0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks' - '1.attn.q' -> 'attention query projection in the 2nd of 4 blocks' - """ - m = re.match(r"(\d+)\.(.*)", canonical) - if not m: - return canonical - layer_idx = int(m.group(1)) - weight_key = m.group(2) - weight_name = WEIGHT_NAMES.get(weight_key, weight_key) - return f"{weight_name} in the {_ordinal(layer_idx + 1)} of {n_blocks} blocks" - - -def _layer_position_note(canonical: str, n_blocks: int) -> str: - """Brief note about what layer position means for interpretation.""" - m = re.match(r"(\d+)\.", canonical) - if not m: - return "" - layer_idx = int(m.group(1)) - if layer_idx == n_blocks - 1: - return "This is in the final block, so its output directly influences token predictions." - remaining = n_blocks - 1 - layer_idx - return ( - f"This is {remaining} block{'s' if remaining > 1 else ''} from the output, " - f"so its effect on token predictions is indirect — filtered through later layers." - ) - - -def _density_note(firing_density: float) -> str: - if firing_density > 0.15: - return ( - "This is a high-density component (fires frequently). " - "High-density components often act as broad biases rather than selective features." - ) - if firing_density < 0.005: - return "This is a very sparse component, likely highly specific." - return "" - def format_prompt( config: DualViewConfig, @@ -108,10 +47,10 @@ def format_prompt( else None ) - output_section = _build_output_section(output_token_stats, output_pmi) - input_section = _build_input_section(input_token_stats, input_pmi) - fires_on_examples = _build_fires_on_examples(component, app_tok, config.max_examples) - says_examples = _build_says_examples(component, app_tok, config.max_examples) + output_section = build_output_section(output_token_stats, output_pmi) + input_section = build_input_section(input_token_stats, input_pmi) + fires_on_examples = build_fires_on_examples(component, app_tok, config.max_examples) + says_examples = build_says_examples(component, app_tok, config.max_examples) if component.firing_density > 0.0: rate_str = f"~1 in {int(1 / component.firing_density)} tokens" @@ -119,11 +58,11 @@ def format_prompt( rate_str = "extremely rare" canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) - layer_desc = _human_layer_desc(canonical, model_metadata.n_blocks) - position_note = _layer_position_note(canonical, model_metadata.n_blocks) - density_note = _density_note(component.firing_density) + layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) + position_note = layer_position_note(canonical, model_metadata.n_blocks) + dens_note = density_note(component.firing_density) - context_notes = " ".join(filter(None, [position_note, density_note])) + context_notes = " ".join(filter(None, [position_note, dens_note])) dataset_line = "" if config.include_dataset_description: @@ -183,85 +122,3 @@ def format_prompt( Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only. """ - - -def _build_output_section( - output_stats: TokenPRLift, - output_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - - if output_pmi: - section += ( - "**Output PMI (pointwise mutual information, in nats: how much more likely " - "a token is to be produced when this component fires, vs its base rate. " - "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**\n" - ) - for tok, pmi in output_pmi[:10]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - if output_stats.top_precision: - section += "\n**Output precision — of all probability mass for token X, what fraction is at positions where this component fires?**\n" - for tok, prec in output_stats.top_precision[:10]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section - - -def _build_input_section( - input_stats: TokenPRLift, - input_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - - if input_pmi: - section += "**Input PMI (same metric as above, for input tokens):**\n" - for tok, pmi in input_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - if input_stats.top_recall: - section += "\n**Input recall — most common tokens when the component fires:**\n" - for tok, recall in input_stats.top_recall[:8]: - section += f"- {repr(tok)}: {recall * 100:.0f}%\n" - - if input_stats.top_precision: - section += "\n**Input precision — probability the component fires given the current token is X:**\n" - for tok, prec in input_stats.top_precision[:8]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section - - -def _build_fires_on_examples( - component: ComponentData, - app_tok: AppTokenizer, - max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - tokens = list(zip(spans, ex.firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section - - -def _build_says_examples( - component: ComponentData, - app_tok: AppTokenizer, - max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - shifted_firings = [False] + ex.firings[:-1] - tokens = list(zip(spans, shifted_firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section diff --git a/spd/dataset_attributions/CLAUDE.md b/spd/dataset_attributions/CLAUDE.md index faf3a5373..2e2c0bd4c 100644 --- a/spd/dataset_attributions/CLAUDE.md +++ b/spd/dataset_attributions/CLAUDE.md @@ -5,150 +5,111 @@ Multi-GPU pipeline for computing component-to-component attribution strengths ag ## Usage (SLURM) ```bash -# Process specific number of batches spd-attributions --n_batches 1000 --n_gpus 8 - -# Process entire training dataset (omit --n_batches) -spd-attributions --n_gpus 24 - -# With optional parameters -spd-attributions --n_batches 1000 --n_gpus 8 \ - --batch_size 64 --ci_threshold 1e-6 --time 48:00:00 +spd-attributions --n_gpus 24 # whole dataset ``` The command: -1. Creates a git snapshot branch for reproducibility (jobs may be queued) -2. Submits a SLURM job array with N tasks (one per GPU) +1. Creates a git snapshot branch for reproducibility +2. Submits a SLURM job array (one per GPU) 3. Each task processes batches where `batch_idx % world_size == rank` -4. Submits a merge job (depends on array completion) that combines all worker results - -**Note**: `--n_batches` is optional. If omitted, the pipeline processes the entire training dataset. +4. Submits a merge job (depends on array completion) ## Usage (non-SLURM) -For environments without SLURM, run the worker script directly: - ```bash -# Single GPU (defaults from DatasetAttributionConfig, auto-generates subrun ID) -python -m spd.dataset_attributions.scripts.run - -# Single GPU with config file -python -m spd.dataset_attributions.scripts.run --config_path path/to/config.yaml +# Single GPU +python -m spd.dataset_attributions.scripts.run_worker -# Multi-GPU (run in parallel via shell, tmux, etc.) -# All workers and the merge step must share the same --subrun_id +# Multi-GPU SUBRUN="da-$(date +%Y%m%d_%H%M%S)" -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 0 --world_size 4 --subrun_id $SUBRUN & -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 1 --world_size 4 --subrun_id $SUBRUN & -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 2 --world_size 4 --subrun_id $SUBRUN & -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 3 --world_size 4 --subrun_id $SUBRUN & +python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 1000}' --rank 0 --world_size 4 --subrun_id $SUBRUN & +python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 1000}' --rank 1 --world_size 4 --subrun_id $SUBRUN & +# ... wait - -# Merge results after all workers complete -python -m spd.dataset_attributions.scripts.run --merge --subrun_id $SUBRUN +python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id $SUBRUN ``` -Each worker processes batches where `batch_idx % world_size == rank`, then the merge step combines all partial results. - ## Data Storage -Each attribution invocation creates a timestamped sub-run directory. `AttributionRepo` automatically loads from the latest sub-run. - ``` SPD_OUT_DIR/dataset_attributions// -├── da-20260211_120000/ # sub-run 1 -│ ├── dataset_attributions.pt # Final merged attributions -│ └── worker_states/ # cleaned up after merge +├── da-20260223_183250/ # sub-run (latest picked by repo) +│ ├── dataset_attributions.pt # merged result +│ └── worker_states/ │ └── dataset_attributions_rank_*.pt -├── da-20260211_140000/ # sub-run 2 -│ └── ... ``` -Legacy layout (pre sub-run) is still supported as a fallback by `AttributionRepo`: +`AttributionRepo.open(run_id)` loads the latest `da-*` subrun that has a `dataset_attributions.pt`. -``` -SPD_OUT_DIR/dataset_attributions// -└── dataset_attributions.pt -``` +## Attribution Metrics -## Architecture +Two metrics: `AttrMetric = Literal["attr", "attr_abs"]` -### SLURM Launcher (`scripts/run_slurm.py`, `scripts/run_slurm_cli.py`) +| Metric | Formula | Description | +|--------|---------|-------------| +| `attr` | E[∂y/∂x · x] | Signed mean attribution | +| `attr_abs` | E[∂\|y\|/∂x · x] | Attribution to absolute value of target (2 backward passes) | -Entry point via `spd-attributions`. Submits array job + dependent merge job. +Naming convention: modifier *before* `attr` applies to the target (e.g. `attr_abs` = attribution to |target|). -### Worker Script (`scripts/run.py`) +## Architecture -Internal script called by SLURM jobs. Accepts config via `--config_path` (file) or `--config_json` (inline JSON). Supports: -- `--config_path`/`--config_json`: Provide `DatasetAttributionConfig` (defaults used if neither given) -- `--rank R --world_size N`: Process subset of batches -- `--merge`: Combine per-rank results into final file -- `--subrun_id`: Sub-run identifier (auto-generated if not provided) +### Storage (`storage.py`) -### Config (`config.py`) +`DatasetAttributionStorage` stores four structurally distinct edge types: -`DatasetAttributionConfig` (tuning params) and `AttributionsSlurmConfig` (DatasetAttributionConfig + SLURM params). `wandb_path` is a runtime arg, not part of config. +| Edge type | Fields | Shape | Has abs? | +|-----------|--------|-------|----------| +| component → component | `regular_attr`, `regular_attr_abs` | `dict[target, dict[source, (tgt_c, src_c)]]` | yes | +| embed → component | `embed_attr`, `embed_attr_abs` | `dict[target, (tgt_c, vocab)]` | yes | +| component → unembed | `unembed_attr` | `dict[source, (d_model, src_c)]` | no | +| embed → unembed | `embed_unembed_attr` | `(d_model, vocab)` | no | -### Harvest Logic (`harvest.py`) +All layer names use **canonical addressing** (`"embed"`, `"0.glu.up"`, `"output"`). -Main harvesting functions: -- `harvest_attributions(wandb_path, config, output_dir, ...)`: Process batches for a single rank -- `merge_attributions(output_dir)`: Combine worker results from `output_dir/worker_states/` into `output_dir` +Unembed edges are stored in residual space (d_model dimensions). `w_unembed` is stored alongside the attribution data, so output token attributions are computed on-the-fly internally — callers never need to provide the projection matrix. No abs variant for unembed edges because abs is a nonlinear operation incompatible with residual-space storage. -### Attribution Harvester (`harvester.py`) +**Normalization**: `normed[t, s] = raw[t, s] / source_denom[s] / target_rms[t]`. Component sources use `ci_sum[s]` as denominator, embed sources use `embed_token_count[s]` (per-token occurrence count). This puts both source types on comparable per-occurrence scales. -Core class that accumulates attribution strengths using gradient × activation formula: +Key methods: `get_top_sources(key, k, sign, metric)`, `get_top_targets(key, k, sign, metric)`. Both return `[]` for nonexistent components. `merge(paths)` classmethod for combining worker results via weighted average by n_tokens. -``` -attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] -``` +### Harvester (`harvester.py`) -Key optimizations: +Accumulates attributions using gradient × activation. Uses **concrete module paths** internally (talks to model cache/CI). Four accumulator groups mirror the storage edge types. Key optimizations: 1. Sum outputs over positions before gradients (reduces backward passes) -2. For output targets, store attributions to output residual stream instead of vocab tokens (reduces storage from O((V+C)²) to O((V+C)×(C+d_model))) - -### Storage (`storage.py`) +2. Output-residual storage (O(d_model) instead of O(vocab)) +3. `scatter_add_` for embed sources, vectorized `.add_()` for components (>14x faster than per-element loops) -`DatasetAttributionStorage` class using output-residual-based storage for scalability. +### Harvest (`harvest.py`) -**Storage structure:** -- `source_to_component`: (n_sources, n_components) - direct attributions to component targets -- `source_to_out_residual`: (n_sources, d_model) - attributions to output residual stream for output queries +Orchestrates the pipeline: loads model, builds gradient connectivity, runs batches, translates concrete→canonical at storage boundary via `topology.target_to_canon()`. -**Source indexing (rows):** -- `[0, vocab_size)`: wte tokens -- `[vocab_size, vocab_size + n_components)`: component layers +### Scripts -**Target handling:** -- Component targets: direct lookup in `source_to_component` -- Output targets: computed on-the-fly via `source_to_out_residual @ w_unembed[:, token_id]` +- `scripts/run_worker.py` — worker entrypoint (single GPU) +- `scripts/run_merge.py` — merge entrypoint (CPU only, needs ~200G RAM) +- `scripts/run_slurm.py` — SLURM launcher (array + merge jobs) +- `scripts/run_slurm_cli.py` — CLI wrapper for `spd-attributions` -**Why output-residual-based storage?** +### Config (`config.py`) -For large vocab models (V=32K), the naive approach would require O((V+C)²) storage (~4 GB). -The output-residual-based approach requires only O((V+C)×(C+d)) storage (~670 MB for Llama-scale), -a 6.5x reduction. Output attributions are computed on-the-fly at query time with negligible latency. +- `DatasetAttributionConfig`: n_batches, batch_size, ci_threshold +- `AttributionsSlurmConfig`: adds n_gpus, partition, time, merge_time, merge_mem (default 200G) ### Repository (`repo.py`) -`AttributionRepo` provides read access via `AttributionRepo.open(run_id)`. Returns `None` if no data exists. Storage is loaded eagerly at construction. +`AttributionRepo.open(run_id)` → loads latest subrun. Returns `None` if no data. + +## Query Methods -## Key Types +All query methods take `metric: AttrMetric` (`"attr"` or `"attr_abs"`). -```python -DatasetAttributionStorage # Main storage class with split matrices -DatasetAttributionEntry # Single entry: component_key, layer, component_idx, value -DatasetAttributionConfig # Config (BaseConfig): n_batches, batch_size, ci_threshold -``` +| Method | Description | +|--------|-------------| +| `get_top_sources(target_key, k, sign, metric)` | Top sources → target | +| `get_top_targets(source_key, k, sign, metric)` | Top targets ← source | -## Query Methods +Key format: `"embed:{token_id}"`, `"0.glu.up:{c_idx}"`, `"output:{token_id}"`. -| Method | w_unembed required? | Description | -|--------|---------------------|-------------| -| `get_top_sources(component_key, k, sign)` | No | Top sources → component target | -| `get_top_sources(output_key, k, sign, w_unembed)` | Yes | Top sources → output token | -| `get_top_component_targets(source_key, k, sign)` | No | Top component targets | -| `get_top_output_targets(source_key, k, sign, w_unembed)` | Yes | Top output token targets | -| `get_top_targets(source_key, k, sign, w_unembed)` | Yes | All targets (components + outputs) | -| `get_attribution(source_key, component_key)` | No | Single component attribution | -| `get_attribution(source_key, output_key, w_unembed)` | Yes | Single output attribution | +Note: `attr_abs` returns empty for output targets (unembed edges have no abs variant). diff --git a/spd/dataset_attributions/config.py b/spd/dataset_attributions/config.py index a1de165fb..3d84fcbd8 100644 --- a/spd/dataset_attributions/config.py +++ b/spd/dataset_attributions/config.py @@ -26,3 +26,4 @@ class AttributionsSlurmConfig(BaseConfig): partition: str = DEFAULT_PARTITION_NAME time: str = "48:00:00" merge_time: str = "01:00:00" + merge_mem: str = "200G" diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 15e4f5b19..3b6e18cee 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -33,27 +33,13 @@ from spd.utils.wandb_utils import parse_wandb_run_path -def _build_component_layer_keys(model: ComponentModel) -> list[str]: - """Build list of component layer keys in canonical order. - - Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. - wte and output keys are not included - they're constructed from vocab_size. - """ - component_layer_keys = [] - for layer in model.target_module_paths: - n_components = model.module_to_c[layer] - for c_idx in range(n_components): - component_layer_keys.append(f"{layer}:{c_idx}") - return component_layer_keys - - def _build_alive_masks( model: ComponentModel, run_id: str, harvest_subrun_id: str | None, - n_components: int, + embed_path: str, vocab_size: int, -) -> tuple[Bool[Tensor, " n_sources"], Bool[Tensor, " n_components"]]: +) -> dict[str, Bool[Tensor, " n_components"]]: """Build masks of alive components (mean_activation > threshold) for sources and targets. Falls back to all-alive if harvest summary not available. @@ -63,43 +49,31 @@ def _build_alive_masks( - Targets: [0, n_components) = component layers (output handled via out_residual) """ - n_sources = vocab_size + n_components - - source_alive = torch.zeros(n_sources, dtype=torch.bool) - target_alive = torch.zeros(n_components, dtype=torch.bool) - - # All wte tokens are always alive (source indices [0, vocab_size)) - source_alive[:vocab_size] = True + component_alive = { + embed_path: torch.ones(vocab_size, dtype=torch.bool), # TODO(oli): maybe remove this + **{ + layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) + for layer in model.target_module_paths + }, + } if harvest_subrun_id is not None: harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) else: harvest = HarvestRepo.open_most_recent(run_id, readonly=True) assert harvest is not None, f"No harvest data for {run_id}" + summary = harvest.get_summary() assert summary is not None, "Harvest summary not available" - # Build masks for component layers - source_idx = vocab_size # Start after wte tokens - target_idx = 0 - for layer in model.target_module_paths: n_layer_components = model.module_to_c[layer] for c_idx in range(n_layer_components): component_key = f"{layer}:{c_idx}" is_alive = component_key in summary and summary[component_key].firing_density > 0.0 - source_alive[source_idx] = is_alive - target_alive[target_idx] = is_alive - source_idx += 1 - target_idx += 1 - - n_source_alive = int(source_alive.sum().item()) - n_target_alive = int(target_alive.sum().item()) - logger.info( - f"Alive components: {n_source_alive}/{n_sources} sources, " - f"{n_target_alive}/{n_components} component targets (firing density > 0.0)" - ) - return source_alive, target_alive + component_alive[layer][c_idx] = is_alive + + return component_alive def harvest_attributions( @@ -134,36 +108,25 @@ def harvest_attributions( model.eval() spd_config = run_info.config - train_loader, tokenizer = train_loader_and_tokenizer(spd_config, config.batch_size) - vocab_size = tokenizer.vocab_size - assert isinstance(vocab_size, int), f"vocab_size must be int, got {type(vocab_size)}" - logger.info(f"Vocab size: {vocab_size}") - - # Build component keys and alive masks - component_layer_keys = _build_component_layer_keys(model) - n_components = len(component_layer_keys) - source_alive, target_alive = _build_alive_masks( - model, run_id, harvest_subrun_id, n_components, vocab_size - ) - source_alive = source_alive.to(device) - target_alive = target_alive.to(device) - - n_sources = vocab_size + n_components - logger.info(f"Component layers: {n_components}, Sources: {n_sources}") + train_loader, _ = train_loader_and_tokenizer(spd_config, config.batch_size) # Get gradient connectivity logger.info("Computing sources_by_target...") topology = TransformerTopology(model.target_model) + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + vocab_size = topology.embedding_module.num_embeddings + logger.info(f"Vocab size: {vocab_size}") sources_by_target_raw = get_sources_by_target(model, topology, str(device), spd_config.sampling) - # Filter sources_by_target: - # - Valid targets: component layers + output - # - Valid sources: wte + component layers + # Filter to valid source/target pairs: + # - Valid sources: embedding + component layers + # - Valid targets: component layers + unembed component_layers = set(model.target_module_paths) - valid_sources = component_layers | {"wte"} - valid_targets = component_layers | {"output"} + valid_sources = component_layers | {embed_path} + valid_targets = component_layers | {unembed_path} - sources_by_target = {} + sources_by_target: dict[str, list[str]] = {} for target, sources in sources_by_target_raw.items(): if target not in valid_targets: continue @@ -172,19 +135,21 @@ def harvest_attributions( sources_by_target[target] = filtered_sources logger.info(f"Found {len(sources_by_target)} target layers with gradient connections") - # Create harvester + # Build alive masks + component_alive = _build_alive_masks(model, run_id, harvest_subrun_id, embed_path, vocab_size) + + # Create harvester (all concrete paths internally) harvester = AttributionHarvester( model=model, sources_by_target=sources_by_target, - n_components=n_components, vocab_size=vocab_size, - source_alive=source_alive, - target_alive=target_alive, + component_alive=component_alive, sampling=spd_config.sampling, + embed_path=embed_path, embedding_module=topology.embedding_module, + unembed_path=unembed_path, unembed_module=topology.unembed_module, device=device, - show_progress=True, ) # Process batches @@ -194,37 +159,24 @@ def harvest_attributions( batch_range = range(n_batches) case "whole_dataset": batch_range = itertools.count() + for batch_idx in tqdm.tqdm(batch_range, desc="Attribution batches"): try: batch_data = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break + # Skip batches not assigned to this rank if world_size is not None and batch_idx % world_size != rank: continue + batch = extract_batch_data(batch_data).to(device) harvester.process_batch(batch) - logger.info( - f"Processing complete. Tokens: {harvester.n_tokens:,}, Batches: {harvester.n_batches}" - ) - - # Normalize by n_tokens to get per-token average attribution - normalized_comp = harvester.comp_accumulator / harvester.n_tokens - normalized_out_residual = harvester.out_residual_accumulator / harvester.n_tokens + logger.info(f"Processing complete. Tokens: {harvester.n_tokens:,}") - # Build and save storage - storage = DatasetAttributionStorage( - component_layer_keys=component_layer_keys, - vocab_size=vocab_size, - d_model=harvester.d_model, - source_to_component=normalized_comp.cpu(), - source_to_out_residual=normalized_out_residual.cpu(), - n_batches_processed=harvester.n_batches, - n_tokens_processed=harvester.n_tokens, - ci_threshold=config.ci_threshold, - ) + storage = harvester.finalize(topology, config.ci_threshold) if rank is not None: worker_dir = output_dir / "worker_states" @@ -234,72 +186,24 @@ def harvest_attributions( output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / "dataset_attributions.pt" storage.save(output_path) - logger.info(f"Saved dataset attributions to {output_path}") def merge_attributions(output_dir: Path) -> None: - """Merge partial attribution files from parallel workers. - - Looks for worker_states/dataset_attributions_rank_*.pt files and merges them - into dataset_attributions.pt in the output_dir. - - Uses streaming merge to avoid OOM - loads one file at a time instead of all at once. - """ + """Merge partial attribution files from parallel workers.""" worker_dir = output_dir / "worker_states" rank_files = sorted(worker_dir.glob("dataset_attributions_rank_*.pt")) assert rank_files, f"No rank files found in {worker_dir}" logger.info(f"Found {len(rank_files)} rank files to merge") - # Load first file to get metadata and initialize accumulators - # Use double precision for accumulation to prevent precision loss with billions of tokens - first = DatasetAttributionStorage.load(rank_files[0]) - total_comp = (first.source_to_component * first.n_tokens_processed).double() - total_out_residual = (first.source_to_out_residual * first.n_tokens_processed).double() - total_tokens = first.n_tokens_processed - total_batches = first.n_batches_processed - logger.info(f"Loaded rank 0: {first.n_tokens_processed:,} tokens") - - # Stream remaining files one at a time - for rank_file in tqdm.tqdm(rank_files[1:], desc="Merging rank files"): - storage = DatasetAttributionStorage.load(rank_file) - - # Validate consistency - assert storage.component_layer_keys == first.component_layer_keys, ( - "Component layer keys mismatch" - ) - assert storage.vocab_size == first.vocab_size, "Vocab size mismatch" - assert storage.d_model == first.d_model, "d_model mismatch" - assert storage.ci_threshold == first.ci_threshold, "CI threshold mismatch" - - # Accumulate de-normalized values - total_comp += storage.source_to_component * storage.n_tokens_processed - total_out_residual += storage.source_to_out_residual * storage.n_tokens_processed - total_tokens += storage.n_tokens_processed - total_batches += storage.n_batches_processed - - # Normalize by total tokens and convert back to float32 for storage - merged_comp = (total_comp / total_tokens).float() - merged_out_residual = (total_out_residual / total_tokens).float() - - # Save merged result - merged = DatasetAttributionStorage( - component_layer_keys=first.component_layer_keys, - vocab_size=first.vocab_size, - d_model=first.d_model, - source_to_component=merged_comp, - source_to_out_residual=merged_out_residual, - n_batches_processed=total_batches, - n_tokens_processed=total_tokens, - ci_threshold=first.ci_threshold, - ) + merged = DatasetAttributionStorage.merge(rank_files) output_path = output_dir / "dataset_attributions.pt" merged.save(output_path) - assert output_path.stat().st_size > 0, f"Merge output is empty: {output_path}" - logger.info(f"Merged {len(rank_files)} files -> {output_path}") - logger.info(f"Total: {total_batches} batches, {total_tokens:,} tokens") - - for rank_file in rank_files: - rank_file.unlink() - worker_dir.rmdir() - logger.info(f"Deleted {len(rank_files)} per-rank files and worker_states/") + logger.info(f"Total: {merged.n_tokens_processed:,} tokens") + + # TODO(oli): reenable this + # disabled deletion for testing, posterity and retries + # for rank_file in rank_files: + # rank_file.unlink() + # worker_dir.rmdir() + # logger.info(f"Deleted {len(rank_files)} per-rank files and worker_states/") diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 5bef0af63..84c153847 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -4,27 +4,33 @@ training dataset using gradient x activation formula, summed over all positions and batches. -Uses residual-based storage for scalability: -- Component targets: accumulated directly to comp_accumulator -- Output targets: accumulated as attributions to output residual stream (source_to_out_residual) - Output attributions computed on-the-fly at query time via w_unembed +Three metrics are accumulated: +- attr: E[∂y/∂x · x] (signed mean attribution) +- attr_abs: E[∂|y|/∂x · x] (attribution to absolute value of target) + +Output (pseudo-) component attributions are handled differently: We accumulate attributions +to the output residual stream, then later project this into token space. + +All layer keys are concrete module paths (e.g. "wte", "h.0.attn.q_proj", "lm_head"). +Translation to canonical names happens at the storage boundary in harvest.py. """ from typing import Any import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Int from torch import Tensor, nn -from tqdm.auto import tqdm from spd.configs import SamplingType +from spd.dataset_attributions.storage import DatasetAttributionStorage from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos +from spd.topology import TransformerTopology from spd.utils.general_utils import bf16_autocast class AttributionHarvester: - """Accumulates attribution strengths across batches. + """Accumulates attribution strengths across batches using concrete module paths. The attribution formula is: attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] @@ -35,11 +41,6 @@ class AttributionHarvester: 2. For output targets, store attributions to the pre-unembed residual (d_model dimensions) instead of vocab tokens. This eliminates the expensive O((V+C) × d_model × V) matmul during harvesting and reduces storage. - - Index structure: - - Sources: wte tokens [0, vocab_size) + component layers [vocab_size, ...) - - Component targets: [0, n_components) in comp_accumulator - - Output targets: via out_residual_accumulator (computed on-the-fly at query time) """ sampling: SamplingType @@ -48,95 +49,115 @@ def __init__( self, model: ComponentModel, sources_by_target: dict[str, list[str]], - n_components: int, vocab_size: int, - source_alive: Bool[Tensor, " n_sources"], - target_alive: Bool[Tensor, " n_components"], + component_alive: dict[str, Bool[Tensor, " n_components"]], sampling: SamplingType, + embed_path: str, embedding_module: nn.Embedding, + unembed_path: str, unembed_module: nn.Linear, device: torch.device, - show_progress: bool = False, ): self.model = model self.sources_by_target = sources_by_target - self.n_components = n_components - self.vocab_size = vocab_size - self.source_alive = source_alive - self.target_alive = target_alive + self.component_alive = component_alive self.sampling = sampling + self.embed_path = embed_path self.embedding_module = embedding_module + self.unembed_path = unembed_path self.unembed_module = unembed_module + self.output_d_model = unembed_module.in_features self.device = device - self.show_progress = show_progress - - self.n_sources = vocab_size + n_components - self.n_batches = 0 - self.n_tokens = 0 - - # Split accumulators for component and output targets - self.comp_accumulator = torch.zeros(self.n_sources, n_components, device=device) - # For output targets: store attributions to output residual dimensions - self.d_model = unembed_module.in_features - self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) - - # Build per-layer index ranges for sources - self.component_layer_names = list(model.target_module_paths) - self.source_layer_to_idx_range = self._build_source_layer_index_ranges() - self.target_layer_to_idx_range = self._build_target_layer_index_ranges() - - # Pre-compute alive indices per layer - self.alive_source_idxs_per_layer = self._build_alive_indices( - self.source_layer_to_idx_range, source_alive + # attribution accumulators + self._straight_through_attr_acc = torch.zeros( + (self.output_d_model, self.embedding_module.num_embeddings), device=self.device ) - self.alive_target_idxs_per_layer = self._build_alive_indices( - self.target_layer_to_idx_range, target_alive + self._embed_tgts_acc = self._get_embed_targets_attr_accumulator(sources_by_target) + self._embed_tgts_acc_abs = self._get_embed_targets_attr_accumulator(sources_by_target) + self._unembed_srcs_acc = self._get_unembed_sources_attr_accumulator(sources_by_target) + self._regular_layers_acc = self._get_regular_layer_attr_accumulator(sources_by_target) + self._regular_layers_acc_abs = self._get_regular_layer_attr_accumulator(sources_by_target) + + # embed token occurrence counts for normalization (analogous to ci_sum for components) + self._embed_token_count = torch.zeros( + (self.embedding_module.num_embeddings,), dtype=torch.long, device=self.device ) - def _build_source_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Source order: wte tokens [0, vocab_size), then component layers.""" - ranges: dict[str, tuple[int, int]] = {"wte": (0, self.vocab_size)} - idx = self.vocab_size - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - return ranges - - def _build_target_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Target order: component layers [0, n_components). Output handled separately.""" - ranges: dict[str, tuple[int, int]] = {} - idx = 0 - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - # Note: "output" not included - handled via out_residual_accumulator - return ranges - - def _build_alive_indices( - self, layer_ranges: dict[str, tuple[int, int]], alive_mask: Bool[Tensor, " n"] - ) -> dict[str, list[int]]: - """Get alive local indices for each layer.""" - return { - layer: torch.where(alive_mask[start:end])[0].tolist() - for layer, (start, end) in layer_ranges.items() + # rms normalization accumulators + self.n_tokens = 0 + self._ci_sum_accumulator = { + layer: torch.zeros((c,), device=self.device) + for layer, c in self.model.module_to_c.items() } + self._square_component_act_accumulator = { + layer: torch.zeros((c,), device=self.device) + for layer, c in self.model.module_to_c.items() + } + self._logit_sq_sum = torch.zeros((self.unembed_module.out_features,), device=self.device) + + def _get_embed_targets_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, Tensor]: + # extract targets who's sources include the embedding + embed_targets_attr_accumulators: dict[str, Tensor] = {} + for target, sources in sources_by_target.items(): + if target == self.unembed_path: + # ignore straight-through edge + continue + if self.embed_path in sources: + embed_targets_attr_accumulators[target] = torch.zeros( + (self.model.module_to_c[target], self.embedding_module.num_embeddings), + device=self.device, + ) + return embed_targets_attr_accumulators + + def _get_unembed_sources_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, Tensor]: + # extract the unembed's sources + unembed_sources_attr_accumulators: dict[str, Tensor] = {} + for source in sources_by_target[self.unembed_path]: + if source == self.embed_path: + # ignore straight-through edge + continue + unembed_sources_attr_accumulators[source] = torch.zeros( + (self.output_d_model, self.model.module_to_c[source]), device=self.device + ) + return unembed_sources_attr_accumulators + + def _get_regular_layer_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, dict[str, Tensor]]: + regular_layers_shapes: dict[str, dict[str, Tensor]] = {} + for target, sources in sources_by_target.items(): + if target == self.unembed_path: + continue + regular_layers_shapes[target] = {} + for source in sources: + if source == self.embed_path: + continue + regular_layers_shapes[target][source] = torch.zeros( + (self.model.module_to_c[target], self.model.module_to_c[source]), + device=self.device, + ) + return regular_layers_shapes def process_batch(self, tokens: Int[Tensor, "batch seq"]) -> None: """Accumulate attributions from one batch.""" - self.n_batches += 1 self.n_tokens += tokens.numel() + self._embed_token_count.add_( + torch.bincount(tokens.flatten(), minlength=self.embedding_module.num_embeddings) + ) - # Setup hooks to capture wte output and pre-unembed residual - wte_out: list[Tensor] = [] + # Setup hooks to capture embedding output and pre-unembed residual + embed_out: list[Tensor] = [] pre_unembed: list[Tensor] = [] - def wte_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: + def embed_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: out.requires_grad_(True) - wte_out.clear() - wte_out.append(out) + embed_out.clear() + embed_out.append(out) return out def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> None: @@ -144,7 +165,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - h1 = self.embedding_module.register_forward_hook(wte_hook, with_kwargs=True) + h1 = self.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) h2 = self.unembed_module.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) # Get masks with all components active @@ -153,6 +174,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No ci = self.model.calc_causal_importances( pre_weight_acts=out.cache, sampling=self.sampling, detach_inputs=False ) + mask_infos = make_mask_infos( component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, routing_masks="all", @@ -160,100 +182,142 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No # Forward pass with gradients with torch.enable_grad(), bf16_autocast(): - comp_output: OutputWithCache = self.model( + model_output: OutputWithCache = self.model( tokens, mask_infos=mask_infos, cache_type="component_acts" ) h1.remove() h2.remove() - cache = comp_output.cache - cache["wte_post_detach"] = wte_out[0] - cache["pre_unembed"] = pre_unembed[0] - cache["tokens"] = tokens - - # Process each target layer - layers = list(self.sources_by_target.items()) - pbar = tqdm(layers, desc="Targets", disable=not self.show_progress, leave=False) - for target_layer, source_layers in pbar: - if target_layer == "output": - self._process_output_targets(source_layers, cache) + cache = model_output.cache + cache[f"{self.embed_path}_post_detach"] = embed_out[0] + cache[f"{self.unembed_path}_pre_detach"] = pre_unembed[0] + + with torch.no_grad(): + for real_layer, ci_vals in ci.lower_leaky.items(): + self._ci_sum_accumulator[real_layer].add_(ci_vals.sum(dim=(0, 1))) + self._logit_sq_sum.add_(model_output.output.detach().square().sum(dim=(0, 1))) + + for target_layer in self.sources_by_target: + if target_layer == self.unembed_path: + self._process_output_targets(cache, tokens, ci.lower_leaky) else: - self._process_component_targets(target_layer, source_layers, cache) + with torch.no_grad(): + sum_sq_acts = cache[f"{target_layer}_post_detach"].square().sum(dim=(0, 1)) + self._square_component_act_accumulator[target_layer].add_(sum_sq_acts) + self._process_component_targets(cache, tokens, ci.lower_leaky, target_layer) - def _process_component_targets( + def _process_output_targets( self, - target_layer: str, - source_layers: list[str], cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], ) -> None: - """Process attributions to a component layer.""" - target_start, _ = self.target_layer_to_idx_range[target_layer] - alive_targets = self.alive_target_idxs_per_layer[target_layer] - if not alive_targets: - return + """Process output attributions via output-residual-space storage.""" + out_residual = cache[f"{self.unembed_path}_pre_detach"] + + out_residual_sum = out_residual.sum(dim=(0, 1)) + + source_layers = self.sources_by_target[self.unembed_path] + assert self.embed_path in source_layers, "remove me when passed" - # Sum over batch and sequence - target_acts = cache[f"{target_layer}_pre_detach"].sum(dim=(0, 1)) source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for t_idx in alive_targets: - grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.comp_accumulator[:, target_start + t_idx], - source_layers, - grads, - source_acts, - cache["tokens"], - ) + for d_idx in range(self.output_d_model): + grads = torch.autograd.grad(out_residual_sum[d_idx], source_acts, retain_graph=True) + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == self.embed_path: + token_attr = (grad * act).sum(dim=-1) # (B S) + self._straight_through_attr_acc[d_idx].scatter_add_( + 0, tokens.flatten(), token_attr.flatten() + ) + else: + ci_weighted_attr = (grad * act * ci[source_layer]).sum(dim=(0, 1)) + self._unembed_srcs_acc[source_layer][d_idx].add_(ci_weighted_attr) - def _process_output_targets( + def _process_component_targets( self, - source_layers: list[str], cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], + target_layer: str, ) -> None: - """Process output attributions via output-residual-space storage. - - Instead of computing and storing attributions to vocab tokens directly, - we store attributions to output residual dimensions. Output attributions are - computed on-the-fly at query time via: attr[src, token] = out_residual[src] @ w_unembed[:, token] - """ - # Sum output residual over batch and sequence -> [d_model] - out_residual = cache["pre_unembed"].sum(dim=(0, 1)) + """Process attributions to a component layer.""" + alive_targets = self.component_alive[target_layer] + if not alive_targets.any(): + return + + target_acts_raw = cache[f"{target_layer}_pre_detach"] + + target_acts = target_acts_raw.sum(dim=(0, 1)) + target_acts_abs = target_acts_raw.abs().sum(dim=(0, 1)) + + source_layers = self.sources_by_target[target_layer] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for d_idx in range(self.d_model): - grads = torch.autograd.grad(out_residual[d_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.out_residual_accumulator[:, d_idx], - source_layers, - grads, - source_acts, - cache["tokens"], + def _accumulate_grads( + grads: tuple[Tensor, ...], + t_idx: int, + embed_acc: dict[str, Tensor], + regular_acc: dict[str, dict[str, Tensor]], + ) -> None: + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == self.embed_path: + token_attr = (grad * act).sum(dim=-1) # (B S) + embed_acc[target_layer][t_idx].scatter_add_( + 0, tokens.flatten(), token_attr.flatten() + ) + else: + ci_weighted = (grad * act * ci[source_layer]).sum(dim=(0, 1)) # (C,) + regular_acc[target_layer][source_layer][t_idx].add_(ci_weighted) + + for t_idx in torch.where(alive_targets)[0].tolist(): + grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) + _accumulate_grads( + grads=grads, + t_idx=t_idx, + embed_acc=self._embed_tgts_acc, + regular_acc=self._regular_layers_acc, ) - def _accumulate_attributions( - self, - target_col: Float[Tensor, " n_sources"], - source_layers: list[str], - grads: tuple[Tensor, ...], - source_acts: list[Tensor], - tokens: Int[Tensor, "batch seq"], - ) -> None: - """Accumulate grad*act attributions from sources to a target column.""" - with torch.no_grad(): - for layer, grad, act in zip(source_layers, grads, source_acts, strict=True): - alive = self.alive_source_idxs_per_layer[layer] - if not alive: - continue + grads_abs = torch.autograd.grad(target_acts_abs[t_idx], source_acts, retain_graph=True) + _accumulate_grads( + grads=grads_abs, + t_idx=t_idx, + embed_acc=self._embed_tgts_acc_abs, + regular_acc=self._regular_layers_acc_abs, + ) - if layer == "wte": - # Per-token: sum grad*act over d_model, scatter by token id - attr = (grad * act).sum(dim=-1).flatten() - target_col.scatter_add_(0, tokens.flatten(), attr) - else: - # Per-component: sum grad*act over batch and sequence - start, _ = self.source_layer_to_idx_range[layer] - attr = (grad * act).sum(dim=(0, 1)) - for c in alive: - target_col[start + c] += attr[c] + def finalize( + self, topology: TransformerTopology, ci_threshold: float + ) -> DatasetAttributionStorage: + """Package raw accumulators into storage. No normalization — that happens at query time.""" + assert self.n_tokens > 0, "No batches processed" + + to_canon = topology.target_to_canon + + def _canon_nested(acc: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + to_canon(t): {to_canon(s): v for s, v in srcs.items()} for t, srcs in acc.items() + } + + def _canon(acc: dict[str, Tensor]) -> dict[str, Tensor]: + return {to_canon(k): v for k, v in acc.items()} + + return DatasetAttributionStorage( + regular_attr=_canon_nested(self._regular_layers_acc), + regular_attr_abs=_canon_nested(self._regular_layers_acc_abs), + embed_attr=_canon(self._embed_tgts_acc), + embed_attr_abs=_canon(self._embed_tgts_acc_abs), + unembed_attr=_canon(self._unembed_srcs_acc), + embed_unembed_attr=self._straight_through_attr_acc, + w_unembed=topology.get_unembed_weight(), + ci_sum=_canon(self._ci_sum_accumulator), + component_act_sq_sum=_canon(self._square_component_act_accumulator), + logit_sq_sum=self._logit_sq_sum, + embed_token_count=self._embed_token_count, + ci_threshold=ci_threshold, + n_tokens_processed=self.n_tokens, + ) diff --git a/spd/dataset_attributions/repo.py b/spd/dataset_attributions/repo.py index 697036ba3..1175d584e 100644 --- a/spd/dataset_attributions/repo.py +++ b/spd/dataset_attributions/repo.py @@ -42,14 +42,13 @@ def open(cls, run_id: str) -> "AttributionRepo | None": candidates = sorted( [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("da-")], key=lambda d: d.name, + reverse=True, ) - if not candidates: - return None - subrun_dir = candidates[-1] - path = subrun_dir / "dataset_attributions.pt" - if not path.exists(): - return None - return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + for subrun_dir in candidates: + path = subrun_dir / "dataset_attributions.pt" + if path.exists(): + return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + return None def get_attributions(self) -> DatasetAttributionStorage: return self._storage diff --git a/spd/dataset_attributions/scripts/diagnose_cancellation.py b/spd/dataset_attributions/scripts/diagnose_cancellation.py new file mode 100644 index 000000000..09bb85cbb --- /dev/null +++ b/spd/dataset_attributions/scripts/diagnose_cancellation.py @@ -0,0 +1,529 @@ +"""Diagnostic: does |mean(grad×act)| preserve L2(grad×act) rankings? + +The harvester accumulates signed sums of grad×act across positions. This script +checks whether that signed mean gives the same top-K source ranking as the +magnitude-preserving L2 = sqrt(mean((grad×act)²)) alternative. + +Methodology: + For each target component, iterate through data, find positions where the + target's CI > threshold (i.e. it's actually firing), then compute per-position + grad×act for all source components at those positions. Reduce to |mean| and L2 + per source component, rank them, and compare rankings via top-K overlap and + mean rank displacement. + + The per-position grad×act computation matches the harvester exactly: + - Component sources: grad × act × ci (CI-weighted, per the harvester) + - Embed sources: (grad × act).sum(embed_dim), grouped by token ID + +Usage: + python -m spd.dataset_attributions.scripts.diagnose_cancellation \ + "wandb:goodfire/spd/s-892f140b" \ + --n_targets_per_layer 20 --n_active 100 +""" + +import random +from dataclasses import dataclass +from typing import Any + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import Tensor, nn + +from spd.configs import LMTaskConfig, SamplingType +from spd.data import train_loader_and_tokenizer +from spd.harvest.repo import HarvestRepo +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import make_mask_infos +from spd.settings import SPD_OUT_DIR +from spd.topology import TransformerTopology, get_sources_by_target +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import bf16_autocast, extract_batch_data +from spd.utils.wandb_utils import parse_wandb_run_path + +matplotlib.use("Agg") + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + + +@dataclass +class ModelContext: + model: ComponentModel + topology: TransformerTopology + sampling: SamplingType + sources_by_target: dict[str, list[str]] + device: torch.device + embed_path: str + unembed_path: str + vocab_size: int + + +def setup(wandb_path: str) -> ModelContext: + device = torch.device(get_device()) + run_info = SPDRunInfo.from_path(wandb_path) + model = ComponentModel.from_run_info(run_info).to(device) + model.eval() + topology = TransformerTopology(model.target_model) + + sources_by_target_raw = get_sources_by_target( + model, topology, str(device), run_info.config.sampling + ) + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + component_layers = set(model.target_module_paths) + valid_sources = component_layers | {embed_path} + valid_targets = component_layers | {unembed_path} + sources_by_target: dict[str, list[str]] = {} + for target, sources in sources_by_target_raw.items(): + if target not in valid_targets: + continue + filtered = [s for s in sources if s in valid_sources] + if filtered: + sources_by_target[target] = filtered + + return ModelContext( + model=model, + topology=topology, + sampling=run_info.config.sampling, + sources_by_target=sources_by_target, + device=device, + embed_path=embed_path, + unembed_path=unembed_path, + vocab_size=topology.embedding_module.num_embeddings, + ) + + +# --------------------------------------------------------------------------- +# Forward pass (matches harvester.process_batch exactly) +# --------------------------------------------------------------------------- + + +def forward_with_caches( + ctx: ModelContext, + tokens: Tensor, +) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """One forward pass → (cache, ci). Reuse across all (target, source) pairs.""" + embed_out: list[Tensor] = [] + pre_unembed: list[Tensor] = [] + + def embed_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: + out.requires_grad_(True) + embed_out.clear() + embed_out.append(out) + return out + + def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> None: + args[0].requires_grad_(True) + pre_unembed.clear() + pre_unembed.append(args[0]) + + h1 = ctx.topology.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) + h2 = ctx.topology.unembed_module.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) + + with torch.no_grad(), bf16_autocast(): + out = ctx.model(tokens, cache_type="input") + ci = ctx.model.calc_causal_importances( + pre_weight_acts=out.cache, sampling=ctx.sampling, detach_inputs=False + ) + + mask_infos = make_mask_infos( + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + routing_masks="all", + ) + + with torch.enable_grad(), bf16_autocast(): + model_output = ctx.model(tokens, mask_infos=mask_infos, cache_type="component_acts") + + h1.remove() + h2.remove() + + cache = model_output.cache + cache[f"{ctx.embed_path}_post_detach"] = embed_out[0] + cache[f"{ctx.unembed_path}_pre_detach"] = pre_unembed[0] + + return cache, ci.lower_leaky + + +# --------------------------------------------------------------------------- +# Per-position attribution (matches harvester._process_component_targets) +# --------------------------------------------------------------------------- + + +def per_position_grads_at( + ctx: ModelContext, + cache: dict[str, Tensor], + ci: dict[str, Tensor], + target_concrete: str, + t_idx: int, + s: int, +) -> dict[str, Tensor]: + """Compute grad×act for all source layers at a single position (b=0, s=s). + + Returns {source_concrete: value_tensor} where: + - Component source: grad × act × ci, shape (C_source,) + - Embed source: (grad × act).sum(embed_dim), scalar + Matches the harvester's _accumulate_grads exactly, just without the sum. + """ + target_acts_raw = cache[f"{target_concrete}_pre_detach"] + scalar = target_acts_raw[0, s, t_idx] + + source_layers = ctx.sources_by_target[target_concrete] + source_acts = [cache[f"{sc}_post_detach"] for sc in source_layers] + grads = torch.autograd.grad(scalar, source_acts, retain_graph=True) + + result: dict[str, Tensor] = {} + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == ctx.embed_path: + result[source_layer] = (grad[0, s] * act[0, s]).sum().cpu() + else: + result[source_layer] = (grad[0, s] * act[0, s] * ci[source_layer][0, s]).cpu() + return result + + +# --------------------------------------------------------------------------- +# Collect active positions for a target component +# --------------------------------------------------------------------------- + + +def collect_active_attrs( + ctx: ModelContext, + loader_iter: Any, + target_concrete: str, + t_idx: int, + n_active: int, + ci_threshold: float, + max_sequences: int, +) -> tuple[dict[str, list[Tensor]], int, int]: + """Iterate sequences, backward only at positions where target CI > threshold. + + Returns (per_source_vals, n_found, n_sequences_checked) where + per_source_vals[source_concrete] is a list of tensors, one per active position. + """ + source_layers = ctx.sources_by_target[target_concrete] + per_source: dict[str, list[Tensor]] = {sc: [] for sc in source_layers} + n_found = 0 + n_checked = 0 + + for _ in range(max_sequences): + try: + batch_data = next(loader_iter) + except StopIteration: + break + tokens = extract_batch_data(batch_data).to(ctx.device) + n_checked += 1 + + # Cheap CI check (no grad graph needed) + with torch.no_grad(), bf16_autocast(): + out = ctx.model(tokens, cache_type="input") + ci_check = ctx.model.calc_causal_importances( + pre_weight_acts=out.cache, sampling=ctx.sampling, detach_inputs=False + ) + + ci_vals = ci_check.lower_leaky[target_concrete][0, :, t_idx] + active_positions = (ci_vals > ci_threshold).nonzero(as_tuple=True)[0] + if len(active_positions) == 0: + continue + + # Full forward with grad graph + cache, ci = forward_with_caches(ctx, tokens) + + for s in active_positions.tolist(): + if n_found >= n_active: + break + grads = per_position_grads_at(ctx, cache, ci, target_concrete, t_idx, s) + for sc, val in grads.items(): + per_source[sc].append(val) + n_found += 1 + + if n_found >= n_active: + break + + return per_source, n_found, n_checked + + +# --------------------------------------------------------------------------- +# Reduce per-position values to |mean| and L2 per source component +# --------------------------------------------------------------------------- + + +def reduce_to_rankings( + per_source: dict[str, list[Tensor]], + embed_path: str, + tokens_per_pos: list[Tensor] | None, + vocab_size: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Reduce per-position attrs to per-source-component |mean| and L2. + + Component sources: each position gives a (C,) vector. |mean| and L2 over positions. + Embed sources: each position gives a scalar. Group by token ID via scatter_add, + then |mean| and L2 per token. + + Returns (abs_means, l2s, is_embed) arrays pooled across all source layers. + """ + all_abs_means: list[np.ndarray] = [] + all_l2s: list[np.ndarray] = [] + all_is_embed: list[np.ndarray] = [] + + for source_layer, vals in per_source.items(): + if not vals: + continue + + if source_layer == embed_path: + assert tokens_per_pos is not None + all_vals = torch.stack(vals).float() + all_toks = torch.cat(tokens_per_pos) + token_sum = torch.zeros(vocab_size) + token_sq_sum = torch.zeros(vocab_size) + token_count = torch.zeros(vocab_size) + token_sum.scatter_add_(0, all_toks, all_vals) + token_sq_sum.scatter_add_(0, all_toks, all_vals.square()) + token_count.scatter_add_(0, all_toks, torch.ones_like(all_vals)) + safe_count = token_count.clamp(min=1) + all_abs_means.append((token_sum / safe_count).abs().numpy()) + all_l2s.append((token_sq_sum / safe_count).sqrt().numpy()) + all_is_embed.append(np.ones(vocab_size, dtype=bool)) + else: + stacked = torch.stack(vals).float() # (N, C) + all_abs_means.append(stacked.mean(dim=0).abs().numpy()) + all_l2s.append(stacked.square().mean(dim=0).sqrt().numpy()) + all_is_embed.append(np.zeros(stacked.shape[1], dtype=bool)) + + return np.concatenate(all_abs_means), np.concatenate(all_l2s), np.concatenate(all_is_embed) + + +# --------------------------------------------------------------------------- +# Target selection +# --------------------------------------------------------------------------- + + +def select_targets( + ctx: ModelContext, + run_id: str, + n_per_layer: int, + fd_range: tuple[float, float], + seed: int, + comp_only: bool, +) -> list[tuple[str, int, float]]: + """Select target components with firing density in range. + + Returns [(concrete_path, c_idx, firing_density), ...]. + """ + harvest = HarvestRepo.open_most_recent(run_id, readonly=True) + assert harvest is not None + summary = harvest.get_summary() + assert summary is not None + + rng = random.Random(seed) + targets: list[tuple[str, int, float]] = [] + + for target_concrete in ctx.sources_by_target: + if target_concrete == ctx.unembed_path: + continue + if comp_only and ctx.embed_path in ctx.sources_by_target[target_concrete]: + continue + + candidates: list[tuple[int, float]] = [] + for c_idx in range(ctx.model.module_to_c[target_concrete]): + key = f"{target_concrete}:{c_idx}" + if key not in summary: + continue + fd = summary[key].firing_density + if fd_range[0] < fd < fd_range[1]: + candidates.append((c_idx, fd)) + + rng.shuffle(candidates) + for c_idx, fd in candidates[:n_per_layer]: + targets.append((target_concrete, c_idx, fd)) + + return targets + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +@dataclass +class TargetResult: + target_layer: str + target_idx: int + firing_density: float + n_active: int + n_sources: int + has_embed: bool + top5_mrd: float + top5_overlap: int + top10_mrd: float + top10_overlap: int + + +def main( + wandb_path: str, + n_targets_per_layer: int = 20, + n_active: int = 100, + ci_threshold: float = 0.01, + fd_min: float = 1e-4, + fd_max: float = 1e-1, + max_sequences: int = 5000, + seed: int = 42, + comp_only: bool = False, +) -> None: + import time + + ctx = setup(wandb_path) + _, _, run_id = parse_wandb_run_path(wandb_path) + to_canon = ctx.topology.target_to_canon + + targets = select_targets(ctx, run_id, n_targets_per_layer, (fd_min, fd_max), seed, comp_only) + print(f"Selected {len(targets)} targets (fd in ({fd_min}, {fd_max}), comp_only={comp_only})") + + spd_config = SPDRunInfo.from_path(wandb_path).config + assert isinstance(spd_config.task_config, LMTaskConfig) + # frozen + # spd_config.task_config.dataset_name = "danbraunai/pile-uncopyrighted-tok-shuffled" + train_loader, _ = train_loader_and_tokenizer(spd_config, 1) + loader_iter = iter(train_loader) + + results: list[TargetResult] = [] + t0 = time.time() + + for ti, (tgt_concrete, t_idx, fd) in enumerate(targets): + tgt_canon = to_canon(tgt_concrete) + source_list = ctx.sources_by_target[tgt_concrete] + has_embed = ctx.embed_path in source_list + + per_source, n_found, _ = collect_active_attrs( + ctx, loader_iter, tgt_concrete, t_idx, n_active, ci_threshold, max_sequences + ) + + if n_found < 10: + print( + f"[{ti + 1}/{len(targets)}] {tgt_canon}:{t_idx} fd={fd:.4f} " + f"— only {n_found} active, skipping" + ) + continue + + # Embed sources excluded: collect_active_attrs doesn't store token IDs + # needed for scatter_add grouping. This is fine — embed rankings are + # near-perfect anyway (confirmed in notebook analysis). + abs_means, l2s, _ = reduce_to_rankings( + {sc: v for sc, v in per_source.items() if sc != ctx.embed_path}, + ctx.embed_path, + None, + ctx.vocab_size, + ) + + n = len(abs_means) + rank_mean = np.argsort(np.argsort(-abs_means)) + rank_l2 = np.argsort(np.argsort(-l2s)) + + top5 = np.argsort(-l2s)[:5] + top10 = np.argsort(-l2s)[:10] + + results.append( + TargetResult( + target_layer=tgt_canon, + target_idx=t_idx, + firing_density=fd, + n_active=n_found, + n_sources=n, + has_embed=has_embed, + top5_mrd=np.abs(rank_mean[top5] - rank_l2[top5]).mean(), + top5_overlap=len(set(top5) & set(np.argsort(-abs_means)[:5])), + top10_mrd=np.abs(rank_mean[top10] - rank_l2[top10]).mean(), + top10_overlap=len(set(top10) & set(np.argsort(-abs_means)[:10])), + ) + ) + + if (ti + 1) % 10 == 0: + elapsed = time.time() - t0 + rate = elapsed / (ti + 1) + print( + f"[{ti + 1}/{len(targets)}] {elapsed:.0f}s, ~{rate * (len(targets) - ti - 1):.0f}s left", + flush=True, + ) + + elapsed = time.time() - t0 + print(f"\nDone: {len(results)} targets in {elapsed:.0f}s") + + _print_results(results) + _plot_results(results) + + +def _print_results(results: list[TargetResult]) -> None: + print(f"\n{'=' * 70}") + print("CANCELLATION DIAGNOSTIC: |mean| vs L2 ranking agreement") + print(f"{'=' * 70}") + print(f" {len(results)} targets, active positions only (CI > threshold)") + print() + + for label, metric, K in [ + ("Top-5 mean rank displacement", "top5_mrd", 5), + ("Top-5 overlap", "top5_overlap", 5), + ("Top-10 mean rank displacement", "top10_mrd", 10), + ("Top-10 overlap", "top10_overlap", 10), + ]: + vals = [getattr(r, metric) for r in results] + print( + f" {label}: {np.mean(vals):.1f} ± {np.std(vals):.1f}" + f" (median {np.median(vals):.1f})" + (f"/{K}" if "overlap" in metric else "") + ) + + print("\n By target layer:") + layers = sorted(set(r.target_layer for r in results)) + print( + f" {'layer':<18} {'n':>3} {'top5 mrd':>10} {'top5 olap':>10} " + f"{'top10 mrd':>10} {'top10 olap':>10}" + ) + print(f" {'-' * 65}") + for layer in layers: + lr = [r for r in results if r.target_layer == layer] + print( + f" {layer:<18} {len(lr):>3} " + f"{np.mean([r.top5_mrd for r in lr]):>6.1f}±{np.std([r.top5_mrd for r in lr]):<3.1f}" + f"{np.mean([r.top5_overlap for r in lr]):>7.1f}/5 " + f"{np.mean([r.top10_mrd for r in lr]):>6.1f}±{np.std([r.top10_mrd for r in lr]):<3.1f}" + f"{np.mean([r.top10_overlap for r in lr]):>7.1f}/10" + ) + + +def _plot_results(results: list[TargetResult]) -> None: + _, axes = plt.subplots(1, 2, figsize=(14, 5)) + + layers = sorted(set(r.target_layer for r in results)) + colors = {layer: f"C{i}" for i, layer in enumerate(layers)} + + for ax, K in [(axes[0], 5), (axes[1], 10)]: + for layer in layers: + vals = [r for r in results if r.target_layer == layer] + ax.hist( + [getattr(r, f"top{K}_mrd") for r in vals], + bins=np.arange(-0.5, 25.5, 1), + alpha=0.4, + color=colors[layer], + label=f"{layer} (μ={np.mean([getattr(r, f'top{K}_mrd') for r in vals]):.1f})", + ) + ax.set_xlabel(f"Top-{K} mean rank displacement") + ax.set_ylabel("# targets") + ax.set_title( + f"Top-{K}: |mean| vs L2 ranking agreement\n{len(results)} targets, active positions only" + ) + ax.legend(fontsize=8) + + plt.tight_layout() + out_path = SPD_OUT_DIR / "www" / "attr_cancellation_diagnostic.png" + out_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path, dpi=150) + print(f"\nSaved to {out_path}") + plt.close() + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/scripts/run_merge.py b/spd/dataset_attributions/scripts/run_merge.py new file mode 100644 index 000000000..913ea5374 --- /dev/null +++ b/spd/dataset_attributions/scripts/run_merge.py @@ -0,0 +1,37 @@ +"""Merge script for dataset attribution rank files. + +Combines per-rank attribution files into a single merged result. + +Usage: + python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id da-xxx +""" + +from spd.dataset_attributions.harvest import merge_attributions +from spd.dataset_attributions.repo import get_attributions_subrun_dir +from spd.log import logger +from spd.utils.wandb_utils import parse_wandb_run_path + + +def main( + *, + wandb_path: str, + subrun_id: str, +) -> None: + _, _, run_id = parse_wandb_run_path(wandb_path) + output_dir = get_attributions_subrun_dir(run_id, subrun_id) + logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") + merge_attributions(output_dir) + + +def get_command(wandb_path: str, subrun_id: str) -> str: + return ( + f"python -m spd.dataset_attributions.scripts.run_merge " + f'--wandb_path "{wandb_path}" ' + f"--subrun_id {subrun_id}" + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index e405b2bd9..6adc4bd52 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -14,7 +14,7 @@ from datetime import datetime from spd.dataset_attributions.config import AttributionsSlurmConfig -from spd.dataset_attributions.scripts import run as attribution_run +from spd.dataset_attributions.scripts import run_merge, run_worker from spd.log import logger from spd.utils.git_utils import create_git_snapshot from spd.utils.slurm import ( @@ -80,12 +80,12 @@ def submit_attributions( suffix = f"-{job_suffix}" if job_suffix else "" array_job_name = f"spd-attr{suffix}" - config_json = config.model_dump_json(exclude_none=True) + config_json = config.config.model_dump_json(exclude_none=True) # SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc. worker_commands = [] for rank in range(n_gpus): - cmd = attribution_run.get_worker_command( + cmd = run_worker.get_command( wandb_path, config_json, rank=rank, @@ -115,12 +115,13 @@ def submit_attributions( ) # Submit merge job with dependency on array completion - merge_cmd = attribution_run.get_merge_command(wandb_path, subrun_id) + merge_cmd = run_merge.get_command(wandb_path, subrun_id) merge_config = SlurmConfig( job_name="spd-attr-merge", partition=partition, - n_gpus=0, # No GPU needed for merge + n_gpus=0, time=config.merge_time, + mem=config.merge_mem, snapshot_branch=snapshot_branch, dependency_job_id=array_result.job_id, comment=wandb_url, diff --git a/spd/dataset_attributions/scripts/run.py b/spd/dataset_attributions/scripts/run_worker.py similarity index 51% rename from spd/dataset_attributions/scripts/run.py rename to spd/dataset_attributions/scripts/run_worker.py index 5d060767e..1f512fbb7 100644 --- a/spd/dataset_attributions/scripts/run.py +++ b/spd/dataset_attributions/scripts/run_worker.py @@ -4,58 +4,44 @@ Usage: # Single GPU - python -m spd.dataset_attributions.scripts.run --config_json '...' + python -m spd.dataset_attributions.scripts.run_worker + + # Single GPU with config + python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 500}' # Multi-GPU (run in parallel) - python -m spd.dataset_attributions.scripts.run --config_json '...' --rank 0 --world_size 4 --subrun_id da-20260211_120000 - ... - python -m spd.dataset_attributions.scripts.run --merge --subrun_id da-20260211_120000 + python -m spd.dataset_attributions.scripts.run_worker --rank 0 --world_size 4 --subrun_id da-xxx """ from datetime import datetime from typing import Any from spd.dataset_attributions.config import DatasetAttributionConfig -from spd.dataset_attributions.harvest import harvest_attributions, merge_attributions +from spd.dataset_attributions.harvest import harvest_attributions from spd.dataset_attributions.repo import get_attributions_subrun_dir -from spd.log import logger from spd.utils.wandb_utils import parse_wandb_run_path def main( wandb_path: str, config_json: dict[str, Any], - rank: int | None = None, - world_size: int | None = None, - merge: bool = False, + rank: int, + world_size: int, subrun_id: str | None = None, harvest_subrun_id: str | None = None, ) -> None: - assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" _, _, run_id = parse_wandb_run_path(wandb_path) if subrun_id is None: subrun_id = "da-" + datetime.now().strftime("%Y%m%d_%H%M%S") + config = ( + DatasetAttributionConfig.model_validate(config_json) + if config_json + else DatasetAttributionConfig() + ) output_dir = get_attributions_subrun_dir(run_id, subrun_id) - if merge: - assert rank is None and world_size is None, "Cannot specify rank/world_size with --merge" - logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") - merge_attributions(output_dir) - return - - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - - config = DatasetAttributionConfig.model_validate(config_json) - - if world_size is not None: - logger.info( - f"Distributed harvest: {wandb_path} (rank {rank}/{world_size}, subrun {subrun_id})" - ) - else: - logger.info(f"Single-GPU harvest: {wandb_path} (subrun {subrun_id})") - harvest_attributions( wandb_path=wandb_path, config=config, @@ -66,7 +52,7 @@ def main( ) -def get_worker_command( +def get_command( wandb_path: str, config_json: str, rank: int, @@ -75,7 +61,7 @@ def get_worker_command( harvest_subrun_id: str | None = None, ) -> str: cmd = ( - f"python -m spd.dataset_attributions.scripts.run " + f"python -m spd.dataset_attributions.scripts.run_worker " f'"{wandb_path}" ' f"--config_json '{config_json}' " f"--rank {rank} " @@ -87,20 +73,7 @@ def get_worker_command( return cmd -def get_merge_command(wandb_path: str, subrun_id: str) -> str: - return ( - f"python -m spd.dataset_attributions.scripts.run " - f'"{wandb_path}" ' - "--merge " - f"--subrun_id {subrun_id}" - ) - - -def cli() -> None: +if __name__ == "__main__": import fire fire.Fire(main) - - -if __name__ == "__main__": - cli() diff --git a/spd/dataset_attributions/storage.py b/spd/dataset_attributions/storage.py index 16181201d..3d62a1b4a 100644 --- a/spd/dataset_attributions/storage.py +++ b/spd/dataset_attributions/storage.py @@ -1,22 +1,37 @@ """Storage classes for dataset attributions. -Uses a residual-based storage approach for scalability: -- Component targets: stored directly in source_to_component matrix -- Output targets: stored as attributions to residual stream, computed on-the-fly via w_unembed +Stores raw (unnormalized) attribution sums. Normalization happens at query time using +stored metadata (CI sums, activation RMS, logit RMS). + +Four edge types, each with its own shape: +- regular: component → component [tgt_c, src_c] (signed + abs) +- embed: embed → component [tgt_c, vocab] (signed + abs) +- unembed: component → unembed [d_model, src_c] (signed only, residual space) +- embed_unembed: embed → unembed [d_model, vocab] (signed only, residual space) + +Abs variants are unavailable for unembed edges because abs is a nonlinear operation +incompatible with the residual-space storage trick. + +Normalization formula: + normed[t, s] = raw[t, s] / source_denom[s] / target_rms[t] +- source_denom is ci_sum[s] for component sources, embed_token_count[s] for embed sources +- target_rms is component activation RMS for component targets, logit RMS for output targets """ -import dataclasses -from collections.abc import Callable +import bisect from dataclasses import dataclass from pathlib import Path from typing import Literal import torch -from jaxtyping import Float from torch import Tensor from spd.log import logger +AttrMetric = Literal["attr", "attr_abs"] + +EPS = 1e-10 + @dataclass class DatasetAttributionEntry: @@ -28,318 +43,339 @@ class DatasetAttributionEntry: value: float -@dataclass class DatasetAttributionStorage: """Dataset-aggregated attribution strengths between components. - Uses residual-based storage for scalability with large vocabularies: - - source_to_component: direct attributions to component targets - - source_to_out_residual: attributions to output residual stream (for computing output attributions) - - Output attributions are computed on-the-fly: attr[src, output_token] = out_residual[src] @ w_unembed[:, token] + All layer names use canonical addressing (e.g., "embed", "0.glu.up", "output"). - Source indexing (rows): - - [0, vocab_size): wte tokens - - [vocab_size, vocab_size + n_components): component layers - - Target indexing: - - Component targets: [0, n_components) in source_to_component - - Output targets: computed via source_to_out_residual @ w_unembed + Internally stores raw sums — normalization applied at query time. + Public interface: get_top_sources(), get_top_targets(), save/load/merge. Key formats: - - wte tokens: "wte:{token_id}" - - component layers: "layer:c_idx" (e.g., "h.0.attn.q_proj:5") + - embed tokens: "embed:{token_id}" + - component layers: "canonical_layer:c_idx" (e.g., "0.glu.up:5") - output tokens: "output:{token_id}" """ - component_layer_keys: list[str] - """Component layer keys in order: ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...]""" + def __init__( + self, + regular_attr: dict[str, dict[str, Tensor]], + regular_attr_abs: dict[str, dict[str, Tensor]], + embed_attr: dict[str, Tensor], + embed_attr_abs: dict[str, Tensor], + unembed_attr: dict[str, Tensor], + embed_unembed_attr: Tensor, + w_unembed: Tensor, + ci_sum: dict[str, Tensor], + component_act_sq_sum: dict[str, Tensor], + logit_sq_sum: Tensor, + embed_token_count: Tensor, + ci_threshold: float, + n_tokens_processed: int, + ): + self._regular_attr = regular_attr + self._regular_attr_abs = regular_attr_abs + self._embed_attr = embed_attr + self._embed_attr_abs = embed_attr_abs + self._unembed_attr = unembed_attr + self._embed_unembed_attr = embed_unembed_attr + self._w_unembed = w_unembed + self._ci_sum = ci_sum + self._component_act_sq_sum = component_act_sq_sum + self._logit_sq_sum = logit_sq_sum + self._embed_token_count = embed_token_count + self.ci_threshold = ci_threshold + self.n_tokens_processed = n_tokens_processed - vocab_size: int - """Vocabulary size (number of wte and output tokens)""" + @property + def target_layers(self) -> set[str]: + return self._regular_attr.keys() | self._embed_attr.keys() - d_model: int - """Model hidden dimension (residual stream size)""" + def _target_n_components(self, layer: str) -> int | None: + if layer in self._embed_attr: + return self._embed_attr[layer].shape[0] + if layer in self._regular_attr: + first_source = next(iter(self._regular_attr[layer].values())) + return first_source.shape[0] + return None - source_to_component: Float[Tensor, "n_sources n_components"] - """Attributions from sources to component targets. Shape: (vocab_size + n_components, n_components)""" + @property + def n_components(self) -> int: + total = 0 + for layer in self.target_layers: + n = self._target_n_components(layer) + assert n is not None + total += n + return total + + @staticmethod + def _parse_key(key: str) -> tuple[str, int]: + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) - source_to_out_residual: Float[Tensor, "n_sources d_model"] - """Attributions from sources to output residual dimensions. Shape: (vocab_size + n_components, d_model)""" + def _select_metric( + self, metric: AttrMetric + ) -> tuple[dict[str, dict[str, Tensor]], dict[str, Tensor]]: + match metric: + case "attr": + return self._regular_attr, self._embed_attr + case "attr_abs": + return self._regular_attr_abs, self._embed_attr_abs - n_batches_processed: int - n_tokens_processed: int - ci_threshold: float + def _component_activation_rms(self, layer: str) -> Tensor: + """RMS activation for a component layer. Shape (n_components,).""" + return (self._component_act_sq_sum[layer] / self.n_tokens_processed).sqrt().clamp(min=EPS) - _component_key_to_idx: dict[str, int] = dataclasses.field( - default_factory=dict, repr=False, init=False - ) + def _logit_activation_rms(self) -> Tensor: + """RMS logit per token. Shape (vocab,).""" + return (self._logit_sq_sum / self.n_tokens_processed).sqrt().clamp(min=EPS) - def __post_init__(self) -> None: - self._component_key_to_idx = {k: i for i, k in enumerate(self.component_layer_keys)} + def _layer_ci_sum(self, layer: str) -> Tensor: + """CI sum for a source layer, clamped. Shape (n_components,).""" + return self._ci_sum[layer].clamp(min=EPS) - n_components = len(self.component_layer_keys) - n_sources = self.vocab_size + n_components + def _embed_count(self) -> Tensor: + """Per-token occurrence count, clamped. Shape (vocab,).""" + return self._embed_token_count.float().clamp(min=EPS) - expected_comp_shape = (n_sources, n_components) - assert self.source_to_component.shape == expected_comp_shape, ( - f"source_to_component shape {self.source_to_component.shape} " - f"doesn't match expected {expected_comp_shape}" - ) + def get_top_sources( + self, + target_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + ) -> list[DatasetAttributionEntry]: + target_layer, target_idx = self._parse_key(target_key) + + value_segments: list[Tensor] = [] + layer_names: list[str] = [] + if target_layer == "embed": + return [] + + if target_layer == "output": + if metric == "attr_abs": + return [] + w = self._w_unembed[:, target_idx].to(self._embed_unembed_attr.device) + target_act_rms = self._logit_activation_rms()[target_idx] + + for source_layer, attr_matrix in self._unembed_attr.items(): + raw = w @ attr_matrix # (src_c,) + value_segments.append(raw / self._layer_ci_sum(source_layer) / target_act_rms) + layer_names.append(source_layer) + + raw = w @ self._embed_unembed_attr # (vocab,) + value_segments.append(raw / self._embed_count() / target_act_rms) + layer_names.append("embed") + else: + regular_attr, embed_target_attr = self._select_metric(metric) + target_act_rms = self._component_activation_rms(target_layer)[target_idx] - expected_resid_shape = (n_sources, self.d_model) - assert self.source_to_out_residual.shape == expected_resid_shape, ( - f"source_to_out_residual shape {self.source_to_out_residual.shape} " - f"doesn't match expected {expected_resid_shape}" - ) + if target_layer in regular_attr: + for source_layer, attr_matrix in regular_attr[target_layer].items(): + raw = attr_matrix[target_idx, :] # (src_c,) + value_segments.append(raw / self._layer_ci_sum(source_layer) / target_act_rms) + layer_names.append(source_layer) - @property - def n_components(self) -> int: - return len(self.component_layer_keys) + if target_layer in embed_target_attr: + raw = embed_target_attr[target_layer][target_idx, :] # (vocab,) + value_segments.append(raw / self._embed_count() / target_act_rms) + layer_names.append("embed") - @property - def n_sources(self) -> int: - return self.vocab_size + self.n_components + return self._top_k_from_segments(value_segments, layer_names, k, sign) - def _parse_key(self, key: str) -> tuple[str, int]: - """Parse a key into (layer, idx).""" - layer, idx_str = key.rsplit(":", 1) - return layer, int(idx_str) + def get_top_targets( + self, + source_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + include_outputs: bool = True, + ) -> list[DatasetAttributionEntry]: + source_layer, source_idx = self._parse_key(source_key) - def _source_idx(self, key: str) -> int: - """Get source (row) index for a key. Raises KeyError if not a valid source.""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - assert 0 <= idx < self.vocab_size, ( - f"wte index {idx} out of range [0, {self.vocab_size})" - ) - return idx - case "output": - raise KeyError(f"output tokens cannot be sources: {key}") - case _: - return self.vocab_size + self._component_key_to_idx[key] - - def _component_target_idx(self, key: str) -> int: - """Get target index for a component key. Raises KeyError if output or invalid.""" - if key.startswith(("wte:", "output:")): - raise KeyError(f"Not a component target: {key}") - return self._component_key_to_idx[key] - - def _source_idx_to_key(self, idx: int) -> str: - """Convert source (row) index to key.""" - if idx < self.vocab_size: - return f"wte:{idx}" - return self.component_layer_keys[idx - self.vocab_size] - - def _component_target_idx_to_key(self, idx: int) -> str: - """Convert component target index to key.""" - return self.component_layer_keys[idx] - - def _output_target_idx_to_key(self, idx: int) -> str: - """Convert output token index to key.""" - return f"output:{idx}" - - def _is_output_target(self, key: str) -> bool: - """Check if key is an output target.""" - return key.startswith("output:") - - def _output_token_id(self, key: str) -> int: - """Extract token_id from an output key like 'output:123'. Asserts valid range.""" - _, token_id = self._parse_key(key) - assert 0 <= token_id < self.vocab_size, f"output index {token_id} out of range" - return token_id - - def has_source(self, key: str) -> bool: - """Check if a key can be a source (wte token or component layer).""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - return 0 <= idx < self.vocab_size - case "output": - return False - case _: - return key in self._component_key_to_idx - - def has_target(self, key: str) -> bool: - """Check if a key can be a target (component layer or output token).""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - return False - case "output": - return 0 <= idx < self.vocab_size - case _: - return key in self._component_key_to_idx + value_segments: list[Tensor] = [] + layer_names: list[str] = [] - def save(self, path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "component_layer_keys": self.component_layer_keys, - "vocab_size": self.vocab_size, - "d_model": self.d_model, - "source_to_component": self.source_to_component.cpu(), - "source_to_out_residual": self.source_to_out_residual.cpu(), - "n_batches_processed": self.n_batches_processed, - "n_tokens_processed": self.n_tokens_processed, - "ci_threshold": self.ci_threshold, - }, - path, - ) - size_mb = path.stat().st_size / (1024 * 1024) - logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") + if source_layer == "output": + return [] + elif source_layer == "embed": + regular, embed = self._select_metric(metric) + embed_count = self._embed_count()[source_idx] - @classmethod - def load(cls, path: Path) -> "DatasetAttributionStorage": - data = torch.load(path, weights_only=True, mmap=True) - return cls( - component_layer_keys=data["component_layer_keys"], - vocab_size=data["vocab_size"], - d_model=data["d_model"], - source_to_component=data["source_to_component"], - source_to_out_residual=data["source_to_out_residual"], - n_batches_processed=data["n_batches_processed"], - n_tokens_processed=data["n_tokens_processed"], - ci_threshold=data["ci_threshold"], - ) + for target_layer, attr_matrix in embed.items(): + raw = attr_matrix[:, source_idx] # (tgt_c,) + value_segments.append( + raw / embed_count / self._component_activation_rms(target_layer) + ) + layer_names.append(target_layer) - def get_attribution( - self, - source_key: str, - target_key: str, - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> float: - """Get attribution strength from source to target. - - Args: - source_key: Source component key (wte or component layer) - target_key: Target component key (component layer or output token) - w_unembed: Unembedding matrix, required if target is an output token - """ - src_idx = self._source_idx(source_key) + if include_outputs and metric == "attr": + residual = self._embed_unembed_attr[:, source_idx] # (d_model,) + raw = residual @ self._w_unembed # (vocab,) + value_segments.append(raw / embed_count / self._logit_activation_rms()) + layer_names.append("output") + else: + regular, embed = self._select_metric(metric) + ci = self._layer_ci_sum(source_layer)[source_idx] - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - return (self.source_to_out_residual[src_idx] @ w_unembed[:, token_id]).item() + for target_layer, sources in regular.items(): + if source_layer not in sources: + continue + raw = sources[source_layer][:, source_idx] # (tgt_c,) + value_segments.append(raw / ci / self._component_activation_rms(target_layer)) + layer_names.append(target_layer) - tgt_idx = self._component_target_idx(target_key) - return self.source_to_component[src_idx, tgt_idx].item() + if include_outputs and metric == "attr" and source_layer in self._unembed_attr: + residual = self._unembed_attr[source_layer][:, source_idx] # (d_model,) + raw = residual @ self._w_unembed # (vocab,) + value_segments.append(raw / ci / self._logit_activation_rms()) + layer_names.append("output") - def _get_top_k( + return self._top_k_from_segments(value_segments, layer_names, k, sign) + + def _top_k_from_segments( self, - values: Tensor, + value_segments: list[Tensor], + layer_names: list[str], k: int, sign: Literal["positive", "negative"], - idx_to_key: Callable[[int], str], ) -> list[DatasetAttributionEntry]: - """Get top-k entries from a 1D tensor of attribution values.""" + if not value_segments: + return [] + + all_values = torch.cat(value_segments) + offsets = [0] + for seg in value_segments: + offsets.append(offsets[-1] + len(seg)) + is_positive = sign == "positive" - top_vals, top_idxs = torch.topk(values, min(k, len(values)), largest=is_positive) + top_vals, top_idxs = torch.topk(all_values, min(k, len(all_values)), largest=is_positive) - # Filter to only values matching the requested sign mask = top_vals > 0 if is_positive else top_vals < 0 top_vals, top_idxs = top_vals[mask], top_idxs[mask] results = [] - for idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): - key = idx_to_key(idx) - layer, c_idx = self._parse_key(key) + for flat_idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): + seg_idx = bisect.bisect_right(offsets, flat_idx) - 1 + local_idx = flat_idx - offsets[seg_idx] + layer = layer_names[seg_idx] results.append( DatasetAttributionEntry( - component_key=key, + component_key=f"{layer}:{local_idx}", layer=layer, - component_idx=c_idx, + component_idx=local_idx, value=val, ) ) return results - def get_top_sources( - self, - target_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target. - - Args: - target_key: Target component key (component layer or output token) - k: Number of top sources to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if target is an output token - """ - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - values = self.source_to_out_residual @ w_unembed[:, token_id] # (n_sources,) + def get_attribution(self, source_key: str, target_key: str) -> float: + source_layer, source_idx = self._parse_key(source_key) + target_layer, target_idx = self._parse_key(target_key) + + if target_layer == "output" and source_layer == "embed": + return (self._embed_unembed_attr[:, source_idx] @ self._w_unembed[:, target_idx]).item() + elif target_layer == "output" and source_layer != "embed": + return ( + self._unembed_attr[source_layer][:, source_idx] @ self._w_unembed[:, target_idx] + ).item() + elif target_layer != "output" and source_layer == "embed": + return (self._embed_attr[target_layer][target_idx, source_idx]).item() else: - tgt_idx = self._component_target_idx(target_key) - values = self.source_to_component[:, tgt_idx] + assert target_layer != "output" and source_layer != "embed" + return (self._regular_attr[target_layer][source_layer][target_idx, source_idx]).item() - return self._get_top_k(values, k, sign, self._source_idx_to_key) + def save(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "regular_attr": _to_cpu_nested(self._regular_attr), + "regular_attr_abs": _to_cpu_nested(self._regular_attr_abs), + "embed_attr": _to_cpu(self._embed_attr), + "embed_attr_abs": _to_cpu(self._embed_attr_abs), + "unembed_attr": _to_cpu(self._unembed_attr), + "embed_unembed_attr": self._embed_unembed_attr.detach().cpu(), + "w_unembed": self._w_unembed.detach().cpu(), + "ci_sum": _to_cpu(self._ci_sum), + "component_act_sq_sum": _to_cpu(self._component_act_sq_sum), + "logit_sq_sum": self._logit_sq_sum.detach().cpu(), + "embed_token_count": self._embed_token_count.detach().cpu(), + "ci_threshold": self.ci_threshold, + "n_tokens_processed": self.n_tokens_processed, + }, + path, + ) + size_mb = path.stat().st_size / (1024 * 1024) + logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") - def get_top_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - include_outputs: bool = True, - ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO. - - Args: - source_key: Source component key (wte or component layer) - k: Number of top targets to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if include_outputs=True - include_outputs: Whether to include output tokens in results + @classmethod + def load(cls, path: Path) -> "DatasetAttributionStorage": + data = torch.load(path, weights_only=True) + return cls( + regular_attr=data["regular_attr"], + regular_attr_abs=data["regular_attr_abs"], + embed_attr=data["embed_attr"], + embed_attr_abs=data["embed_attr_abs"], + unembed_attr=data["unembed_attr"], + embed_unembed_attr=data["embed_unembed_attr"], + w_unembed=data["w_unembed"], + ci_sum=data["ci_sum"], + component_act_sq_sum=data["component_act_sq_sum"], + logit_sq_sum=data["logit_sq_sum"], + embed_token_count=data["embed_token_count"], + ci_threshold=data["ci_threshold"], + n_tokens_processed=data["n_tokens_processed"], + ) + + @classmethod + def merge(cls, paths: list[Path]) -> "DatasetAttributionStorage": + """Merge partial attribution files from parallel workers. + + All stored values are raw sums — merge is element-wise addition. """ - src_idx = self._source_idx(source_key) - comp_values = self.source_to_component[src_idx, :] # (n_components,) + assert paths, "No files to merge" - if include_outputs: - assert w_unembed is not None, "w_unembed required when include_outputs=True" - # Compute attributions to all output tokens - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - all_values = torch.cat([comp_values, output_values]) + merged = cls.load(paths[0]) - def combined_idx_to_key(idx: int) -> str: - if idx < self.n_components: - return self._component_target_idx_to_key(idx) - return self._output_target_idx_to_key(idx - self.n_components) + for path in paths[1:]: + other = cls.load(path) + assert other.ci_threshold == merged.ci_threshold, "CI threshold mismatch" - return self._get_top_k(all_values, k, sign, combined_idx_to_key) + for target, sources in other._regular_attr.items(): + for source, tensor in sources.items(): + merged._regular_attr[target][source] += tensor + merged._regular_attr_abs[target][source] += other._regular_attr_abs[target][ + source + ] - return self._get_top_k(comp_values, k, sign, self._component_target_idx_to_key) + for target, tensor in other._embed_attr.items(): + merged._embed_attr[target] += tensor + merged._embed_attr_abs[target] += other._embed_attr_abs[target] - def get_top_component_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - ) -> list[DatasetAttributionEntry]: - """Get top-k component targets (excluding outputs) this source attributes TO. + for source, tensor in other._unembed_attr.items(): + merged._unembed_attr[source] += tensor - Convenience method that doesn't require w_unembed. - """ - return self.get_top_targets(source_key, k, sign, w_unembed=None, include_outputs=False) + merged._embed_unembed_attr += other._embed_unembed_attr - def get_top_output_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"], - ) -> list[DatasetAttributionEntry]: - """Get top-k output token targets this source attributes TO.""" - src_idx = self._source_idx(source_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - return self._get_top_k(output_values, k, sign, self._output_target_idx_to_key) + for layer in other._ci_sum: + merged._ci_sum[layer] += other._ci_sum[layer] + + for layer in other._component_act_sq_sum: + merged._component_act_sq_sum[layer] += other._component_act_sq_sum[layer] + + merged._logit_sq_sum += other._logit_sq_sum + merged._embed_token_count += other._embed_token_count + merged.n_tokens_processed += other.n_tokens_processed + + return merged + + +def _to_cpu_nested(d: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + target: {source: v.detach().cpu() for source, v in sources.items()} + for target, sources in d.items() + } + + +def _to_cpu(d: dict[str, Tensor]) -> dict[str, Tensor]: + return {k: v.detach().cpu() for k, v in d.items()} diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index 0d57991a9..50f3b51f6 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -56,13 +56,17 @@ loss_metric_configs: classname: PersistentPGDReconSubsetLoss optimizer: type: adam - lr: 0.1 + lr_schedule: + start_val: 0.1 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.9 beta2: 0.99 eps: 1.0e-08 scope: - type: batch_invariant - n_masks: 8 + type: repeat_across_batch + n_sources: 8 routing: type: uniform_k_subset - coeff: 1000000.0 @@ -102,10 +106,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml index 463782776..e1430c4e0 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml @@ -74,7 +74,6 @@ lr_schedule: fn_type: constant steps: 5000 batch_size: 64 -gradient_accumulation_steps: 1 grad_clip_norm_components: 0.01 grad_clip_norm_ci_fns: null faithfulness_warmup_steps: 0 @@ -106,10 +105,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml index 77d3d13f6..3e145fdc1 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml @@ -56,13 +56,17 @@ loss_metric_configs: classname: PersistentPGDReconSubsetLoss optimizer: type: adam - lr: 0.01 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.8 beta2: 0.99 eps: 1.0e-08 scope: - type: batch_invariant - n_masks: 8 + type: repeat_across_batch + n_sources: 8 routing: type: uniform_k_subset - coeff: 1000000.0 @@ -75,7 +79,6 @@ lr_schedule: fn_type: cosine steps: 400000 batch_size: 128 -gradient_accumulation_steps: 1 grad_clip_norm_components: 0.01 grad_clip_norm_ci_fns: null faithfulness_warmup_steps: 400 @@ -107,10 +110,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index 4b5e82bb8..5ec8190be 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -40,7 +40,7 @@ module_info: identity_module_info: null use_delta_component: true loss_metric_configs: -- coeff: 0.0004 +- coeff: 0.0006 classname: ImportanceMinimalityLoss pnorm: 2.0 beta: 0.2 @@ -56,10 +56,14 @@ loss_metric_configs: classname: PersistentPGDReconLoss optimizer: type: adam - lr: 0.01 beta1: 0.5 beta2: 0.99 eps: 1.0e-08 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.025 + final_val_frac: 0.1 + fn_type: cosine scope: type: per_batch_per_position use_sigmoid_parameterization: false @@ -105,10 +109,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index 1bb58ab2a..de5f9a4d6 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -80,10 +80,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index 7efa90d05..e1ff2a0d9 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -82,10 +82,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index 11082574f..b9a330e48 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -80,10 +80,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index 3ed7a7b0d..34d6106ff 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -82,10 +82,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index 8af7f6601..146bc1362 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -74,10 +74,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index 5e938cd5d..79ab8fea7 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -83,9 +83,7 @@ eval_metric_configs: rounding_threshold: 0 - classname: CIMeanPerComponent - classname: StochasticHiddenActsReconLoss - coeff: null - classname: PGDReconLoss - coeff: null init: random step_size: 0.1 n_steps: 20 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml index c5d7d684e..84c63b083 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml @@ -101,10 +101,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index ac123c4eb..16c9fa644 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -80,10 +80,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index 4dcff849b..f77da81ff 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -105,10 +105,8 @@ eval_metric_configs: - h.2.* all_but_layer_3: - h.3.* -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/z-jan22.yaml b/spd/experiments/lm/z-jan22.yaml index ffb11f3e8..e0d14f13f 100644 --- a/spd/experiments/lm/z-jan22.yaml +++ b/spd/experiments/lm/z-jan22.yaml @@ -104,10 +104,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/z-jan22_ppgd.yaml b/spd/experiments/lm/z-jan22_ppgd.yaml index ff1b442e4..d2f4b63e7 100644 --- a/spd/experiments/lm/z-jan22_ppgd.yaml +++ b/spd/experiments/lm/z-jan22_ppgd.yaml @@ -75,10 +75,15 @@ loss_metric_configs: coeff: 0.5 optimizer: type: adam - lr: 0.01 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.8 beta2: 0.99 - scope: unique_per_batch_per_token + scope: + type: per_batch_per_position routing: type: uniform_k_subset output_loss_type: kl diff --git a/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml b/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml index 58c948356..d5a102eaa 100644 --- a/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml +++ b/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml @@ -85,7 +85,11 @@ loss_metric_configs: coeff: 0.5 optimizer: type: adam - lr: 0.01 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.8 beta2: 0.99 scope: diff --git a/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml b/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml index 087d05ded..8487a9ce8 100644 --- a/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml +++ b/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml @@ -50,10 +50,15 @@ loss_metric_configs: coeff: 0.5 optimizer: type: adam - lr: 0.01 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.8 beta2: 0.99 - scope: unique_per_batch_per_token + scope: + type: per_batch_per_position routing: type: uniform_k_subset output_loss_type: kl diff --git a/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml index c74b9e739..b20a94d06 100644 --- a/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml +++ b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml @@ -65,10 +65,15 @@ loss_metric_configs: coeff: 0.5 optimizer: type: adam - lr: 0.01 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.8 beta2: 0.99 - scope: unique_per_batch_per_token + scope: + type: per_batch_per_position routing: type: uniform_k_subset output_loss_type: kl diff --git a/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml index 239d435bb..de678b368 100644 --- a/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml +++ b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml @@ -50,7 +50,11 @@ loss_metric_configs: coeff: 0.5 optimizer: type: adam - lr: 0.01 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.8 beta2: 0.99 scope: diff --git a/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml index dd721c9e9..4e708862e 100644 --- a/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml @@ -94,7 +94,6 @@ eval_metric_configs: groups: null - classname: "CIMeanPerComponent" - classname: "StochasticHiddenActsReconLoss" - coeff: null # --- Pretrained model info --- pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" diff --git a/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml index bd111ff84..8c239ee60 100644 --- a/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml @@ -97,9 +97,7 @@ eval_metric_configs: groups: null - classname: "CIMeanPerComponent" - classname: "StochasticHiddenActsReconLoss" - coeff: null - classname: "PGDReconLoss" - coeff: null init: random step_size: 0.1 n_steps: 20 diff --git a/spd/graph_interp/CLAUDE.md b/spd/graph_interp/CLAUDE.md new file mode 100644 index 000000000..327db0e7c --- /dev/null +++ b/spd/graph_interp/CLAUDE.md @@ -0,0 +1,71 @@ +# Graph Interpretation Module + +Context-aware component labeling using network graph structure. Unlike standard autointerp (one-shot per component), this module uses dataset attributions to provide graph context: each component's prompt includes labels from already-labeled components connected via the attribution graph. + +## Usage + +```bash +# Via SLURM (standalone) +spd-graph-interp --config config.yaml + +# Direct execution +python -m spd.graph_interp.scripts.run --config_json '{...}' +``` + +Requires `OPENROUTER_API_KEY` env var. Requires both harvest data and dataset attributions to exist. + +## Three-Phase Pipeline + +1. **Output pass** (late → early): "What does this component DO?" Each component's prompt includes top-K downstream components (by attribution) with their labels. Late layers labeled first so earlier layers see labeled downstream context. + +2. **Input pass** (early → late): "What TRIGGERS this component?" Each component's prompt includes top-K upstream components (by attribution) + co-firing components (Jaccard/PMI). Early layers labeled first so later layers see labeled upstream context. Independent of the output pass. + +3. **Unification** (parallel): Synthesizes output + input labels into a single unified label per component. + +All three phases run in a single invocation. Resume is per-phase via completed key sets in the DB. + +## Data Storage + +``` +SPD_OUT_DIR/graph_interp// +└── ti-YYYYMMDD_HHMMSS/ + ├── interp.db # SQLite: output_labels, input_labels, unified_labels, prompt_edges + └── config.yaml +``` + +## Database Schema + +- `output_labels`: component_key → label, confidence, reasoning, raw_response, prompt +- `input_labels`: same schema as output_labels +- `unified_labels`: same schema as output_labels +- `prompt_edges`: directed filtered graph of (component, related_key, pass, attribution, related_label) +- `config`: key-value store + +## Architecture + +| File | Purpose | +|------|---------| +| `config.py` | `GraphInterpConfig`, `GraphInterpSlurmConfig` | +| `schemas.py` | `LabelResult`, `PromptEdge`, path helpers | +| `db.py` | `GraphInterpDB` — SQLite with WAL mode | +| `ordering.py` | Topological sort via `CanonicalWeight` from topology module | +| `graph_context.py` | `RelatedComponent`, gather attributed + co-firing components | +| `prompts.py` | Three prompt formatters (output, input, unification) | +| `interpret.py` | Main three-phase execution loop | +| `repo.py` | `GraphInterpRepo` — read-only access to results | +| `scripts/run.py` | CLI entry point (called by SLURM) | +| `scripts/run_slurm.py` | SLURM submission | +| `scripts/run_slurm_cli.py` | Thin CLI wrapper for `spd-graph-interp` | + +## Dependencies + +- Harvest data (component stats, correlations, token stats) +- Dataset attributions (component-to-component attribution strengths) +- Reuses `map_llm_calls` from `spd/autointerp/llm_api.py` +- Reuses prompt helpers from `spd/autointerp/prompt_helpers.py` + +## SLURM Integration + +- 0 GPUs, 16 CPUs, 240GB memory (CPU-only, LLM API calls) +- Depends on both harvest merge AND attribution merge jobs +- Entry point: `spd-graph-interp` diff --git a/spd/graph_interp/__init__.py b/spd/graph_interp/__init__.py new file mode 100644 index 000000000..61e182fda --- /dev/null +++ b/spd/graph_interp/__init__.py @@ -0,0 +1 @@ +"""Graph interpretation: context-aware component labeling using graph structure.""" diff --git a/spd/graph_interp/config.py b/spd/graph_interp/config.py new file mode 100644 index 000000000..e6e7441d3 --- /dev/null +++ b/spd/graph_interp/config.py @@ -0,0 +1,26 @@ +"""Graph interpretation configuration.""" + +from openrouter.components import Effort + +from spd.base_config import BaseConfig +from spd.dataset_attributions.storage import AttrMetric +from spd.settings import DEFAULT_PARTITION_NAME + + +class GraphInterpConfig(BaseConfig): + model: str = "google/gemini-3-flash-preview" + reasoning_effort: Effort = "low" + attr_metric: AttrMetric = "attr_abs" + top_k_attributed: int = 8 + max_examples: int = 20 + label_max_words: int = 8 + cost_limit_usd: float | None = None + max_requests_per_minute: int = 500 + max_concurrent: int = 50 + limit: int | None = None + + +class GraphInterpSlurmConfig(BaseConfig): + config: GraphInterpConfig + partition: str = DEFAULT_PARTITION_NAME + time: str = "24:00:00" diff --git a/spd/graph_interp/db.py b/spd/graph_interp/db.py new file mode 100644 index 000000000..0af8c796f --- /dev/null +++ b/spd/graph_interp/db.py @@ -0,0 +1,228 @@ +"""SQLite database for graph interpretation data.""" + +import sqlite3 +from pathlib import Path + +from spd.graph_interp.schemas import LabelResult, PromptEdge + +_SCHEMA = """\ +CREATE TABLE IF NOT EXISTS output_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS input_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS unified_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS prompt_edges ( + component_key TEXT NOT NULL, + related_key TEXT NOT NULL, + pass TEXT NOT NULL, + attribution REAL NOT NULL, + related_label TEXT, + related_confidence TEXT, + PRIMARY KEY (component_key, related_key, pass) +); + +CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); +""" + + +class GraphInterpDB: + def __init__(self, db_path: Path, readonly: bool = False) -> None: + if readonly: + self._conn = sqlite3.connect( + f"file:{db_path}?immutable=1", uri=True, check_same_thread=False + ) + else: + self._conn = sqlite3.connect(str(db_path), check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.executescript(_SCHEMA) + self._conn.row_factory = sqlite3.Row + + # -- Output labels --------------------------------------------------------- + + def save_output_label(self, result: LabelResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO output_labels VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def get_output_label(self, component_key: str) -> LabelResult | None: + row = self._conn.execute( + "SELECT * FROM output_labels WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def get_all_output_labels(self) -> dict[str, LabelResult]: + rows = self._conn.execute("SELECT * FROM output_labels").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + def get_completed_output_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM output_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Input labels ---------------------------------------------------------- + + def save_input_label(self, result: LabelResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO input_labels VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def get_input_label(self, component_key: str) -> LabelResult | None: + row = self._conn.execute( + "SELECT * FROM input_labels WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def get_all_input_labels(self) -> dict[str, LabelResult]: + rows = self._conn.execute("SELECT * FROM input_labels").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + def get_completed_input_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM input_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Unified labels -------------------------------------------------------- + + def save_unified_label(self, result: LabelResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO unified_labels VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def get_unified_label(self, component_key: str) -> LabelResult | None: + row = self._conn.execute( + "SELECT * FROM unified_labels WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def get_all_unified_labels(self) -> dict[str, LabelResult]: + rows = self._conn.execute("SELECT * FROM unified_labels").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + def get_completed_unified_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM unified_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Prompt edges ---------------------------------------------------------- + + def save_prompt_edges(self, edges: list[PromptEdge]) -> None: + rows = [ + ( + e.component_key, + e.related_key, + e.pass_name, + e.attribution, + e.related_label, + e.related_confidence, + ) + for e in edges + ] + self._conn.executemany( + "INSERT OR REPLACE INTO prompt_edges VALUES (?, ?, ?, ?, ?, ?)", + rows, + ) + self._conn.commit() + + def get_prompt_edges(self, component_key: str) -> list[PromptEdge]: + rows = self._conn.execute( + "SELECT * FROM prompt_edges WHERE component_key = ?", (component_key,) + ).fetchall() + return [_row_to_prompt_edge(row) for row in rows] + + def get_all_prompt_edges(self) -> list[PromptEdge]: + rows = self._conn.execute("SELECT * FROM prompt_edges").fetchall() + return [_row_to_prompt_edge(row) for row in rows] + + # -- Config ---------------------------------------------------------------- + + def save_config(self, key: str, value: str) -> None: + self._conn.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", (key, value)) + self._conn.commit() + + # -- Stats ----------------------------------------------------------------- + + def get_label_count(self, table: str) -> int: + assert table in ("output_labels", "input_labels", "unified_labels") + row = self._conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() + assert row is not None + return row[0] + + def close(self) -> None: + self._conn.close() + + +def _row_to_label_result(row: sqlite3.Row) -> LabelResult: + return LabelResult( + component_key=row["component_key"], + label=row["label"], + confidence=row["confidence"], + reasoning=row["reasoning"], + raw_response=row["raw_response"], + prompt=row["prompt"], + ) + + +def _row_to_prompt_edge(row: sqlite3.Row) -> PromptEdge: + return PromptEdge( + component_key=row["component_key"], + related_key=row["related_key"], + pass_name=row["pass"], + attribution=row["attribution"], + related_label=row["related_label"], + related_confidence=row["related_confidence"], + ) diff --git a/spd/graph_interp/graph_context.py b/spd/graph_interp/graph_context.py new file mode 100644 index 000000000..9ac08ad73 --- /dev/null +++ b/spd/graph_interp/graph_context.py @@ -0,0 +1,98 @@ +"""Gather related components from attribution graph.""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal + +from spd.dataset_attributions.storage import DatasetAttributionEntry +from spd.graph_interp.ordering import parse_component_key +from spd.graph_interp.schemas import LabelResult +from spd.harvest.analysis import get_correlated_components +from spd.harvest.storage import CorrelationStorage + + +@dataclass +class RelatedComponent: + component_key: str + attribution: float + label: str | None + confidence: str | None + jaccard: float | None + pmi: float | None + + +GetAttributed = Callable[[str, int, Literal["positive", "negative"]], list[DatasetAttributionEntry]] + + +def get_related_components( + component_key: str, + get_attributed: GetAttributed, + correlation_storage: CorrelationStorage, + labels_so_far: dict[str, LabelResult], + k: int, +) -> list[RelatedComponent]: + """Top-K components connected via attribution, enriched with co-firing stats and labels.""" + my_layer, _ = parse_component_key(component_key) + + pos = get_attributed(component_key, k * 2, "positive") + neg = get_attributed(component_key, k * 2, "negative") + + candidates = pos + neg + candidates.sort(key=lambda e: abs(e.value), reverse=True) + candidates = candidates[:k] + + cofiring = _build_cofiring_lookup(component_key, correlation_storage, k * 3) + result = [_build_related(e.component_key, e.value, cofiring, labels_so_far) for e in candidates] + + for r in result: + r_layer, _ = parse_component_key(r.component_key) + assert r_layer != my_layer, ( + f"Same-layer component {r.component_key} in related list for {component_key}" + ) + + return result + + +def _build_cofiring_lookup( + component_key: str, + correlation_storage: CorrelationStorage, + k: int, +) -> dict[str, tuple[float, float | None]]: + lookup: dict[str, tuple[float, float | None]] = {} + + jaccard_results = get_correlated_components( + correlation_storage, component_key, metric="jaccard", top_k=k + ) + for c in jaccard_results: + lookup[c.component_key] = (c.score, None) + + pmi_results = get_correlated_components( + correlation_storage, component_key, metric="pmi", top_k=k + ) + for c in pmi_results: + if c.component_key in lookup: + jaccard_val = lookup[c.component_key][0] + lookup[c.component_key] = (jaccard_val, c.score) + else: + lookup[c.component_key] = (0.0, c.score) + + return lookup + + +def _build_related( + related_key: str, + attribution: float, + cofiring: dict[str, tuple[float, float | None]], + labels_so_far: dict[str, LabelResult], +) -> RelatedComponent: + label = labels_so_far.get(related_key) + jaccard, pmi = cofiring.get(related_key, (None, None)) + + return RelatedComponent( + component_key=related_key, + attribution=attribution, + label=label.label if label else None, + confidence=label.confidence if label else None, + jaccard=jaccard, + pmi=pmi, + ) diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py new file mode 100644 index 000000000..d018f77e0 --- /dev/null +++ b/spd/graph_interp/interpret.py @@ -0,0 +1,377 @@ +"""Main three-phase graph interpretation execution. + +Structure: + output_labels = scan(layers_reversed, step) + input_labels = scan(layers_forward, step) + unified = map(output_labels + input_labels, unify) + +Each scan folds over layers. Within a layer, components are labeled in parallel +via async LLM calls. The fold accumulator (labels_so_far) lets each component's +prompt include labels from previously-processed layers. +""" + +import asyncio +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable +from functools import partial +from pathlib import Path +from typing import Literal + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls +from spd.autointerp.schemas import ModelMetadata +from spd.dataset_attributions.storage import ( + AttrMetric, + DatasetAttributionEntry, + DatasetAttributionStorage, +) +from spd.graph_interp import graph_context +from spd.graph_interp.config import GraphInterpConfig +from spd.graph_interp.db import GraphInterpDB +from spd.graph_interp.graph_context import RelatedComponent, get_related_components +from spd.graph_interp.ordering import group_and_sort_by_layer +from spd.graph_interp.prompts import ( + LABEL_SCHEMA, + format_input_prompt, + format_output_prompt, + format_unification_prompt, +) +from spd.graph_interp.schemas import LabelResult, PromptEdge +from spd.harvest.analysis import get_input_token_stats, get_output_token_stats +from spd.harvest.repo import HarvestRepo +from spd.harvest.storage import CorrelationStorage, TokenStatsStorage +from spd.log import logger + +GetRelated = Callable[[str, dict[str, LabelResult]], list[RelatedComponent]] +Step = Callable[[list[str], dict[str, LabelResult]], Awaitable[dict[str, LabelResult]]] + + +def run_graph_interp( + openrouter_api_key: str, + config: GraphInterpConfig, + harvest: HarvestRepo, + attribution_storage: DatasetAttributionStorage, + correlation_storage: CorrelationStorage, + token_stats: TokenStatsStorage, + model_metadata: ModelMetadata, + db_path: Path, + tokenizer_name: str, +) -> None: + logger.info("Loading tokenizer...") + app_tok = AppTokenizer.from_pretrained(tokenizer_name) + + logger.info("Loading component summaries...") + summaries = harvest.get_summary() + alive = {k: s for k, s in summaries.items() if s.firing_density > 0.0} + all_keys = sorted(alive, key=lambda k: alive[k].firing_density, reverse=True) + if config.limit is not None: + all_keys = all_keys[: config.limit] + + layers = group_and_sort_by_layer(all_keys, model_metadata.layer_descriptions) + total = len(all_keys) + logger.info(f"Graph interp: {total} components across {len(layers)} layers") + + # -- Injected behaviours --------------------------------------------------- + + async def llm_map( + jobs: Iterable[LLMJob], n_total: int | None = None + ) -> AsyncGenerator[LLMResult | LLMError]: + async for result in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=config.model, + reasoning_effort=config.reasoning_effort, + jobs=jobs, + max_tokens=8000, + max_concurrent=config.max_concurrent, + max_requests_per_minute=config.max_requests_per_minute, + cost_limit_usd=config.cost_limit_usd, + response_schema=LABEL_SCHEMA, + n_total=n_total, + ): + yield result + + concrete_to_canon = model_metadata.layer_descriptions + canon_to_concrete = {v: k for k, v in concrete_to_canon.items()} + + def _translate_entries(entries: list[DatasetAttributionEntry]) -> list[DatasetAttributionEntry]: + for e in entries: + if e.layer in canon_to_concrete: + e.layer = canon_to_concrete[e.layer] + e.component_key = f"{e.layer}:{e.component_idx}" + return entries + + def _to_canon(concrete_key: str) -> str: + layer, idx = concrete_key.rsplit(":", 1) + return f"{concrete_to_canon[layer]}:{idx}" + + def _make_get_targets(metric: AttrMetric) -> "graph_context.GetAttributed": + def get( + key: str, k: int, sign: Literal["positive", "negative"] + ) -> list[DatasetAttributionEntry]: + return _translate_entries( + attribution_storage.get_top_targets(_to_canon(key), k=k, sign=sign, metric=metric) + ) + + return get + + def _make_get_sources(metric: AttrMetric) -> "graph_context.GetAttributed": + def get( + key: str, k: int, sign: Literal["positive", "negative"] + ) -> list[DatasetAttributionEntry]: + return _translate_entries( + attribution_storage.get_top_sources(_to_canon(key), k=k, sign=sign, metric=metric) + ) + + return get + + def _get_related(get_attributed: "graph_context.GetAttributed") -> GetRelated: + def get(key: str, labels_so_far: dict[str, LabelResult]) -> list[RelatedComponent]: + return get_related_components( + key, + get_attributed, + correlation_storage, + labels_so_far, + config.top_k_attributed, + ) + + return get + + # -- Layer processors ------------------------------------------------------ + + async def process_output_layer( + get_related: GetRelated, + save_label: Callable[[LabelResult], None], + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + o_stats = get_output_token_stats(token_stats, key, app_tok, top_k=50) + assert o_stats is not None, f"No output token stats for {key}" + + related = get_related(key, labels_so_far) + _save_edges(db, key, related, "output") + prompt = format_output_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + output_token_stats=o_stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + async def process_input_layer( + get_related: GetRelated, + save_label: Callable[[LabelResult], None], + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + i_stats = get_input_token_stats(token_stats, key, app_tok, top_k=20) + assert i_stats is not None, f"No input token stats for {key}" + + related = get_related(key, labels_so_far) + _save_edges(db, key, related, "input") + prompt = format_input_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + input_token_stats=i_stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + # -- Scan (fold over layers) ----------------------------------------------- + + async def scan( + layer_order: list[tuple[str, list[str]]], + initial: dict[str, LabelResult], + step: Step, + ) -> dict[str, LabelResult]: + labels = dict(initial) + if labels: + logger.info(f"Resuming, {len(labels)} already completed") + + completed_so_far = 0 + for layer, keys in layer_order: + pending = [k for k in keys if k not in labels] + if not pending: + completed_so_far += len(keys) + continue + + new_labels = await step(pending, labels) + labels.update(new_labels) + + completed_so_far += len(keys) + logger.info(f"Completed layer {layer} ({completed_so_far}/{total})") + + return labels + + # -- Map (parallel over all components) ------------------------------------ + + async def map_unify( + output_labels: dict[str, LabelResult], + input_labels: dict[str, LabelResult], + ) -> None: + completed = db.get_completed_unified_keys() + keys = [k for k in all_keys if k not in completed] + if not keys: + logger.info("Unification: all labels already completed") + return + if completed: + logger.info(f"Unification: resuming, {len(completed)} already completed") + + n_skipped = 0 + + def jobs() -> Iterable[LLMJob]: + nonlocal n_skipped + for key in keys: + out = output_labels.get(key) + inp = input_labels.get(key) + if out is None or inp is None: + n_skipped += 1 + continue + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + prompt = format_unification_prompt( + output_label=out, + input_label=inp, + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + logger.info(f"Unifying {len(keys)} components") + new_labels = await _collect_labels(llm_map, jobs(), len(keys), db.save_unified_label) + + if n_skipped: + logger.warning(f"Skipped {n_skipped} components missing output or input labels") + logger.info(f"Unification: completed {len(new_labels)}/{len(keys)}") + + # -- Run ------------------------------------------------------------------- + + logger.info("Initializing DB and building scan steps...") + db = GraphInterpDB(db_path) + + metric = config.attr_metric + get_targets = _make_get_targets(metric) + get_sources = _make_get_sources(metric) + + label_output = partial( + process_output_layer, + _get_related(get_targets), + db.save_output_label, + ) + label_input = partial( + process_input_layer, + _get_related(get_sources), + db.save_input_label, + ) + + async def _run() -> None: + logger.section("Phase 1: Output pass (late → early)") + output_labels = await scan(list(reversed(layers)), db.get_all_output_labels(), label_output) + + logger.section("Phase 2: Input pass (early → late)") + input_labels = await scan(list(layers), db.get_all_input_labels(), label_input) + + logger.section("Phase 3: Unification") + await map_unify(output_labels, input_labels) + + logger.info( + f"Completed: {db.get_label_count('output_labels')} output, " + f"{db.get_label_count('input_labels')} input, " + f"{db.get_label_count('unified_labels')} unified labels -> {db_path}" + ) + + try: + asyncio.run(_run()) + finally: + db.close() + + +# -- Shared LLM call machinery ------------------------------------------------ + + +async def _collect_labels( + llm_map: Callable[[Iterable[LLMJob], int | None], AsyncGenerator[LLMResult | LLMError]], + jobs: Iterable[LLMJob], + n_total: int, + save_label: Callable[[LabelResult], None], +) -> dict[str, LabelResult]: + """Run LLM jobs, parse results, save to DB, return new labels.""" + new_labels: dict[str, LabelResult] = {} + n_errors = 0 + + async for outcome in llm_map(jobs, n_total): + match outcome: + case LLMResult(job=job, parsed=parsed, raw=raw): + result = _parse_label(job.key, parsed, raw, job.prompt) + save_label(result) + new_labels[job.key] = result + case LLMError(job=job, error=e): + n_errors += 1 + logger.error(f"Skipping {job.key}: {type(e).__name__}: {e}") + _check_error_rate(n_errors, len(new_labels)) + + return new_labels + + +def _parse_label(key: str, parsed: dict[str, object], raw: str, prompt: str) -> LabelResult: + assert len(parsed) == 3, f"Expected 3 fields, got {len(parsed)}" + label = parsed["label"] + confidence = parsed["confidence"] + reasoning = parsed["reasoning"] + assert isinstance(label, str) and isinstance(confidence, str) and isinstance(reasoning, str) + return LabelResult( + component_key=key, + label=label, + confidence=confidence, + reasoning=reasoning, + raw_response=raw, + prompt=prompt, + ) + + +def _check_error_rate(n_errors: int, n_done: int) -> None: + total = n_errors + n_done + if total > 10 and n_errors / total > 0.05: + raise RuntimeError( + f"Error rate {n_errors / total:.0%} ({n_errors}/{total}) exceeds 5% threshold" + ) + + +def _save_edges( + db: GraphInterpDB, + component_key: str, + related: list[RelatedComponent], + pass_name: Literal["output", "input"], +) -> None: + edges = [ + PromptEdge( + component_key=component_key, + related_key=r.component_key, + pass_name=pass_name, + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] + if edges: + db.save_prompt_edges(edges) diff --git a/spd/graph_interp/ordering.py b/spd/graph_interp/ordering.py new file mode 100644 index 000000000..9ef4d5afa --- /dev/null +++ b/spd/graph_interp/ordering.py @@ -0,0 +1,88 @@ +"""Layer ordering for graph interpretation. + +Uses the topology module's CanonicalWeight system for correct ordering +across all model architectures. Canonical addresses are provided by +ModelMetadata.layer_descriptions (concrete path → canonical string). +""" + +from spd.topology.canonical import ( + CanonicalWeight, + FusedAttnWeight, + GLUWeight, + LayerWeight, + MLPWeight, + SeparateAttnWeight, +) + +_SUBLAYER_ORDER = {"attn": 0, "attn_fused": 0, "glu": 1, "mlp": 1} + +_PROJECTION_ORDER: dict[type, dict[str, int]] = { + SeparateAttnWeight: {"q": 0, "k": 1, "v": 2, "o": 3}, + FusedAttnWeight: {"qkv": 0, "o": 1}, + GLUWeight: {"gate": 0, "up": 1, "down": 2}, + MLPWeight: {"up": 0, "down": 1}, +} + + +def canonical_sort_key(canonical: str) -> tuple[int, int, int]: + """Sort key for a canonical address string like '0.attn.q' or '1.mlp.down'.""" + weight = CanonicalWeight.parse(canonical) + assert isinstance(weight, LayerWeight), f"Expected LayerWeight, got {type(weight).__name__}" + + match weight.name: + case SeparateAttnWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["attn"] + proj_idx = _PROJECTION_ORDER[SeparateAttnWeight][p] + case FusedAttnWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["attn_fused"] + proj_idx = _PROJECTION_ORDER[FusedAttnWeight][p] + case GLUWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["glu"] + proj_idx = _PROJECTION_ORDER[GLUWeight][p] + case MLPWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["mlp"] + proj_idx = _PROJECTION_ORDER[MLPWeight][p] + + return weight.layer_idx, sublayer_idx, proj_idx + + +def parse_component_key(key: str) -> tuple[str, int]: + """Split 'h.1.mlp.c_fc:42' into ('h.1.mlp.c_fc', 42).""" + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +def group_and_sort_by_layer( + component_keys: list[str], + layer_descriptions: dict[str, str], +) -> list[tuple[str, list[str]]]: + """Group component keys by layer, return [(layer, [keys])] in topological order. + + Args: + component_keys: Component keys like 'h.0.attn.q_proj:42'. + layer_descriptions: Mapping from concrete layer path to canonical address + (from ModelMetadata.layer_descriptions). + """ + by_layer: dict[str, list[str]] = {} + for key in component_keys: + layer, _ = parse_component_key(key) + by_layer.setdefault(layer, []).append(key) + + def sort_key(layer: str) -> tuple[int, int, int]: + canonical = layer_descriptions[layer] + return canonical_sort_key(canonical) + + sorted_layers = sorted(by_layer.keys(), key=sort_key) + + result: list[tuple[str, list[str]]] = [] + for layer in sorted_layers: + keys = sorted(by_layer[layer], key=lambda k: parse_component_key(k)[1]) + result.append((layer, keys)) + return result + + +def is_later_layer(earlier: str, later: str, layer_descriptions: dict[str, str]) -> bool: + """Check if `later` is topologically after `earlier`.""" + return canonical_sort_key(layer_descriptions[earlier]) < canonical_sort_key( + layer_descriptions[later] + ) diff --git a/spd/graph_interp/prompts.py b/spd/graph_interp/prompts.py new file mode 100644 index 000000000..01874a745 --- /dev/null +++ b/spd/graph_interp/prompts.py @@ -0,0 +1,227 @@ +"""Prompt formatters for graph interpretation. + +Three prompts: +1. Output pass (late→early): "What does this component DO?" — output tokens, says examples, downstream +2. Input pass (early→late): "What TRIGGERS this component?" — input tokens, fires-on examples, upstream +3. Unification: Synthesize output + input labels into unified label. +""" + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.prompt_helpers import ( + build_fires_on_examples, + build_input_section, + build_output_section, + build_says_examples, + density_note, + human_layer_desc, + layer_position_note, +) +from spd.autointerp.schemas import ModelMetadata +from spd.graph_interp.graph_context import RelatedComponent +from spd.graph_interp.schemas import LabelResult +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData + +LABEL_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "label": {"type": "string"}, + "confidence": {"type": "string", "enum": ["low", "medium", "high"]}, + "reasoning": {"type": "string"}, + }, + "required": ["label", "confidence", "reasoning"], + "additionalProperties": False, +} + + +def _component_header( + component: ComponentData, + model_metadata: ModelMetadata, +) -> str: + canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) + layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) + position_note = layer_position_note(canonical, model_metadata.n_blocks) + dens_note = density_note(component.firing_density) + + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) + + context_notes = " ".join(filter(None, [position_note, dens_note])) + + return f"""\ +## Context +- Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model +- Firing rate: {component.firing_density * 100:.2f}% ({rate_str}) +{context_notes}""" + + +def format_output_prompt( + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + output_token_stats: TokenPRLift, + related: list[RelatedComponent], + label_max_words: int, + max_examples: int, +) -> str: + header = _component_header(component, model_metadata) + + output_pmi = ( + [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] + if component.output_token_pmi.top + else None + ) + output_section = build_output_section(output_token_stats, output_pmi) + says = build_says_examples(component, app_tok, max_examples) + related_table = _format_related_table(related, model_metadata, app_tok) + + return f"""\ +You are analyzing a component in a neural network to understand its OUTPUT FUNCTION — what it does when it fires. + +{header} + +## Output tokens (what the model produces when this component fires) +{output_section} +## Activation examples — what the model produces +{says} +## Downstream components (what this component influences) +These components in later layers are most influenced by this component (by gradient attribution): +{related_table} +## Task +Give a {label_max_words}-word-or-fewer label describing this component's OUTPUT FUNCTION — what it does when it fires. + +Say "unclear" if the evidence is too weak. + +Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} +""" + + +def format_input_prompt( + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + input_token_stats: TokenPRLift, + related: list[RelatedComponent], + label_max_words: int, + max_examples: int, +) -> str: + header = _component_header(component, model_metadata) + + input_pmi = ( + [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] + if component.input_token_pmi.top + else None + ) + input_section = build_input_section(input_token_stats, input_pmi) + fires_on = build_fires_on_examples(component, app_tok, max_examples) + related_table = _format_related_table(related, model_metadata, app_tok) + + return f"""\ +You are analyzing a component in a neural network to understand its INPUT FUNCTION — what triggers it to fire. + +{header} + +## Input tokens (what causes this component to fire) +{input_section} +## Activation examples — where the component fires +{fires_on} +## Upstream components (what feeds into this component) +These components in earlier layers most strongly attribute to this component: +{related_table} +## Task +Give a {label_max_words}-word-or-fewer label describing this component's INPUT FUNCTION — what conditions trigger it to fire. + +Say "unclear" if the evidence is too weak. + +Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} +""" + + +def format_unification_prompt( + output_label: LabelResult, + input_label: LabelResult, + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + label_max_words: int, + max_examples: int, +) -> str: + header = _component_header(component, model_metadata) + fires_on = build_fires_on_examples(component, app_tok, max_examples) + says = build_says_examples(component, app_tok, max_examples) + + return f"""\ +A neural network component has been analyzed from two perspectives. + +{header} + +## Activation examples — where the component fires +{fires_on} +## Activation examples — what the model produces +{says} +## Two-perspective analysis + +OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence}) + Reasoning: {output_label.reasoning} + +INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence}) + Reasoning: {input_label.reasoning} + +## Task +Synthesize these into a single unified label (max {label_max_words} words) that captures the component's complete role. If input and output suggest the same concept, unify them. If they describe genuinely different aspects (e.g. fires on X, produces Y), combine both. + +Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} +""" + + +def _format_related_table( + components: list[RelatedComponent], + model_metadata: ModelMetadata, + app_tok: AppTokenizer, +) -> str: + # Filter: only show labeled components and token entries (embed/output) + visible = [n for n in components if n.label is not None or _is_token_entry(n.component_key)] + if not visible: + return "(no related components with labels found)\n" + + # Normalize attributions: strongest = 1.0 + max_attr = max(abs(n.attribution) for n in visible) + norm = max_attr if max_attr > 0 else 1.0 + + lines: list[str] = [] + for n in visible: + display = _component_display(n.component_key, model_metadata, app_tok) + rel_attr = n.attribution / norm + + parts = [f" {display} (relative attribution: {rel_attr:+.2f}"] + if n.jaccard is not None: + parts.append(f", co-firing Jaccard: {n.jaccard:.3f}") + parts.append(")") + + line = "".join(parts) + if n.label is not None: + line += f'\n label: "{n.label}" (confidence: {n.confidence})' + lines.append(line) + + return "\n".join(lines) + "\n" + + +def _is_token_entry(key: str) -> bool: + layer = key.rsplit(":", 1)[0] + return layer in ("embed", "output") + + +def _component_display(key: str, model_metadata: ModelMetadata, app_tok: AppTokenizer) -> str: + layer, idx_str = key.rsplit(":", 1) + match layer: + case "embed": + return f'input token "{app_tok.get_tok_display(int(idx_str))}"' + case "output": + return f'output token "{app_tok.get_tok_display(int(idx_str))}"' + case _: + canonical = model_metadata.layer_descriptions.get(layer, layer) + desc = human_layer_desc(canonical, model_metadata.n_blocks) + return f"component from {desc}" diff --git a/spd/graph_interp/repo.py b/spd/graph_interp/repo.py new file mode 100644 index 000000000..9906ff138 --- /dev/null +++ b/spd/graph_interp/repo.py @@ -0,0 +1,91 @@ +"""Graph interpretation data repository. + +Owns SPD_OUT_DIR/graph_interp// and provides read access +to output, input, and unified labels. + +Use GraphInterpRepo.open() to construct — returns None if no data exists. +""" + +from pathlib import Path +from typing import Any + +import yaml + +from spd.graph_interp.db import GraphInterpDB +from spd.graph_interp.schemas import LabelResult, PromptEdge, get_graph_interp_dir + + +class GraphInterpRepo: + """Read access to graph interpretation data for a single run.""" + + def __init__(self, db: GraphInterpDB, subrun_dir: Path, run_id: str) -> None: + self._db = db + self._subrun_dir = subrun_dir + self.subrun_id = subrun_dir.name + self.run_id = run_id + + @classmethod + def open(cls, run_id: str) -> "GraphInterpRepo | None": + """Open graph interp data for a run. Returns None if no data exists.""" + base_dir = get_graph_interp_dir(run_id) + if not base_dir.exists(): + return None + candidates = sorted( + [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("ti-")], + key=lambda d: d.name, + ) + if not candidates: + return None + subrun_dir = candidates[-1] + db_path = subrun_dir / "interp.db" + if not db_path.exists(): + return None + return cls( + db=GraphInterpDB(db_path, readonly=True), + subrun_dir=subrun_dir, + run_id=run_id, + ) + + def get_config(self) -> dict[str, Any] | None: + config_path = self._subrun_dir / "config.yaml" + if not config_path.exists(): + return None + with open(config_path) as f: + return yaml.safe_load(f) + + # -- Labels ---------------------------------------------------------------- + + def get_all_output_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_output_labels() + + def get_all_input_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_input_labels() + + def get_all_unified_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_unified_labels() + + def get_output_label(self, component_key: str) -> LabelResult | None: + return self._db.get_output_label(component_key) + + def get_input_label(self, component_key: str) -> LabelResult | None: + return self._db.get_input_label(component_key) + + def get_unified_label(self, component_key: str) -> LabelResult | None: + return self._db.get_unified_label(component_key) + + # -- Edges ----------------------------------------------------------------- + + def get_prompt_edges(self, component_key: str) -> list[PromptEdge]: + return self._db.get_prompt_edges(component_key) + + def get_all_prompt_edges(self) -> list[PromptEdge]: + return self._db.get_all_prompt_edges() + + # -- Stats ----------------------------------------------------------------- + + def get_label_counts(self) -> dict[str, int]: + return { + "output": self._db.get_label_count("output_labels"), + "input": self._db.get_label_count("input_labels"), + "unified": self._db.get_label_count("unified_labels"), + } diff --git a/spd/graph_interp/schemas.py b/spd/graph_interp/schemas.py new file mode 100644 index 000000000..ad391e270 --- /dev/null +++ b/spd/graph_interp/schemas.py @@ -0,0 +1,37 @@ +"""Data types and path helpers for graph interpretation.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +from spd.settings import SPD_OUT_DIR + +GRAPH_INTERP_DIR = SPD_OUT_DIR / "graph_interp" + + +def get_graph_interp_dir(decomposition_id: str) -> Path: + return GRAPH_INTERP_DIR / decomposition_id + + +def get_graph_interp_subrun_dir(decomposition_id: str, subrun_id: str) -> Path: + return get_graph_interp_dir(decomposition_id) / subrun_id + + +@dataclass +class LabelResult: + component_key: str + label: str + confidence: str + reasoning: str + raw_response: str + prompt: str + + +@dataclass +class PromptEdge: + component_key: str + related_key: str + pass_name: Literal["output", "input"] + attribution: float + related_label: str | None + related_confidence: str | None diff --git a/spd/graph_interp/scripts/__init__.py b/spd/graph_interp/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/graph_interp/scripts/export_html.py b/spd/graph_interp/scripts/export_html.py new file mode 100644 index 000000000..14a5955da --- /dev/null +++ b/spd/graph_interp/scripts/export_html.py @@ -0,0 +1,268 @@ +"""Export graph interpretation data to JSON for the static HTML page. + +Usage: + python -m spd.graph_interp.scripts.export_html s-17805b61 + python -m spd.graph_interp.scripts.export_html s-17805b61 --subrun_id ti-20260223_213443 + python -m spd.graph_interp.scripts.export_html s-17805b61 --mock +""" + +import json +import random +from dataclasses import asdict +from typing import Any + +from spd.graph_interp.repo import GraphInterpRepo +from spd.graph_interp.schemas import LabelResult, get_graph_interp_dir +from spd.settings import SPD_OUT_DIR + +WWW_DIR = SPD_OUT_DIR / "www" +DATA_DIR = WWW_DIR / "data" + + +def _label_to_dict(label: LabelResult) -> dict[str, str]: + return { + "label": label.label, + "confidence": label.confidence, + "reasoning": label.reasoning, + } + + +def _parse_component_key(key: str) -> tuple[str, int]: + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +def export_from_repo(repo: GraphInterpRepo) -> dict[str, Any]: + output_labels = repo.get_all_output_labels() + input_labels = repo.get_all_input_labels() + unified_labels = repo.get_all_unified_labels() + + all_keys = sorted( + set(output_labels) | set(input_labels) | set(unified_labels), + key=lambda k: (_parse_component_key(k)[0], _parse_component_key(k)[1]), + ) + + components = [] + for key in all_keys: + layer, component_idx = _parse_component_key(key) + entry: dict[str, Any] = { + "key": key, + "layer": layer, + "component_idx": component_idx, + } + if key in output_labels: + entry["output_label"] = _label_to_dict(output_labels[key]) + if key in input_labels: + entry["input_label"] = _label_to_dict(input_labels[key]) + if key in unified_labels: + entry["unified_label"] = _label_to_dict(unified_labels[key]) + + edges = repo.get_prompt_edges(key) + if edges: + entry["edges"] = [asdict(e) for e in edges] + + components.append(entry) + + label_counts = repo.get_label_counts() + + return { + "decomposition_id": repo.run_id, + "subrun_id": repo.subrun_id, + "label_counts": label_counts, + "components": components, + } + + +def generate_mock_data(decomposition_id: str) -> dict[str, Any]: + random.seed(42) + + layers = [ + "h.0.mlp.c_fc", + "h.0.mlp.down_proj", + "h.0.attn.q_proj", + "h.0.attn.k_proj", + "h.0.attn.v_proj", + "h.0.attn.o_proj", + "h.1.mlp.c_fc", + "h.1.mlp.down_proj", + "h.1.attn.q_proj", + "h.1.attn.k_proj", + "h.1.attn.v_proj", + "h.1.attn.o_proj", + ] + + output_labels_pool = [ + "sentence-final punctuation and period prediction", + "proper nouns and character name completions", + "emotional adjectives describing characters", + "temporal adverbs and time-related transitions", + "morphological suffix completion (-ing, -ed, -ly)", + "determiners preceding concrete nouns", + "dialogue-opening quotation marks and speech verbs", + "plural noun suffixes after quantity words", + "conjunction and clause boundary detection", + "verb tense agreement and auxiliary verbs", + "spatial prepositions and location descriptors", + "possessive pronouns and genitive markers", + "narrative action verbs (walked, looked, said)", + "abstract emotion nouns (fear, joy, anger)", + "comparative and superlative adjective forms", + ] + + input_labels_pool = [ + "punctuation and common function words", + "sentence-initial capital letters and proper nouns", + "mid-sentence verbs following subject nouns", + "adjective-noun boundaries in descriptive phrases", + "clause-final positions before conjunctions", + "article-noun sequences in noun phrases", + "subject pronouns at clause boundaries", + "preposition-object sequences", + "verb stems preceding inflectional suffixes", + "quotation marks and dialogue boundaries", + "comma-separated list items", + "sentence-medial adverbs after auxiliaries", + "concrete nouns following determiners", + "coordinating conjunctions between clauses", + "word stems requiring morphological completion", + ] + + unified_labels_pool = [ + "sentence termination tracking and terminal punctuation prediction", + "character name recognition and proper noun completion", + "emotional state description through adjective selection", + "temporal transition signaling via adverbs and tense markers", + "morphological word completion from stems to suffixed forms", + "noun phrase construction: determiners predicting concrete nouns", + "dialogue framing through quotation marks and speech attribution", + "plural morphology following quantifiers and numerals", + "clause coordination and syntactic boundary marking", + "verbal agreement and auxiliary verb selection", + "spatial relationship encoding via prepositional phrases", + "possessive construction and genitive case marking", + "narrative action sequencing through core verbs", + "abstract emotional vocabulary and sentiment expression", + "degree modification and comparative construction", + ] + + confidences = ["high", "high", "high", "medium", "medium", "low"] + + reasoning_templates = [ + "The output function focuses on {output_focus}, while the input function responds to {input_focus}. Together, this component acts as a bridge between {bridge_from} and {bridge_to}, consistent with its position in {layer}.", + "This component's output pattern of {output_focus} is activated by {input_focus} in the input. The unified interpretation captures how {bridge_from} contexts trigger {bridge_to} predictions.", + "Downstream context shows this component feeds into {output_focus} pathways, while upstream context reveals activation by {input_focus}. The synthesis reflects a coherent role in {bridge_from}-to-{bridge_to} processing.", + ] + + focus_terms = [ + "punctuation patterns", + "noun completions", + "verb inflections", + "emotional descriptors", + "syntactic boundaries", + "morphological suffixes", + "dialogue markers", + "temporal signals", + "spatial relationships", + ] + + components = [] + for layer in layers: + n_components = random.randint(8, 20) + indices = sorted(random.sample(range(500), n_components)) + for idx in indices: + key = f"{layer}:{idx}" + conf = random.choice(confidences) + output_conf = random.choice(confidences) + input_conf = random.choice(confidences) + + output_label = random.choice(output_labels_pool) + input_label = random.choice(input_labels_pool) + unified_label = random.choice(unified_labels_pool) + + reasoning = random.choice(reasoning_templates).format( + output_focus=random.choice(focus_terms), + input_focus=random.choice(focus_terms), + bridge_from=random.choice(focus_terms), + bridge_to=random.choice(focus_terms), + layer=layer, + ) + + components.append( + { + "key": key, + "layer": layer, + "component_idx": idx, + "output_label": { + "label": output_label, + "confidence": output_conf, + "reasoning": f"Output: {reasoning}", + }, + "input_label": { + "label": input_label, + "confidence": input_conf, + "reasoning": f"Input: {reasoning}", + }, + "unified_label": { + "label": unified_label, + "confidence": conf, + "reasoning": reasoning, + }, + } + ) + + return { + "decomposition_id": decomposition_id, + "subrun_id": "ti-mock", + "label_counts": { + "output": len(components), + "input": len(components), + "unified": len(components), + }, + "components": components, + } + + +def main( + decomposition_id: str, + subrun_id: str | None = None, + mock: bool = False, +) -> None: + DATA_DIR.mkdir(parents=True, exist_ok=True) + out_path = DATA_DIR / f"graph_interp_{decomposition_id}.json" + + if mock: + data = generate_mock_data(decomposition_id) + print(f"Generated mock data: {len(data['components'])} components") + else: + if subrun_id is not None: + base_dir = get_graph_interp_dir(decomposition_id) + subrun_dir = base_dir / subrun_id + assert subrun_dir.exists(), f"Subrun dir not found: {subrun_dir}" + db_path = subrun_dir / "interp.db" + assert db_path.exists(), f"No interp.db in {subrun_dir}" + from spd.graph_interp.db import GraphInterpDB + + db = GraphInterpDB(db_path, readonly=True) + repo = GraphInterpRepo(db=db, subrun_dir=subrun_dir, run_id=decomposition_id) + else: + repo = GraphInterpRepo.open(decomposition_id) + if repo is None: + print(f"No graph interp data for {decomposition_id}. Generating mock data instead.") + data = generate_mock_data(decomposition_id) + with open(out_path, "w") as f: + json.dump(data, f) + print(f"Wrote mock data to {out_path}") + return + + data = export_from_repo(repo) + print(f"Exported {len(data['components'])} components from {data['subrun_id']}") + + with open(out_path, "w") as f: + json.dump(data, f) + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/graph_interp/scripts/run.py b/spd/graph_interp/scripts/run.py new file mode 100644 index 000000000..2ed94638b --- /dev/null +++ b/spd/graph_interp/scripts/run.py @@ -0,0 +1,102 @@ +"""CLI entry point for graph interpretation. + +Called by SLURM or directly: + python -m spd.graph_interp.scripts.run --config_json '{...}' +""" + +import os +from datetime import datetime +from typing import Any + +from dotenv import load_dotenv + +from spd.adapters import adapter_from_id +from spd.adapters.spd import SPDAdapter +from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.config import GraphInterpConfig +from spd.graph_interp.interpret import run_graph_interp +from spd.graph_interp.schemas import get_graph_interp_subrun_dir +from spd.harvest.repo import HarvestRepo +from spd.log import logger + + +def main( + decomposition_id: str, + config_json: dict[str, Any], + harvest_subrun_id: str | None = None, +) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + config = GraphInterpConfig.model_validate(config_json) + + load_dotenv() + openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") + assert openrouter_api_key, "OPENROUTER_API_KEY not set" + + subrun_id = "ti-" + datetime.now().strftime("%Y%m%d_%H%M%S") + subrun_dir = get_graph_interp_subrun_dir(decomposition_id, subrun_id) + subrun_dir.mkdir(parents=True, exist_ok=True) + config.to_file(subrun_dir / "config.yaml") + db_path = subrun_dir / "interp.db" + logger.info(f"Graph interp run: {subrun_dir}") + + logger.info("Loading adapter and model metadata...") + adapter = adapter_from_id(decomposition_id) + assert isinstance(adapter, SPDAdapter) + logger.info("Loading harvest data...") + if harvest_subrun_id is not None: + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) + else: + harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=True) + assert harvest is not None, f"No harvest data for {decomposition_id}" + + logger.info("Loading dataset attributions...") + attributions = AttributionRepo.open(decomposition_id) + assert attributions is not None, f"Dataset attributions required for {decomposition_id}" + attribution_storage = attributions.get_attributions() + logger.info( + f" {attribution_storage.n_components} components, {attribution_storage.n_tokens_processed:,} tokens" + ) + + logger.info("Loading component correlations...") + correlations = harvest.get_correlations() + assert correlations is not None, f"Component correlations required for {decomposition_id}" + + logger.info("Loading token stats...") + token_stats = harvest.get_token_stats() + assert token_stats is not None, f"Token stats required for {decomposition_id}" + + logger.info("Data loading complete") + + run_graph_interp( + openrouter_api_key=openrouter_api_key, + config=config, + harvest=harvest, + attribution_storage=attribution_storage, + correlation_storage=correlations, + token_stats=token_stats, + model_metadata=adapter.model_metadata, + db_path=db_path, + tokenizer_name=adapter.tokenizer_name, + ) + + +def get_command( + decomposition_id: str, + config: GraphInterpConfig, + harvest_subrun_id: str | None = None, +) -> str: + config_json = config.model_dump_json(exclude_none=True) + cmd = ( + "python -m spd.graph_interp.scripts.run " + f"--decomposition_id {decomposition_id} " + f"--config_json '{config_json}' " + ) + if harvest_subrun_id is not None: + cmd += f"--harvest_subrun_id {harvest_subrun_id} " + return cmd + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/graph_interp/scripts/run_slurm.py b/spd/graph_interp/scripts/run_slurm.py new file mode 100644 index 000000000..fed1c146b --- /dev/null +++ b/spd/graph_interp/scripts/run_slurm.py @@ -0,0 +1,69 @@ +"""SLURM launcher for graph interpretation. + +Submits a single CPU job that runs the three-phase interpretation pipeline. +Depends on both harvest merge and attribution merge jobs. +""" + +from dataclasses import dataclass + +from spd.graph_interp.config import GraphInterpSlurmConfig +from spd.graph_interp.scripts import run +from spd.log import logger +from spd.utils.slurm import SlurmConfig, SubmitResult, generate_script, submit_slurm_job + + +@dataclass +class GraphInterpSubmitResult: + result: SubmitResult + + +def submit_graph_interp( + decomposition_id: str, + config: GraphInterpSlurmConfig, + dependency_job_ids: list[str], + snapshot_branch: str | None = None, + harvest_subrun_id: str | None = None, +) -> GraphInterpSubmitResult: + """Submit graph interpretation to SLURM. + + Args: + decomposition_id: ID of the target decomposition. + config: Graph interp SLURM configuration. + dependency_job_ids: Jobs to wait for (harvest merge + attribution merge). + snapshot_branch: Git snapshot branch to use. + harvest_subrun_id: Specific harvest subrun to use. + """ + cmd = run.get_command( + decomposition_id=decomposition_id, + config=config.config, + harvest_subrun_id=harvest_subrun_id, + ) + + dependency_str = ":".join(dependency_job_ids) if dependency_job_ids else None + + slurm_config = SlurmConfig( + job_name="spd-graph-interp", + partition=config.partition, + n_gpus=0, + cpus_per_task=16, + mem="240G", + time=config.time, + snapshot_branch=snapshot_branch, + dependency_job_id=dependency_str, + comment=decomposition_id, + ) + script_content = generate_script(slurm_config, cmd) + result = submit_slurm_job(script_content, "spd-graph-interp") + + logger.section("Graph interp job submitted") + logger.values( + { + "Job ID": result.job_id, + "Decomposition ID": decomposition_id, + "Model": config.config.model, + "Depends on": ", ".join(dependency_job_ids), + "Log": result.log_pattern, + } + ) + + return GraphInterpSubmitResult(result=result) diff --git a/spd/graph_interp/scripts/run_slurm_cli.py b/spd/graph_interp/scripts/run_slurm_cli.py new file mode 100644 index 000000000..a40fbee0b --- /dev/null +++ b/spd/graph_interp/scripts/run_slurm_cli.py @@ -0,0 +1,27 @@ +"""CLI entry point for graph interp SLURM launcher. + +Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. + +Usage: + spd-graph-interp --config graph_interp_config.yaml +""" + +import fire + + +def main(decomposition_id: str, config: str) -> None: + """Submit graph interpretation pipeline to SLURM. + + Args: + decomposition_id: ID of the target decomposition run. + config: Path to GraphInterpSlurmConfig YAML/JSON. + """ + from spd.graph_interp.config import GraphInterpSlurmConfig + from spd.graph_interp.scripts.run_slurm import submit_graph_interp + + slurm_config = GraphInterpSlurmConfig.from_file(config) + submit_graph_interp(decomposition_id, slurm_config, dependency_job_ids=[]) + + +def cli() -> None: + fire.Fire(main) diff --git a/spd/harvest/db.py b/spd/harvest/db.py index 52556918f..10573c276 100644 --- a/spd/harvest/db.py +++ b/spd/harvest/db.py @@ -82,7 +82,6 @@ def __init__(self, db_path: Path, readonly: bool = False) -> None: ) else: self._conn = sqlite3.connect(str(db_path)) - self._conn.execute("PRAGMA journal_mode=WAL") self._conn.executescript(_SCHEMA) self._conn.row_factory = sqlite3.Row diff --git a/spd/harvest/intruder.py b/spd/harvest/intruder.py index b16f2d7f3..f91a5e0c2 100644 --- a/spd/harvest/intruder.py +++ b/spd/harvest/intruder.py @@ -19,7 +19,7 @@ from spd.app.backend.utils import delimit_tokens from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls from spd.harvest.config import IntruderEvalConfig -from spd.harvest.repo import HarvestRepo +from spd.harvest.db import HarvestDB from spd.harvest.schemas import ActivationExample, ComponentData from spd.log import logger @@ -146,7 +146,7 @@ async def run_intruder_scoring( model: str, openrouter_api_key: str, tokenizer_name: str, - harvest: HarvestRepo, + score_db: HarvestDB, eval_config: IntruderEvalConfig, limit: int | None, cost_limit_usd: float | None, @@ -163,7 +163,7 @@ async def run_intruder_scoring( density_index = DensityIndex(components, min_examples=n_real + 1) - existing_scores = harvest.get_scores("intruder") + existing_scores = score_db.get_scores("intruder") completed = set(existing_scores.keys()) if completed: logger.info(f"Resuming: {len(completed)} already scored") @@ -234,7 +234,7 @@ async def run_intruder_scoring( score = correct / len(trials) if trials else 0.0 result = IntruderResult(component_key=ck, score=score, trials=trials, n_errors=n_err) results.append(result) - harvest.save_score(ck, "intruder", score, json.dumps(asdict(result))) + score_db.save_score(ck, "intruder", score, json.dumps(asdict(result))) logger.info(f"Scored {len(results)} components") return results diff --git a/spd/harvest/scripts/run_intruder.py b/spd/harvest/scripts/run_intruder.py index d14766251..4dbef810a 100644 --- a/spd/harvest/scripts/run_intruder.py +++ b/spd/harvest/scripts/run_intruder.py @@ -6,6 +6,7 @@ from spd.adapters import adapter_from_id from spd.harvest.config import IntruderEvalConfig +from spd.harvest.db import HarvestDB from spd.harvest.intruder import run_intruder_scoring from spd.harvest.repo import HarvestRepo @@ -24,7 +25,8 @@ def main( tokenizer_name = adapter_from_id(decomposition_id).tokenizer_name - harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=False) + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) + score_db = HarvestDB(harvest._dir / "harvest.db") components = harvest.get_all_components() @@ -34,12 +36,13 @@ def main( model=eval_config.model, openrouter_api_key=openrouter_api_key, tokenizer_name=tokenizer_name, - harvest=harvest, + score_db=score_db, eval_config=eval_config, limit=eval_config.limit, cost_limit_usd=eval_config.cost_limit_usd, ) ) + score_db.close() def get_command(decomposition_id: str, config: IntruderEvalConfig, harvest_subrun_id: str) -> str: diff --git a/spd/investigate/CLAUDE.md b/spd/investigate/CLAUDE.md new file mode 100644 index 000000000..922734220 --- /dev/null +++ b/spd/investigate/CLAUDE.md @@ -0,0 +1,118 @@ +# Investigation Module + +Launch a Claude Code agent to investigate a specific research question about an SPD model decomposition. + +## Usage + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What circuit handles verb agreement?" --max_turns 30 --time 4:00:00 +``` + +For parallel investigations, run the command multiple times with different prompts. + +## Architecture + +``` +spd/investigate/ +├── __init__.py # Public exports +├── CLAUDE.md # This file +├── schemas.py # Pydantic models for outputs (BehaviorExplanation, InvestigationEvent) +├── agent_prompt.py # System prompt template with model info injection +└── scripts/ + ├── __init__.py + ├── run_slurm_cli.py # CLI entry point (spd-investigate) + ├── run_slurm.py # SLURM submission logic + └── run_agent.py # Worker script (runs in SLURM job) +``` + +## How It Works + +1. `spd-investigate` creates output dir, metadata, git snapshot, and submits a single SLURM job +2. The SLURM job runs `run_agent.py` which: + - Starts an isolated FastAPI backend with MCP support + - Loads the SPD run onto GPU + - Fetches model architecture info + - Generates the agent prompt (research question + model context + methodology) + - Launches Claude Code with MCP tools +3. The agent investigates using MCP tools and writes findings to the output directory + +## MCP Tools + +The agent accesses all SPD functionality via MCP at `/mcp`: + +**Circuit Discovery:** +- `optimize_graph` — Find minimal circuit for a behavior (streams progress) +- `create_prompt` — Tokenize text and get next-token probabilities + +**Component Analysis:** +- `get_component_info` — Interpretation, token stats, correlations +- `probe_component` — Fast CI probing on custom text +- `get_component_activation_examples` — Training examples where a component fires +- `get_component_attributions` — Dataset-level component dependencies +- `get_attribution_strength` — Attribution between specific component pairs + +**Testing:** +- `run_ablation` — Test circuit with only selected components +- `search_dataset` — Search training data + +**Metadata:** +- `get_model_info` — Architecture details + +**Output:** +- `update_research_log` — Append to research log (PRIMARY OUTPUT) +- `save_graph_artifact` — Save graph for inline visualization +- `save_explanation` — Save complete behavior explanation +- `set_investigation_summary` — Set title/summary for UI + +## Output Structure + +``` +SPD_OUT_DIR/investigations// +├── metadata.json # Investigation config (wandb_path, prompt, etc.) +├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) +├── events.jsonl # Structured progress events +├── explanations.jsonl # Complete behavior explanations +├── summary.json # Agent-provided title/summary for UI +├── artifacts/ # Graph artifacts for visualization +│ └── graph_001.json +├── app.db # Isolated SQLite database +├── backend.log # Backend subprocess output +├── claude_output.jsonl # Raw Claude Code output +├── agent_prompt.md # The prompt given to the agent +└── mcp_config.json # MCP server configuration +``` + +## Environment + +The backend runs with `SPD_INVESTIGATION_DIR` set to the investigation directory. This controls: +- Database location: `/app.db` +- Events log: `/events.jsonl` +- Research log: `/research_log.md` + +## Configuration + +CLI arguments: +- `wandb_path` — Required. WandB run path for the SPD decomposition. +- `prompt` — Required. Research question or investigation directive. +- `--context_length` — Token context length (default: 128) +- `--max_turns` — Max Claude turns (default: 50, prevents runaway) +- `--partition` — SLURM partition (default: h200-reserved) +- `--time` — Job time limit (default: 8:00:00) +- `--job_suffix` — Optional suffix for job names + +## Monitoring + +```bash +# Watch research log +tail -f SPD_OUT_DIR/investigations//research_log.md + +# Watch events +tail -f SPD_OUT_DIR/investigations//events.jsonl + +# View explanations +cat SPD_OUT_DIR/investigations//explanations.jsonl | jq . + +# Check SLURM job status +squeue --me +``` diff --git a/spd/investigate/__init__.py b/spd/investigate/__init__.py new file mode 100644 index 000000000..9e666dd7d --- /dev/null +++ b/spd/investigate/__init__.py @@ -0,0 +1,22 @@ +"""Investigation: SLURM-based agent investigation of model behaviors. + +This module provides infrastructure for launching a Claude Code agent to investigate +behaviors in an SPD model decomposition. Each investigation: +1. Starts an isolated app backend instance (separate database, unique port) +2. Receives a specific research question and detailed instructions +3. Investigates behaviors and writes findings to append-only JSONL files +""" + +from spd.investigate.schemas import ( + BehaviorExplanation, + ComponentInfo, + Evidence, + InvestigationEvent, +) + +__all__ = [ + "BehaviorExplanation", + "ComponentInfo", + "Evidence", + "InvestigationEvent", +] diff --git a/spd/investigate/agent_prompt.py b/spd/investigate/agent_prompt.py new file mode 100644 index 000000000..d53a47ac3 --- /dev/null +++ b/spd/investigate/agent_prompt.py @@ -0,0 +1,211 @@ +"""System prompt for SPD investigation agents. + +This module contains the detailed instructions given to the investigation agent. +The agent has access to SPD tools via MCP - tools are self-documenting. +""" + +from typing import Any + +AGENT_SYSTEM_PROMPT = """ +# SPD Behavior Investigation Agent + +You are a research agent investigating behaviors in a neural network model decomposition. +A researcher has given you a specific question to investigate. Your job is to answer it +thoroughly using the SPD analysis tools available to you. + +## Your Mission + +{prompt} + +## Available Tools (via MCP) + +You have access to SPD analysis tools. Use them directly - they have full documentation. + +**Circuit Discovery:** +- **optimize_graph**: Find the minimal circuit for a behavior (e.g., "boy" → "he") +- **create_prompt**: Tokenize text and get next-token probabilities + +**Component Analysis:** +- **get_component_info**: Get interpretation and token stats for a component +- **probe_component**: Fast CI probing - test if a component activates on specific text +- **get_component_activation_examples**: See training examples where a component fires +- **get_component_attributions**: Dataset-level component dependencies (sources and targets) +- **get_attribution_strength**: Query attribution strength between two specific components + +**Testing:** +- **run_ablation**: Test a circuit by running with only selected components +- **search_dataset**: Find examples in the training data + +**Metadata:** +- **get_model_info**: Get model architecture details +- **get_stored_graphs**: Retrieve previously computed graphs + +**Output:** +- **update_research_log**: Append to your research log (PRIMARY OUTPUT - use frequently!) +- **save_graph_artifact**: Save a graph for inline visualization in your research log +- **save_explanation**: Save a complete, validated behavior explanation +- **set_investigation_summary**: Set a title and summary for your investigation + +## Investigation Methodology + +### Step 1: Understand the Question + +Read the research question carefully. Think about what behaviors, components, or mechanisms +might be relevant. Use `get_model_info` if you need to understand the model architecture. + +### Step 2: Explore and Hypothesize + +- Use `create_prompt` to test prompts and see what the model predicts +- Use `search_dataset` to find relevant examples in the training data +- Use `probe_component` to quickly test whether specific components respond to your prompts +- Use `get_component_info` to understand what components do + +### Step 3: Find Circuits + +- Use `optimize_graph` to find the minimal circuit for specific behaviors +- Examine which components have high CI values +- Note the circuit size (fewer active components = cleaner mechanism) + +### Step 4: Understand Component Roles + +For each important component in a circuit: +1. Use `get_component_info` for interpretation and token associations +2. Use `probe_component` to test activation on different inputs +3. Use `get_component_activation_examples` to see training examples +4. Use `get_component_attributions` to understand information flow +5. Check correlated components for related functions + +### Step 5: Test with Ablations + +Form hypotheses and test them: +1. Use `run_ablation` with the circuit's components +2. Verify predictions match expectations +3. Try removing individual components to find critical ones + +### Step 6: Document Your Findings + +Use `update_research_log` frequently - this is how humans monitor your work! +When you have a complete explanation, use `save_explanation` to create a structured record. + +## Scientific Principles + +- **Be skeptical**: Your first hypothesis is probably incomplete +- **Triangulate**: Don't rely on a single type of evidence +- **Document uncertainty**: Note what you're confident in vs. uncertain about +- **Consider alternatives**: What else could explain the behavior? + +## Output Format + +### Research Log (PRIMARY OUTPUT - Update frequently!) + +Use `update_research_log` with markdown content. Call it every few minutes to show progress: + +Example calls: +``` +update_research_log("## Hypothesis: Gendered Pronoun Circuit\\n\\nTesting prompt: 'The boy said that' → expecting ' he'\\n\\n") + +update_research_log("## Ablation Test\\n\\nResult: P(he) = 0.89 (vs 0.22 baseline)\\n\\nThis confirms the circuit is sufficient!\\n\\n") +``` + +### Including Graph Visualizations + +After running `optimize_graph`, embed the circuit visualization in your research log: + +1. Call `save_graph_artifact` with the graph_id returned by optimize_graph +2. Reference it in your research log using the `spd:graph` code block + +Example: +``` +save_graph_artifact(graph_id=42, caption="Circuit predicting 'he' after 'The boy'") + +update_research_log('''## Circuit Visualization + +```spd:graph +artifact: graph_001 +``` + +This circuit shows the key components involved in predicting "he"... +''') +``` + +### Saving Explanations + +When you have a complete explanation, use `save_explanation`: + +``` +save_explanation( + subject_prompt="The boy said that", + behavior_description="Predicts masculine pronoun 'he' after male subject", + components_involved=[ + {{"component_key": "h.0.mlp.c_fc:407", "role": "Male subject detector"}}, + {{"component_key": "h.3.attn.o_proj:262", "role": "Masculine pronoun promoter"}} + ], + explanation="Component h.0.mlp.c_fc:407 activates on male subjects...", + confidence="medium", + limitations=["Only tested on simple sentences"] +) +``` + +## Getting Started + +1. **Create your research log** with `update_research_log` +2. Understand the research question and plan your approach +3. Use analysis tools to explore the model +4. **Call `update_research_log` frequently** - humans are watching! +5. Use `save_explanation` for complete findings +6. **Call `set_investigation_summary`** with a title and summary when done + +Document what you learn, even if it's "this was more complicated than expected." +""" + + +def _format_model_info(model_info: dict[str, Any]) -> str: + """Format model architecture info for inclusion in the agent prompt.""" + parts = [f"- **Architecture**: {model_info.get('summary', 'Unknown')}"] + + target_config = model_info.get("target_model_config") + if target_config: + if "n_layer" in target_config: + parts.append(f"- **Layers**: {target_config['n_layer']}") + if "n_embd" in target_config: + parts.append(f"- **Hidden dim**: {target_config['n_embd']}") + if "vocab_size" in target_config: + parts.append(f"- **Vocab size**: {target_config['vocab_size']}") + if "n_ctx" in target_config: + parts.append(f"- **Context length**: {target_config['n_ctx']}") + + topology = model_info.get("topology") + if topology and topology.get("block_structure"): + block = topology["block_structure"][0] + attn = ", ".join(block.get("attn_projections", [])) + ffn = ", ".join(block.get("ffn_projections", [])) + parts.append(f"- **Attention projections**: {attn}") + parts.append(f"- **FFN projections**: {ffn}") + + return "\n".join(parts) + + +def get_agent_prompt( + wandb_path: str, + prompt: str, + model_info: dict[str, Any], +) -> str: + """Generate the full agent prompt with runtime parameters filled in.""" + formatted_prompt = AGENT_SYSTEM_PROMPT.format(prompt=prompt) + + model_section = f""" +## Model Architecture + +{_format_model_info(model_info)} + +## Runtime Context + +- **Model Run**: {wandb_path} + +Use the MCP tools for ALL output: +- `update_research_log` → **PRIMARY OUTPUT** - Update frequently with your progress! +- `save_explanation` → Save complete, validated behavior explanations + +**Start by calling update_research_log to create your log, then investigate!** +""" + return formatted_prompt + model_section diff --git a/spd/investigate/schemas.py b/spd/investigate/schemas.py new file mode 100644 index 000000000..d4da1a896 --- /dev/null +++ b/spd/investigate/schemas.py @@ -0,0 +1,104 @@ +"""Schemas for investigation outputs. + +All agent outputs are append-only JSONL files. Each line is a JSON object +conforming to one of the schemas defined here. +""" + +from datetime import UTC, datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ComponentInfo(BaseModel): + """Information about a component involved in a behavior.""" + + component_key: str = Field( + ..., + description="Component key in format 'layer:component_idx' (e.g., 'h.0.mlp.c_fc:5')", + ) + role: str = Field( + ..., + description="The role this component plays in the behavior (e.g., 'stores subject gender')", + ) + interpretation: str | None = Field( + default=None, + description="Auto-interp label for this component if available", + ) + + +class Evidence(BaseModel): + """A piece of supporting evidence for an explanation.""" + + evidence_type: Literal["ablation", "attribution", "activation_pattern", "correlation", "other"] + description: str = Field( + ..., + description="Description of the evidence", + ) + details: dict[str, Any] = Field( + default_factory=dict, + description="Additional structured details (e.g., ablation results, attribution values)", + ) + + +class BehaviorExplanation(BaseModel): + """A candidate explanation for a behavior discovered by an agent. + + This is the primary output schema for agent investigations. Each explanation + describes a behavior (demonstrated by a subject prompt), the components involved, + and supporting evidence. + """ + + subject_prompt: str = Field( + ..., + description="A prompt that demonstrates the behavior being explained", + ) + behavior_description: str = Field( + ..., + description="Clear description of the behavior (e.g., 'correctly predicts gendered pronoun')", + ) + components_involved: list[ComponentInfo] = Field( + ..., + description="List of components involved in this behavior and their roles", + ) + explanation: str = Field( + ..., + description="Explanation of how the components work together to produce the behavior", + ) + supporting_evidence: list[Evidence] = Field( + default_factory=list, + description="Evidence supporting this explanation (ablations, attributions, etc.)", + ) + confidence: Literal["high", "medium", "low"] = Field( + ..., + description="Agent's confidence in this explanation", + ) + alternative_hypotheses: list[str] = Field( + default_factory=list, + description="Alternative hypotheses that were considered but not fully supported", + ) + limitations: list[str] = Field( + default_factory=list, + description="Known limitations of this explanation", + ) + + +class InvestigationEvent(BaseModel): + """A generic event logged by an agent during investigation. + + Used for logging progress, observations, and other non-explanation events. + """ + + event_type: Literal[ + "start", + "progress", + "observation", + "hypothesis", + "test_result", + "explanation", + "error", + "complete", + ] + timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) + message: str + details: dict[str, Any] = Field(default_factory=dict) diff --git a/spd/investigate/scripts/__init__.py b/spd/investigate/scripts/__init__.py new file mode 100644 index 000000000..ff51f7654 --- /dev/null +++ b/spd/investigate/scripts/__init__.py @@ -0,0 +1 @@ +"""Investigation SLURM scripts.""" diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py new file mode 100644 index 000000000..54806ed36 --- /dev/null +++ b/spd/investigate/scripts/run_agent.py @@ -0,0 +1,306 @@ +"""Worker script that runs inside each SLURM job. + +This script: +1. Reads the research question from the investigation metadata +2. Starts the app backend with an isolated database +3. Loads the SPD run and fetches model architecture info +4. Configures MCP server for Claude Code +5. Launches Claude Code with the investigation question +6. Handles cleanup on exit +""" + +import json +import os +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from types import FrameType +from typing import Any + +import fire +import requests + +from spd.investigate.agent_prompt import get_agent_prompt +from spd.investigate.schemas import InvestigationEvent +from spd.investigate.scripts.run_slurm import get_investigation_output_dir +from spd.log import logger + + +def write_mcp_config(inv_dir: Path, port: int) -> Path: + """Write MCP configuration file for Claude Code.""" + mcp_config = { + "mcpServers": { + "spd": { + "type": "http", + "url": f"http://localhost:{port}/mcp", + } + } + } + config_path = inv_dir / "mcp_config.json" + config_path.write_text(json.dumps(mcp_config, indent=2)) + return config_path + + +def write_claude_settings(inv_dir: Path) -> None: + """Write Claude Code settings to pre-grant MCP tool permissions.""" + claude_dir = inv_dir / ".claude" + claude_dir.mkdir(exist_ok=True) + settings = {"permissions": {"allow": ["mcp__spd__*"]}} + (claude_dir / "settings.json").write_text(json.dumps(settings, indent=2)) + + +def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: + """Find an available port starting from start_port.""" + for offset in range(max_attempts): + port = start_port + offset + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except OSError: + continue + raise RuntimeError( + f"Could not find available port in range {start_port}-{start_port + max_attempts}" + ) + + +def wait_for_backend(port: int, timeout: float = 120.0) -> bool: + """Wait for the backend to become healthy.""" + url = f"http://localhost:{port}/api/health" + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + return True + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + return False + + +def load_run(port: int, wandb_path: str, context_length: int) -> None: + """Load the SPD run into the backend. Raises on failure.""" + url = f"http://localhost:{port}/api/runs/load" + params = {"wandb_path": wandb_path, "context_length": context_length} + resp = requests.post(url, params=params, timeout=300) + assert resp.status_code == 200, ( + f"Failed to load run {wandb_path}: {resp.status_code} {resp.text}" + ) + + +def fetch_model_info(port: int) -> dict[str, Any]: + """Fetch model architecture info from the backend.""" + resp = requests.get(f"http://localhost:{port}/api/pretrain_info/loaded", timeout=30) + assert resp.status_code == 200, f"Failed to fetch model info: {resp.status_code} {resp.text}" + result: dict[str, Any] = resp.json() + return result + + +def log_event(events_path: Path, event: InvestigationEvent) -> None: + """Append an event to the events log.""" + with open(events_path, "a") as f: + f.write(event.model_dump_json() + "\n") + + +def run_agent( + wandb_path: str, + inv_id: str, + context_length: int = 128, + max_turns: int = 50, +) -> None: + """Run a single investigation agent. + + Args: + wandb_path: WandB path of the SPD run. + inv_id: Unique identifier for this investigation. + context_length: Context length for prompts. + max_turns: Maximum agentic turns before stopping (prevents runaway agents). + """ + inv_dir = get_investigation_output_dir(inv_id) + assert inv_dir.exists(), f"Investigation directory does not exist: {inv_dir}" + + # Read prompt from metadata + metadata: dict[str, Any] = json.loads((inv_dir / "metadata.json").read_text()) + prompt = metadata["prompt"] + + write_claude_settings(inv_dir) + + events_path = inv_dir / "events.jsonl" + (inv_dir / "explanations.jsonl").touch() + + log_event( + events_path, + InvestigationEvent( + event_type="start", + message=f"Investigation {inv_id} starting", + details={"wandb_path": wandb_path, "inv_id": inv_id, "prompt": prompt}, + ), + ) + + port = find_available_port() + logger.info(f"[{inv_id}] Using port {port}") + + log_event( + events_path, + InvestigationEvent( + event_type="progress", + message=f"Starting backend on port {port}", + details={"port": port}, + ), + ) + + # Start backend with investigation configuration + env = os.environ.copy() + env["SPD_INVESTIGATION_DIR"] = str(inv_dir) + + backend_cmd = [ + sys.executable, + "-m", + "spd.app.backend.server", + "--port", + str(port), + ] + + backend_log_path = inv_dir / "backend.log" + backend_log = open(backend_log_path, "w") # noqa: SIM115 - managed manually + backend_proc = subprocess.Popen( + backend_cmd, + env=env, + stdout=backend_log, + stderr=subprocess.STDOUT, + ) + + def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: + _ = frame + logger.info(f"[{inv_id}] Cleaning up...") + if backend_proc.poll() is None: + backend_proc.terminate() + try: + backend_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + backend_proc.kill() + backend_log.close() + if signum is not None: + sys.exit(1) + + signal.signal(signal.SIGTERM, cleanup) + signal.signal(signal.SIGINT, cleanup) + + try: + logger.info(f"[{inv_id}] Waiting for backend...") + if not wait_for_backend(port): + log_event( + events_path, + InvestigationEvent(event_type="error", message="Backend failed to start"), + ) + raise RuntimeError("Backend failed to start") + + logger.info(f"[{inv_id}] Backend ready, loading run...") + log_event( + events_path, + InvestigationEvent(event_type="progress", message="Backend ready, loading run"), + ) + + load_run(port, wandb_path, context_length) + + logger.info(f"[{inv_id}] Run loaded, fetching model info...") + model_info = fetch_model_info(port) + + logger.info(f"[{inv_id}] Launching Claude Code...") + log_event( + events_path, + InvestigationEvent( + event_type="progress", message="Run loaded, launching Claude Code agent" + ), + ) + + agent_prompt = get_agent_prompt( + wandb_path=wandb_path, + prompt=prompt, + model_info=model_info, + ) + + (inv_dir / "agent_prompt.md").write_text(agent_prompt) + + mcp_config_path = write_mcp_config(inv_dir, port) + logger.info(f"[{inv_id}] MCP config written to {mcp_config_path}") + + claude_output_path = inv_dir / "claude_output.jsonl" + claude_cmd = [ + "claude", + "--print", + "--verbose", + "--output-format", + "stream-json", + "--max-turns", + str(max_turns), + # MCP: only our backend, no inherited servers + "--mcp-config", + str(mcp_config_path), + # Permissions: only MCP tools, deny everything else + "--permission-mode", + "dontAsk", + "--allowedTools", + "mcp__spd__*", + # Isolation: skip all user/project settings (no plugins, no inherited config) + "--setting-sources", + "", + "--model", + "opus", + ] + + logger.info(f"[{inv_id}] Starting Claude Code (max_turns={max_turns})...") + logger.info(f"[{inv_id}] Monitor with: tail -f {claude_output_path}") + + with open(claude_output_path, "w") as output_file: + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=output_file, + stderr=subprocess.STDOUT, + text=True, + cwd=str(inv_dir), + ) + + assert claude_proc.stdin is not None + claude_proc.stdin.write(agent_prompt) + claude_proc.stdin.close() + + claude_proc.wait() + + log_event( + events_path, + InvestigationEvent( + event_type="complete", + message="Investigation complete", + details={"exit_code": claude_proc.returncode}, + ), + ) + + logger.info(f"[{inv_id}] Investigation complete") + + except Exception as e: + log_event( + events_path, + InvestigationEvent( + event_type="error", + message=f"Agent failed: {e}", + details={"error_type": type(e).__name__}, + ), + ) + logger.error(f"[{inv_id}] Failed: {e}") + raise + finally: + cleanup() + + +def cli() -> None: + fire.Fire(run_agent) + + +if __name__ == "__main__": + cli() diff --git a/spd/investigate/scripts/run_slurm.py b/spd/investigate/scripts/run_slurm.py new file mode 100644 index 000000000..703ed2f78 --- /dev/null +++ b/spd/investigate/scripts/run_slurm.py @@ -0,0 +1,95 @@ +"""SLURM submission logic for investigation jobs.""" + +import json +import secrets +import sys +from dataclasses import dataclass +from pathlib import Path + +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.git_utils import create_git_snapshot +from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job +from spd.utils.wandb_utils import parse_wandb_run_path + + +@dataclass +class InvestigationResult: + inv_id: str + job_id: str + output_dir: Path + + +def get_investigation_output_dir(inv_id: str) -> Path: + return SPD_OUT_DIR / "investigations" / inv_id + + +def launch_investigation( + wandb_path: str, + prompt: str, + context_length: int, + max_turns: int, + partition: str, + time: str, + job_suffix: str | None, +) -> InvestigationResult: + """Launch a single investigation agent via SLURM. + + Creates a SLURM job that starts an isolated app backend, loads the SPD run, + and launches a Claude Code agent with the given research question. + """ + # Normalize wandb_path to canonical form (entity/project/run_id) + entity, project, run_id = parse_wandb_run_path(wandb_path) + canonical_wandb_path = f"{entity}/{project}/{run_id}" + + inv_id = f"inv-{secrets.token_hex(4)}" + output_dir = get_investigation_output_dir(inv_id) + output_dir.mkdir(parents=True, exist_ok=True) + + snapshot_branch, commit_hash = create_git_snapshot(inv_id) + + suffix = f"-{job_suffix}" if job_suffix else "" + job_name = f"spd-investigate{suffix}" + + metadata = { + "inv_id": inv_id, + "wandb_path": canonical_wandb_path, + "prompt": prompt, + "context_length": context_length, + "max_turns": max_turns, + "snapshot_branch": snapshot_branch, + "commit_hash": commit_hash, + } + (output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2)) + + cmd = ( + f"{sys.executable} -m spd.investigate.scripts.run_agent " + f'"{wandb_path}" ' + f"--inv_id {inv_id} " + f"--context_length {context_length} " + f"--max_turns {max_turns}" + ) + + slurm_config = SlurmConfig( + job_name=job_name, + partition=partition, + n_gpus=1, + time=time, + snapshot_branch=snapshot_branch, + ) + script = generate_script(slurm_config, cmd) + result = submit_slurm_job(script, "investigate") + + logger.section("Investigation submitted") + logger.values( + { + "Investigation ID": inv_id, + "Job ID": result.job_id, + "WandB path": canonical_wandb_path, + "Prompt": prompt[:100] + ("..." if len(prompt) > 100 else ""), + "Output directory": str(output_dir), + "Logs": result.log_pattern, + } + ) + + return InvestigationResult(inv_id=inv_id, job_id=result.job_id, output_dir=output_dir) diff --git a/spd/investigate/scripts/run_slurm_cli.py b/spd/investigate/scripts/run_slurm_cli.py new file mode 100644 index 000000000..df784de61 --- /dev/null +++ b/spd/investigate/scripts/run_slurm_cli.py @@ -0,0 +1,59 @@ +"""CLI entry point for investigation SLURM launcher. + +Usage: + spd-investigate "" + spd-investigate @prompt.txt + spd-investigate "" --max_turns 30 +""" + +from pathlib import Path + +import fire + +from spd.settings import DEFAULT_PARTITION_NAME + + +def _resolve_prompt(prompt: str) -> str: + """If prompt starts with @, read from that file path. Otherwise return as-is.""" + if prompt.startswith("@"): + path = Path(prompt[1:]) + assert path.exists(), f"Prompt file not found: {path}" + return path.read_text().strip() + return prompt + + +def main( + wandb_path: str, + prompt: str, + context_length: int = 128, + max_turns: int = 50, + partition: str = DEFAULT_PARTITION_NAME, + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a single investigation agent for a specific question. + + Args: + wandb_path: WandB run path for the SPD decomposition to investigate. + prompt: The research question, or @filepath to read from a file. + context_length: Context length for prompts (default 128). + max_turns: Maximum agentic turns (default 50, prevents runaway). + partition: SLURM partition name. + time: Job time limit (default 8 hours). + job_suffix: Optional suffix for SLURM job names. + """ + from spd.investigate.scripts.run_slurm import launch_investigation + + launch_investigation( + wandb_path=wandb_path, + prompt=_resolve_prompt(prompt), + context_length=context_length, + max_turns=max_turns, + partition=partition, + time=time, + job_suffix=job_suffix, + ) + + +def cli() -> None: + fire.Fire(main) diff --git a/spd/persistent_pgd.py b/spd/persistent_pgd.py index 03003d910..add3ba119 100644 --- a/spd/persistent_pgd.py +++ b/spd/persistent_pgd.py @@ -152,7 +152,7 @@ def __init__( assert batch_dims[0] % n == 0, ( f"n_sources={n} must divide the per-rank microbatch size " f"{batch_dims[0]}, not the global batch size. " - f"With DDP, reduce n_sources or use fewer ranks." + f"Adjust n_sources or batch_size to satisfy this." ) source_leading_dims = [n] + list(batch_dims[1:]) case PerBatchPerPositionScope(): diff --git a/spd/postprocess/__init__.py b/spd/postprocess/__init__.py index 616cd144c..e2feab509 100644 --- a/spd/postprocess/__init__.py +++ b/spd/postprocess/__init__.py @@ -68,8 +68,7 @@ def postprocess(config: PostprocessConfig) -> Path: intruder_slurm = SlurmConfig( job_name="spd-intruder-eval", partition=config.intruder.partition, - n_gpus=0, - cpus_per_task=16, + n_gpus=2, time=config.intruder.time, snapshot_branch=snapshot_branch, dependency_job_id=harvest_result.merge_result.job_id, diff --git a/spd/settings.py b/spd/settings.py index 9e3e37f7b..56d60ecfe 100644 --- a/spd/settings.py +++ b/spd/settings.py @@ -24,3 +24,5 @@ DEFAULT_PARTITION_NAME = "h200-reserved" DEFAULT_PROJECT_NAME = "spd" + +SPD_APP_DEFAULT_RUN: str | None = os.environ.get("SPD_APP_DEFAULT_RUN") diff --git a/spd/topology/gradient_connectivity.py b/spd/topology/gradient_connectivity.py index bcaac8423..31ba61b5a 100644 --- a/spd/topology/gradient_connectivity.py +++ b/spd/topology/gradient_connectivity.py @@ -74,19 +74,20 @@ def embed_hook( cache[f"{embed_path}_post_detach"] = embed_cache[f"{embed_path}_post_detach"] cache[f"{unembed_path}_pre_detach"] = comp_output_with_cache.output - layers = [embed_path, *model.target_module_paths, unembed_path] + source_layers = [embed_path, *model.target_module_paths] # Don't include "output" as source + target_layers = [*model.target_module_paths, unembed_path] # Don't include embed as target # Test all distinct pairs for gradient flow test_pairs = [] - for in_layer in layers[:-1]: # Don't include "output" as source - for out_layer in layers[1:]: # Don't include embed as target - if in_layer != out_layer: - test_pairs.append((in_layer, out_layer)) + for source_layer in source_layers: + for target_layer in target_layers: + if source_layer != target_layer: + test_pairs.append((source_layer, target_layer)) sources_by_target: dict[str, list[str]] = defaultdict(list) - for in_layer, out_layer in test_pairs: - out_pre_detach = cache[f"{out_layer}_pre_detach"] - in_post_detach = cache[f"{in_layer}_post_detach"] + for source_layer, target_layer in test_pairs: + out_pre_detach = cache[f"{target_layer}_pre_detach"] + in_post_detach = cache[f"{source_layer}_post_detach"] out_value = out_pre_detach[0, 0, 0] grads = torch.autograd.grad( outputs=out_value, @@ -97,5 +98,5 @@ def embed_hook( assert len(grads) == 1 grad = grads[0] if grad is not None: # pyright: ignore[reportUnnecessaryComparison] - sources_by_target[out_layer].append(in_layer) + sources_by_target[target_layer].append(source_layer) return dict(sources_by_target) diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 6b1d85813..7dd49a730 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -31,7 +31,11 @@ # Regex patterns for parsing W&B run references # Run IDs can be 8 chars (e.g., "d2ec3bfe") or prefixed with char-dash (e.g., "s-d2ec3bfe") +DEFAULT_WANDB_ENTITY = "goodfire" +DEFAULT_WANDB_PROJECT = "spd" + _RUN_ID_PATTERN = r"(?:[a-z0-9]-)?[a-z0-9]{8}" +_BARE_RUN_ID_RE = re.compile(r"^(s-[a-z0-9]{8})$") _WANDB_PATH_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/({_RUN_ID_PATTERN})$") _WANDB_PATH_WITH_RUNS_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/runs/({_RUN_ID_PATTERN})$") _WANDB_URL_RE = re.compile( @@ -169,6 +173,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: """Parse various W&B run reference formats into (entity, project, run_id). Accepts: + - "s-xxxxxxxx" (bare SPD run ID, assumes goodfire/spd) - "entity/project/runId" (compact form) - "entity/project/runs/runId" (with /runs/) - "wandb:entity/project/runId" (with wandb: prefix) @@ -187,6 +192,10 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: if s.startswith("wandb:"): s = s[6:] + # Bare run ID (e.g. "s-17805b61") → default entity/project + if m := _BARE_RUN_ID_RE.match(s): + return DEFAULT_WANDB_ENTITY, DEFAULT_WANDB_PROJECT, m.group(1) + # Try compact form: entity/project/runid if m := _WANDB_PATH_RE.match(s): return m.group(1), m.group(2), m.group(3) @@ -201,6 +210,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: raise ValueError( f"Invalid W&B run reference. Expected one of:\n" + f' - "s-xxxxxxxx" (bare run ID)\n' f' - "entity/project/xxxxxxxx"\n' f' - "entity/project/runs/xxxxxxxx"\n' f' - "wandb:entity/project/runs/xxxxxxxx"\n' diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index f39ef385f..cc5ed5a0e 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -147,6 +147,7 @@ def app_with_state(): harvest=None, interp=None, attributions=None, + graph_interp=None, ) manager = StateManager.get() diff --git a/tests/dataset_attributions/test_harvester.py b/tests/dataset_attributions/test_harvester.py deleted file mode 100644 index 96ebc5df8..000000000 --- a/tests/dataset_attributions/test_harvester.py +++ /dev/null @@ -1,265 +0,0 @@ -"""Tests for dataset attribution harvester logic.""" - -from pathlib import Path - -import torch - -from spd.dataset_attributions.storage import DatasetAttributionStorage - - -def _make_storage( - n_components: int = 2, - vocab_size: int = 3, - d_model: int = 4, - source_to_component: torch.Tensor | None = None, - source_to_out_residual: torch.Tensor | None = None, -) -> DatasetAttributionStorage: - """Helper to create storage with default values.""" - n_sources = vocab_size + n_components - if source_to_component is None: - source_to_component = torch.zeros(n_sources, n_components) - if source_to_out_residual is None: - source_to_out_residual = torch.zeros(n_sources, d_model) - - return DatasetAttributionStorage( - component_layer_keys=[f"layer1:{i}" for i in range(n_components)], - vocab_size=vocab_size, - d_model=d_model, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - n_batches_processed=10, - n_tokens_processed=1000, - ci_threshold=0.0, - ) - - -class TestDatasetAttributionStorage: - """Tests for DatasetAttributionStorage. - - Storage structure: - - source_to_component: (n_sources, n_components) for component target attributions - - source_to_out_residual: (n_sources, d_model) for output target attributions (via w_unembed) - """ - - def test_has_source_and_target(self) -> None: - """Test has_source and has_target methods.""" - storage = _make_storage(n_components=2, vocab_size=3) - - # wte tokens can only be sources - assert storage.has_source("wte:0") - assert storage.has_source("wte:2") - assert not storage.has_source("wte:3") # Out of vocab - assert not storage.has_target("wte:0") # wte can't be target - - # Component layers can be both sources and targets - assert storage.has_source("layer1:0") - assert storage.has_source("layer1:1") - assert storage.has_target("layer1:0") - assert storage.has_target("layer1:1") - assert not storage.has_source("layer1:2") - assert not storage.has_target("layer1:2") - - # output tokens can only be targets - assert storage.has_target("output:0") - assert storage.has_target("output:2") - assert not storage.has_target("output:3") # Out of vocab - assert not storage.has_source("output:0") # output can't be source - - def test_get_attribution_component_target(self) -> None: - """Test get_attribution for component targets (no w_unembed needed).""" - # 2 component layers: layer1:0, layer1:1 - # vocab_size=2, d_model=4 - # n_sources = 2 + 2 = 4 - # source_to_component shape: (4, 2) - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 -> components - [3.0, 4.0], # wte:1 -> components - [5.0, 6.0], # layer1:0 -> components - [7.0, 8.0], # layer1:1 -> components - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # wte:0 -> layer1:0 - assert storage.get_attribution("wte:0", "layer1:0") == 1.0 - # wte:1 -> layer1:1 - assert storage.get_attribution("wte:1", "layer1:1") == 4.0 - # layer1:0 -> layer1:1 - assert storage.get_attribution("layer1:0", "layer1:1") == 6.0 - - def test_get_attribution_output_target(self) -> None: - """Test get_attribution for output targets (requires w_unembed).""" - # source_to_out_residual shape: (4, 4) for n_sources=4, d_model=4 - source_to_out_residual = torch.tensor( - [ - [1.0, 0.0, 0.0, 0.0], # wte:0 -> out_residual - [0.0, 1.0, 0.0, 0.0], # wte:1 -> out_residual - [0.0, 0.0, 1.0, 0.0], # layer1:0 -> out_residual - [0.0, 0.0, 0.0, 1.0], # layer1:1 -> out_residual - ] - ) - # w_unembed shape: (d_model=4, vocab=2) - w_unembed = torch.tensor( - [ - [1.0, 2.0], # d0 -> outputs - [3.0, 4.0], # d1 -> outputs - [5.0, 6.0], # d2 -> outputs - [7.0, 8.0], # d3 -> outputs - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, d_model=4, source_to_out_residual=source_to_out_residual - ) - - # wte:0 -> output:0 = out_residual[0] @ w_unembed[:, 0] = [1,0,0,0] @ [1,3,5,7] = 1.0 - assert storage.get_attribution("wte:0", "output:0", w_unembed=w_unembed) == 1.0 - # wte:1 -> output:1 = [0,1,0,0] @ [2,4,6,8] = 4.0 - assert storage.get_attribution("wte:1", "output:1", w_unembed=w_unembed) == 4.0 - # layer1:0 -> output:0 = [0,0,1,0] @ [1,3,5,7] = 5.0 - assert storage.get_attribution("layer1:0", "output:0", w_unembed=w_unembed) == 5.0 - - def test_get_top_sources_component_target(self) -> None: - """Test get_top_sources for component targets.""" - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 - [5.0, 3.0], # wte:1 - [2.0, 4.0], # layer1:0 - [3.0, 1.0], # layer1:1 - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # Top sources TO layer1:0 (column 0): wte:0=1.0, wte:1=5.0, layer1:0=2.0, layer1:1=3.0 - sources = storage.get_top_sources("layer1:0", k=2, sign="positive") - assert len(sources) == 2 - assert sources[0].component_key == "wte:1" - assert sources[0].value == 5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == 3.0 - - def test_get_top_sources_negative(self) -> None: - """Test get_top_sources with negative sign.""" - source_to_component = torch.tensor( - [ - [-1.0, 2.0], - [-5.0, 3.0], - [-2.0, 4.0], - [-3.0, 1.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - sources = storage.get_top_sources("layer1:0", k=2, sign="negative") - assert len(sources) == 2 - # wte:1 has most negative (-5.0), then layer1:1 (-3.0) - assert sources[0].component_key == "wte:1" - assert sources[0].value == -5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == -3.0 - - def test_get_top_component_targets(self) -> None: - """Test get_top_component_targets (no w_unembed needed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - targets = storage.get_top_component_targets("layer1:0", k=2, sign="positive") - assert len(targets) == 2 - assert targets[0].component_key == "layer1:1" - assert targets[0].value == 4.0 - assert targets[1].component_key == "layer1:0" - assert targets[1].value == 2.0 - - def test_get_top_targets_with_outputs(self) -> None: - """Test get_top_targets including outputs (requires w_unembed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - # Make out_residual attribution that produces high output values - source_to_out_residual = torch.tensor( - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0], # layer1:0 -> out_residual (sum=4 per output) - [0.0, 0.0, 0.0, 0.0], - ] - ) - # w_unembed that gives output:0=10, output:1=5 - w_unembed = torch.tensor( - [ - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - ] - ) - storage = _make_storage( - n_components=2, - vocab_size=2, - d_model=4, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - ) - - targets = storage.get_top_targets("layer1:0", k=3, sign="positive", w_unembed=w_unembed) - assert len(targets) == 3 - # output:0 = 10.0, output:1 = 5.0, layer1:1 = 4.0 - assert targets[0].component_key == "output:0" - assert targets[0].value == 10.0 - assert targets[1].component_key == "output:1" - assert targets[1].value == 5.0 - assert targets[2].component_key == "layer1:1" - assert targets[2].value == 4.0 - - def test_save_and_load(self, tmp_path: Path) -> None: - """Test save and load roundtrip.""" - n_components = 2 - vocab_size = 3 - d_model = 4 - n_sources = vocab_size + n_components - - original = DatasetAttributionStorage( - component_layer_keys=["layer:0", "layer:1"], - vocab_size=vocab_size, - d_model=d_model, - source_to_component=torch.randn(n_sources, n_components), - source_to_out_residual=torch.randn(n_sources, d_model), - n_batches_processed=100, - n_tokens_processed=10000, - ci_threshold=0.01, - ) - - path = tmp_path / "test_attributions.pt" - original.save(path) - - loaded = DatasetAttributionStorage.load(path) - - assert loaded.component_layer_keys == original.component_layer_keys - assert loaded.vocab_size == original.vocab_size - assert loaded.d_model == original.d_model - assert loaded.n_batches_processed == original.n_batches_processed - assert loaded.n_tokens_processed == original.n_tokens_processed - assert loaded.ci_threshold == original.ci_threshold - assert torch.allclose(loaded.source_to_component, original.source_to_component) - assert torch.allclose(loaded.source_to_out_residual, original.source_to_out_residual) diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py new file mode 100644 index 000000000..95fb788f7 --- /dev/null +++ b/tests/dataset_attributions/test_storage.py @@ -0,0 +1,181 @@ +"""Tests for DatasetAttributionStorage.""" + +from pathlib import Path + +import torch +from torch import Tensor + +from spd.dataset_attributions.storage import DatasetAttributionStorage + +VOCAB_SIZE = 4 +D_MODEL = 4 +LAYER_0 = "0.glu.up" +LAYER_1 = "1.glu.up" +C0 = 3 # components in layer 0 +C1 = 2 # components in layer 1 + + +def _make_storage(seed: int = 0, n_tokens: int = 640) -> DatasetAttributionStorage: + """Build storage for test topology. + + Sources by target: + "0.glu.up": ["embed"] -> embed edge (C0, VOCAB_SIZE) + "1.glu.up": ["embed", "0.glu.up"] -> embed edge (C1, VOCAB_SIZE) + regular (C1, C0) + "output": ["0.glu.up", "1.glu.up"] -> unembed (D_MODEL, C0), (D_MODEL, C1) + "output": ["embed"] -> embed_unembed (D_MODEL, VOCAB_SIZE) + """ + g = torch.Generator().manual_seed(seed) + + def rand(*shape: int) -> Tensor: + return torch.randn(*shape, generator=g) + + return DatasetAttributionStorage( + regular_attr={LAYER_1: {LAYER_0: rand(C1, C0)}}, + regular_attr_abs={LAYER_1: {LAYER_0: rand(C1, C0)}}, + embed_attr={LAYER_0: rand(C0, VOCAB_SIZE), LAYER_1: rand(C1, VOCAB_SIZE)}, + embed_attr_abs={LAYER_0: rand(C0, VOCAB_SIZE), LAYER_1: rand(C1, VOCAB_SIZE)}, + unembed_attr={LAYER_0: rand(D_MODEL, C0), LAYER_1: rand(D_MODEL, C1)}, + embed_unembed_attr=rand(D_MODEL, VOCAB_SIZE), + w_unembed=rand(D_MODEL, VOCAB_SIZE), + ci_sum={LAYER_0: rand(C0).abs() + 1.0, LAYER_1: rand(C1).abs() + 1.0}, + component_act_sq_sum={LAYER_0: rand(C0).abs() + 1.0, LAYER_1: rand(C1).abs() + 1.0}, + logit_sq_sum=rand(VOCAB_SIZE).abs() + 1.0, + embed_token_count=torch.randint(100, 1000, (VOCAB_SIZE,), generator=g), + ci_threshold=1e-6, + n_tokens_processed=n_tokens, + ) + + +class TestNComponents: + def test_counts_all_target_layers(self): + storage = _make_storage() + assert storage.n_components == C0 + C1 + + +class TestGetTopSources: + def test_component_target_returns_entries(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + assert all(r.value > 0 for r in results) + assert len(results) <= 5 + + def test_component_target_includes_embed(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=20, sign="positive", metric="attr") + layers = {r.layer for r in results} + assert "embed" in layers or LAYER_0 in layers + + def test_output_target(self): + storage = _make_storage() + results = storage.get_top_sources("output:0", k=5, sign="positive", metric="attr") + assert len(results) <= 5 + + def test_output_target_attr_abs_returns_empty(self): + storage = _make_storage() + results = storage.get_top_sources("output:0", k=5, sign="positive", metric="attr_abs") + assert results == [] + + def test_target_only_in_embed_attr(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_0}:0", k=5, sign="positive", metric="attr") + assert len(results) <= 5 + assert all(r.layer == "embed" for r in results) + + def test_attr_abs_metric(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr_abs") + assert len(results) <= 5 + + def test_no_nan_in_results(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=20, sign="positive", metric="attr") + assert all(not torch.isnan(torch.tensor(r.value)) for r in results) + + +class TestGetTopTargets: + def test_component_source(self): + storage = _make_storage() + results = storage.get_top_targets( + f"{LAYER_0}:0", k=5, sign="positive", metric="attr", include_outputs=False + ) + assert len(results) <= 5 + assert all(r.value > 0 for r in results) + + def test_embed_source(self): + storage = _make_storage() + results = storage.get_top_targets( + "embed:0", k=5, sign="positive", metric="attr", include_outputs=False + ) + assert len(results) <= 5 + + def test_include_outputs(self): + storage = _make_storage() + results = storage.get_top_targets(f"{LAYER_0}:0", k=20, sign="positive", metric="attr") + assert len(results) > 0 + + def test_embed_source_with_outputs(self): + storage = _make_storage() + results = storage.get_top_targets("embed:0", k=20, sign="positive", metric="attr") + assert len(results) > 0 + + def test_attr_abs_skips_output_targets(self): + storage = _make_storage() + results = storage.get_top_targets(f"{LAYER_0}:0", k=20, sign="positive", metric="attr_abs") + assert all(r.layer != "output" for r in results) + + +class TestSaveLoad: + def test_roundtrip(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + + loaded = DatasetAttributionStorage.load(path) + + assert loaded.ci_threshold == original.ci_threshold + assert loaded.n_tokens_processed == original.n_tokens_processed + assert loaded.n_components == original.n_components + + def test_roundtrip_query_consistency(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + loaded = DatasetAttributionStorage.load(path) + + orig_results = original.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + load_results = loaded.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + + assert len(orig_results) == len(load_results) + for orig, loaded in zip(orig_results, load_results, strict=True): + assert orig.component_key == loaded.component_key + assert abs(orig.value - loaded.value) < 1e-5 + + +class TestMerge: + def test_two_workers_additive(self, tmp_path: Path): + s1 = _make_storage(seed=0, n_tokens=320) + s2 = _make_storage(seed=42, n_tokens=320) + + p1 = tmp_path / "rank_0.pt" + p2 = tmp_path / "rank_1.pt" + s1.save(p1) + s2.save(p2) + + merged = DatasetAttributionStorage.merge([p1, p2]) + + assert merged.n_tokens_processed == 640 + + def test_single_file(self, tmp_path: Path): + original = _make_storage(seed=7, n_tokens=640) + path = tmp_path / "rank_0.pt" + original.save(path) + + merged = DatasetAttributionStorage.merge([path]) + + assert merged.n_tokens_processed == original.n_tokens_processed + + orig_results = original.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + merge_results = merged.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + for o, m in zip(orig_results, merge_results, strict=True): + assert o.component_key == m.component_key + assert abs(o.value - m.value) < 1e-5 diff --git a/unused.py b/unused.py new file mode 100644 index 000000000..5a89b7cdc --- /dev/null +++ b/unused.py @@ -0,0 +1,171 @@ +""" +Finds potentially redundant type options across a codebase by analyzing call sites. + +Reports: +1. Params typed as Optional/X|None where None is never actually passed +2. Params with defaults where the arg is always explicitly provided (default never used) + +Limitations: +- Name-based function matching (false positives with same-name functions) +- No *args/**kwargs support +- No dynamic calls (getattr, functools.partial, etc.) +- Single-pass, no cross-module type inference +""" + +import ast +import sys +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class ParamInfo: + has_none_type: bool = False + has_default: bool = False + + +@dataclass +class FuncInfo: + params: dict[str, ParamInfo] = field(default_factory=dict) + # param_name -> list of bools: was None passed at this call site? + none_passed: dict[str, list[bool]] = field(default_factory=lambda: defaultdict(list)) + # param_name -> list of bools: was arg explicitly provided? + explicitly_provided: dict[str, list[bool]] = field(default_factory=lambda: defaultdict(list)) + call_count: int = 0 + + +def annotation_includes_none(node: ast.expr | None) -> bool: + if node is None: + return False + # X | None or None | X + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + return annotation_includes_none(node.left) or annotation_includes_none(node.right) + # None constant + if isinstance(node, ast.Constant) and node.value is None: + return True + # Optional[X] + if isinstance(node, ast.Subscript): + if isinstance(node.value, ast.Name) and node.value.id == "Optional": + return True + if isinstance(node.value, ast.Attribute) and node.value.attr == "Optional": + return True + return False + + +def is_none(node: ast.expr) -> bool: + return isinstance(node, ast.Constant) and node.value is None + + +def collect_functions(tree: ast.AST) -> dict[str, FuncInfo]: + funcs = {} + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + info = FuncInfo() + args = node.args + # Build param list (positional only + regular args), skip *args/**kwargs + all_params = args.posonlyargs + args.args + defaults_offset = len(all_params) - len(args.defaults) + + for i, arg in enumerate(all_params): + has_default = i >= defaults_offset + has_none = annotation_includes_none(arg.annotation) + info.params[arg.arg] = ParamInfo(has_none_type=has_none, has_default=has_default) + + funcs[node.name] = info + return funcs + + +def process_calls(tree: ast.AST, funcs: dict[str, FuncInfo]) -> None: + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + + # Get function name + if isinstance(node.func, ast.Name): + name = node.func.id + elif isinstance(node.func, ast.Attribute): + name = node.func.attr + else: + continue + + if name not in funcs: + continue + + info = funcs[name] + info.call_count += 1 + param_names = list(info.params.keys()) + + # Track positional args + provided = set() + for i, arg_node in enumerate(node.args): + if i < len(param_names): + pname = param_names[i] + provided.add(pname) + info.none_passed[pname].append(is_none(arg_node)) + + # Track keyword args + for kw in node.keywords: + if kw.arg is None: # **kwargs unpacking, skip + continue + provided.add(kw.arg) + info.none_passed[kw.arg].append(is_none(kw.value)) + + # Record which params were explicitly provided + for pname in info.params: + info.explicitly_provided[pname].append(pname in provided) + + +def analyze(root: Path) -> None: + all_funcs: dict[str, FuncInfo] = {} + + files = list(root.rglob("*.py")) + trees = [] + for f in files: + try: + source = f.read_text() + tree = ast.parse(source, filename=str(f)) + trees.append(tree) + for name, info in collect_functions(tree).items(): + all_funcs[name] = info # last definition wins on collision + except SyntaxError: + continue + + for tree in trees: + process_calls(tree, all_funcs) + + print("=== Redundant | None ===") + for name, info in sorted(all_funcs.items()): + if info.call_count == 0: + continue + for pname, param in info.params.items(): + if not param.has_none_type: + continue + calls_with_data = info.none_passed.get(pname, []) + if not calls_with_data: + continue + if not any(calls_with_data): + print( + f" {name}({pname}): None never passed ({len(calls_with_data)} call sites checked)" + ) + + print("\n=== Default never used (always explicitly provided) ===") + for name, info in sorted(all_funcs.items()): + if info.call_count == 0: + continue + for pname, param in info.params.items(): + if not param.has_default: + continue + provided_list = info.explicitly_provided.get(pname, []) + if not provided_list: + continue + if all(provided_list): + print( + f" {name}({pname}): default never used ({len(provided_list)} call sites checked)" + ) + + +if __name__ == "__main__": + root = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(".") + analyze(root)