From c3e6c8e39e14d48d9f622b232120c4ede23f1c60 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 12:32:01 +0000 Subject: [PATCH 01/20] App overhaul + clustering + cleanup App Backend: - AppTokenizer: server-side token display - Refactored graph computation, absolute-target attribution edges - SQLite prompt DB on NFS with DELETE journal + fcntl.flock locking - New routers: graph_interp, investigations, MCP, pretrain_info, run_registry, data_sources - Unified InterventionResult, target-sans masking, masked predictions - Spotlight mode, configurable optimization loss (CE/KL/positional) - Removed get_attribution_strength MCP tool (storage method was deleted) App Frontend: - Canvas edges, spotlight mode, 50K edge limit - New: DataSourcesTab, InvestigationsTab, ClustersTab, ModelGraph, DatasetExplorerTab, OptimizationSettings - Design system: CSS variables, token probability coloring - Lazy loading, bulk endpoints, Loadable pattern Clustering: - CUDA support, memory optimizations, Pile model configs Cleanup: - Remove scratch files, CLAUDE.md updates, .gitignore additions Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 7 +- .mcp.json | 3 + CLAUDE.md | 95 +- find_clean_facts.py | 572 ------ package-lock.json | 6 - scripts/export_circuit_json.py | 200 +++ scripts/migrate_harvest_data.py | 369 ++++ scripts/parse_transformer_circuits_post.py | 363 ++++ scripts/render_circuit_html.py | 69 + scripts/test_abs_grad_trick.py | 150 ++ spd/app/CLAUDE.md | 78 +- spd/app/TODO.md | 173 +- spd/app/backend/app_tokenizer.py | 6 + spd/app/backend/compute.py | 424 ++++- spd/app/backend/database.py | 471 +++-- spd/app/backend/optim_cis.py | 599 +++++-- spd/app/backend/routers/__init__.py | 8 + spd/app/backend/routers/clusters.py | 16 +- spd/app/backend/routers/data_sources.py | 18 +- .../backend/routers/dataset_attributions.py | 204 +-- spd/app/backend/routers/graph_interp.py | 373 ++++ spd/app/backend/routers/graphs.py | 609 +++++-- spd/app/backend/routers/intervention.py | 335 +--- spd/app/backend/routers/investigations.py | 317 ++++ spd/app/backend/routers/mcp.py | 1573 +++++++++++++++++ spd/app/backend/routers/pretrain_info.py | 24 + spd/app/backend/routers/run_registry.py | 95 + spd/app/backend/routers/runs.py | 9 + spd/app/backend/schemas.py | 2 - spd/app/backend/server.py | 32 + spd/app/backend/state.py | 26 + spd/app/frontend/package-lock.json | 15 + spd/app/frontend/package.json | 3 + spd/app/frontend/src/app.css | 30 +- .../ActivationContextsPagedTable.svelte | 166 +- .../ActivationContextsViewer.svelte | 58 +- .../components/ClusterComponentCard.svelte | 254 +++ .../src/components/ClusterPathInput.svelte | 26 +- .../src/components/ClustersTab.svelte | 27 + .../src/components/ClustersViewer.svelte | 252 +++ .../src/components/DataSourcesTab.svelte | 308 ++-- .../src/components/InvestigationsTab.svelte | 645 +++++++ .../frontend/src/components/ModelGraph.svelte | 520 ++++++ .../src/components/ProbColoredTokens.svelte | 51 +- .../components/PromptAttributionsGraph.svelte | 200 ++- .../components/PromptAttributionsTab.svelte | 1101 ++++++++---- .../src/components/RunSelector.svelte | 336 +++- .../frontend/src/components/RunView.svelte | 52 +- .../src/components/TokenHighlights.svelte | 19 +- .../TokenizedSearchResultCard.svelte | 58 +- .../investigations/ArtifactGraph.svelte | 448 +++++ .../investigations/ResearchLogViewer.svelte | 223 +++ .../prompt-attr/ComponentNodeCard.svelte | 177 +- .../prompt-attr/ComputeProgressOverlay.svelte | 71 +- .../prompt-attr/InterventionsView.svelte | 1119 ++++++------ .../components/prompt-attr/NodeTooltip.svelte | 37 +- .../prompt-attr/OptimizationGrid.svelte | 134 ++ .../prompt-attr/OptimizationParams.svelte | 84 +- .../prompt-attr/OptimizationSettings.svelte | 214 ++- .../prompt-attr/OutputNodeCard.svelte | 63 +- .../prompt-attr/PromptPicker.svelte | 10 +- .../prompt-attr/StagedNodesPanel.svelte | 8 - .../prompt-attr/TokenDropdown.svelte | 52 +- .../components/prompt-attr/ViewTabs.svelte | 8 +- .../src/components/prompt-attr/types.ts | 34 +- .../ui/CorrelatedSubcomponentsList.svelte | 5 +- .../ui/DatasetAttributionsSection.svelte | 189 +- .../ui/DisplaySettingsDropdown.svelte | 28 + .../components/ui/EdgeAttributionGrid.svelte | 103 +- .../components/ui/EdgeAttributionList.svelte | 109 +- .../src/components/ui/GraphInterpBadge.svelte | 264 +++ .../src/components/ui/TokenSpan.svelte | 43 + spd/app/frontend/src/lib/api/correlations.ts | 18 +- spd/app/frontend/src/lib/api/dataSources.ts | 8 +- .../src/lib/api/datasetAttributions.ts | 14 +- spd/app/frontend/src/lib/api/graphInterp.ts | 81 + spd/app/frontend/src/lib/api/graphs.ts | 174 +- spd/app/frontend/src/lib/api/index.ts | 3 + spd/app/frontend/src/lib/api/intervention.ts | 33 +- .../frontend/src/lib/api/investigations.ts | 101 ++ spd/app/frontend/src/lib/api/pretrainInfo.ts | 1 + spd/app/frontend/src/lib/api/runRegistry.ts | 26 + spd/app/frontend/src/lib/api/runs.ts | 2 + spd/app/frontend/src/lib/colors.ts | 16 +- spd/app/frontend/src/lib/componentKeys.ts | 17 + .../src/lib/displaySettings.svelte.ts | 10 + spd/app/frontend/src/lib/interventionTypes.ts | 104 +- spd/app/frontend/src/lib/layerAliasing.ts | 219 --- .../src/lib/promptAttributionsTypes.ts | 63 +- spd/app/frontend/src/lib/registry.ts | 82 +- spd/app/frontend/src/lib/tokenUtils.ts | 41 + .../src/lib/useComponentData.svelte.ts | 60 +- .../useComponentDataExpectCached.svelte.ts | 62 +- spd/app/frontend/src/lib/useRun.svelte.ts | 54 +- spd/app/frontend/vite.config.ts | 1 + spd/app/run_app.py | 2 +- spd/autointerp/db.py | 1 - spd/clustering/CLAUDE.md | 12 + .../configs/pipeline-dev-simplestories.yaml | 2 +- spd/dataset_attributions/harvest.py | 1 - spd/utils/sqlite.py | 4 +- tests/app/test_server_api.py | 46 + tests/dataset_attributions/test_storage.py | 10 +- 103 files changed, 12139 insertions(+), 3899 deletions(-) create mode 100644 .mcp.json delete mode 100644 find_clean_facts.py delete mode 100644 package-lock.json create mode 100644 scripts/export_circuit_json.py create mode 100644 scripts/migrate_harvest_data.py create mode 100644 scripts/parse_transformer_circuits_post.py create mode 100644 scripts/render_circuit_html.py create mode 100644 scripts/test_abs_grad_trick.py create mode 100644 spd/app/backend/routers/graph_interp.py create mode 100644 spd/app/backend/routers/investigations.py create mode 100644 spd/app/backend/routers/mcp.py create mode 100644 spd/app/backend/routers/run_registry.py create mode 100644 spd/app/frontend/src/components/ClusterComponentCard.svelte create mode 100644 spd/app/frontend/src/components/ClustersTab.svelte create mode 100644 spd/app/frontend/src/components/ClustersViewer.svelte create mode 100644 spd/app/frontend/src/components/InvestigationsTab.svelte create mode 100644 spd/app/frontend/src/components/ModelGraph.svelte create mode 100644 spd/app/frontend/src/components/investigations/ArtifactGraph.svelte create mode 100644 spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte create mode 100644 spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte create mode 100644 spd/app/frontend/src/components/ui/GraphInterpBadge.svelte create mode 100644 spd/app/frontend/src/components/ui/TokenSpan.svelte create mode 100644 spd/app/frontend/src/lib/api/graphInterp.ts create mode 100644 spd/app/frontend/src/lib/api/investigations.ts create mode 100644 spd/app/frontend/src/lib/api/runRegistry.ts create mode 100644 spd/app/frontend/src/lib/componentKeys.ts delete mode 100644 spd/app/frontend/src/lib/layerAliasing.ts create mode 100644 spd/app/frontend/src/lib/tokenUtils.ts diff --git a/.gitignore b/.gitignore index 8d3099ac1..3581e751a 100644 --- a/.gitignore +++ b/.gitignore @@ -9,8 +9,6 @@ scripts/outputs/ **/out/ neuronpedia_outputs/ .env -.mcp.json -.cursor/ .vscode/settings.json notebooks/ @@ -179,4 +177,7 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* +*.schema.json + +.claude/worktrees \ No newline at end of file diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 000000000..700113020 --- /dev/null +++ b/.mcp.json @@ -0,0 +1,3 @@ +{ + "mcpServers": {} +} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 2407dde87..13bdcf4d0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,14 +3,18 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Environment Setup + **IMPORTANT**: Always activate the virtual environment before running Python or git operations: + ```bash source .venv/bin/activate ``` -Repo requires `.env` file with WandB credentials (see `.env.example`) +If working in a worktree, make sure there's a local `.venv` first by running `uv sync` in the worktree directory. Do NOT `cd` to the main repo — all commands (including git) should run in the worktree. +Repo requires `.env` file with WandB credentials (see `.env.example`) ## Project Overview + SPD (Stochastic Parameter Decomposition) is a research framework for analyzing neural network components and their interactions through sparse parameter decomposition techniques. - Target model parameters are decomposed as a sum of `parameter components` @@ -36,6 +40,8 @@ The codebase supports three experimental domains: TMS (Toy Model of Superpositio - `ss_llama_simple_mlp`, `ss_llama_simple_mlp-1L`, `ss_llama_simple_mlp-2L` - Llama MLP-only variants - `ss_gpt2`, `ss_gpt2_simple`, `ss_gpt2_simple_noln` - Simple Stories GPT-2 variants - `ss_gpt2_simple-1L`, `ss_gpt2_simple-2L` - GPT-2 simple layer variants + - `pile_llama_simple_mlp-2L`, `pile_llama_simple_mlp-4L`, `pile_llama_simple_mlp-12L` - Pile Llama MLP-only variants + - `pile_gpt2_simple-2L_global_reverse` - Pile GPT-2 with global reverse - `gpt2` - Standard GPT-2 - `ts` - TinyStories @@ -46,7 +52,7 @@ This repository implements methods from two key research papers on parameter dec **Stochastic Parameter Decomposition (SPD)** - [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) -- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. +- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. - Introduces the core SPD framework - Details the stochastic masking approach and optimization techniques used throughout the codebase - Useful reading for understanding the implementation details, though may be outdated. @@ -95,6 +101,7 @@ This repository implements methods from two key research papers on parameter dec ## Architecture Overview **Core SPD Framework:** + - `spd/run_spd.py` - Main SPD optimization logic called by all experiments - `spd/configs.py` - Pydantic config classes for all experiment types - `spd/registry.py` - Centralized experiment registry with all experiment configurations @@ -105,15 +112,17 @@ This repository implements methods from two key research papers on parameter dec - `spd/figures.py` - Figures for logging to WandB (e.g. CI histograms, Identity plots, etc.) **Terminology: Sources vs Masks:** + - **Sources** (`adv_sources`, `PPGDSources`, `self.sources`): The raw values that PGD optimizes adversarially. These are interpolated with CI to produce component masks: `mask = ci + (1 - ci) * source`. Used in both regular PGD (`spd/metrics/pgd_utils.py`) and persistent PGD (`spd/persistent_pgd.py`). - **Masks** (`component_masks`, `RoutingMasks`, `make_mask_infos`, `n_mask_samples`): The materialized per-component masks used during forward passes. These are produced from sources (in PGD) or from stochastic sampling, and are a general SPD concept across the whole codebase. **Experiment Structure:** Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: + - `models.py` - Experiment-specific model classes and pretrained loading - `*_decomposition.py` - Main SPD execution script -- `train_*.py` - Training script for target models +- `train_*.py` - Training script for target models - `*_config.yaml` - Configuration files - `plotting.py` - Visualization utilities @@ -127,7 +136,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: **Configuration System:** - YAML configs define all experiment parameters -- Pydantic models provide type safety and validation +- Pydantic models provide type safety and validation - WandB integration for experiment tracking and model storage - Supports both local paths and `wandb:project/runs/run_id` format for model loading - Centralized experiment registry (`spd/registry.py`) manages all experiment configurations @@ -137,8 +146,9 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: - `spd/harvest/` - Offline GPU pipeline for collecting component statistics (correlations, token stats, activation examples) - `spd/autointerp/` - LLM-based automated interpretation of components - `spd/dataset_attributions/` - Multi-GPU pipeline for computing component-to-component attribution strengths aggregated over training data -- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions}//` -- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, and `spd/dataset_attributions/CLAUDE.md` for details +- `spd/graph_interp/` - Context-aware component labeling using graph structure (attributions + correlations) +- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions,graph_interp}//` +- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, `spd/dataset_attributions/CLAUDE.md`, and `spd/graph_interp/CLAUDE.md` for details **Output Directory (`SPD_OUT_DIR`):** @@ -160,12 +170,14 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── investigate/ # Agent investigation (see investigate/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) │ ├── dataset_attributions/ # Dataset attributions (see dataset_attributions/CLAUDE.md) │ ├── harvest/ # Statistics collection (see harvest/CLAUDE.md) │ ├── postprocess/ # Unified postprocessing pipeline (harvest + attributions + autointerp) +│ ├── graph_interp/ # Context-aware interpretation (see graph_interp/CLAUDE.md) │ ├── pretrain/ # Target model pretraining (see pretrain/CLAUDE.md) │ ├── experiments/ # Experiment implementations │ │ ├── tms/ # Toy Model of Superposition @@ -201,14 +213,17 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-autointerp` | `spd/autointerp/scripts/run_slurm_cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | | `spd-postprocess` | `spd/postprocess/cli.py` | Unified postprocessing pipeline (harvest + attributions + interpret + evals) | +| `spd-graph-interp` | `spd/graph_interp/scripts/run_slurm_cli.py` | Submit graph interpretation SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | | `spd-pretrain` | `spd/pretrain/scripts/run_slurm_cli.py` | Pretrain target models | +| `spd-investigate` | `spd/investigate/scripts/run_slurm_cli.py` | Launch investigation agent | ### Files to Skip When Searching Use `spd/` as the search root (not repo root) to avoid noise. **Always skip:** + - `.venv/` - Virtual environment - `__pycache__/`, `.pytest_cache/`, `.ruff_cache/` - Build artifacts - `node_modules/` - Frontend dependencies @@ -218,27 +233,37 @@ Use `spd/` as the search root (not repo root) to avoid noise. - `wandb/` - WandB local files **Usually skip unless relevant:** + - `tests/` - Test files (unless debugging test failures) - `papers/` - Research paper drafts ### Common Call Chains **Running Experiments:** + - `spd-run` → `spd/scripts/run.py` → `spd/utils/slurm.py` → SLURM → `spd/run_spd.py` - `spd-local` → `spd/scripts/run_local.py` → `spd/run_spd.py` directly **Harvest Pipeline:** + - `spd-harvest` → `spd/harvest/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/harvest/scripts/run.py` → `spd/harvest/harvest.py` **Autointerp Pipeline:** + - `spd-autointerp` → `spd/autointerp/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → `spd/autointerp/interpret.py` **Dataset Attributions Pipeline:** + - `spd-attributions` → `spd/dataset_attributions/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/dataset_attributions/harvest.py` **Clustering Pipeline:** + - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Investigation Pipeline:** + +- `spd-investigate` → `spd/investigate/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM → `spd/investigate/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -285,6 +310,28 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Investigation (`spd-investigate`) + +Launch a Claude Code agent to investigate a specific question about an SPD model: + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What components are involved in verb agreement?" --time 4:00:00 +``` + +Each investigation: + +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates the specific research question using SPD tools via MCP +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/investigations//` + +For parallel investigations, run the command multiple times with different prompts. + +See `spd/investigate/CLAUDE.md` for details. + ### Unified Postprocessing (`spd-postprocess`) Run all postprocessing steps for a completed SPD run with a single command: @@ -295,6 +342,7 @@ spd-postprocess --config custom_config.yaml # Use custom config ``` Defaults are defined in `PostprocessConfig` (`spd/postprocess/config.py`). Pass a custom YAML/JSON config to override. Set any section to `null` to skip it: + - `attributions: null` — skip dataset attributions - `autointerp: null` — skip autointerp entirely (interpret + evals) - `autointerp.evals: null` — skip evals but still run interpret @@ -323,6 +371,7 @@ spd-run # Run all experiments ``` All `spd-run` executions: + - Submit jobs to SLURM - Create a git snapshot for reproducibility - Create W&B workspace views @@ -343,6 +392,7 @@ spd-run --experiments --sweep --n_agents [--cpu] ``` Examples: + ```bash spd-run --experiments tms_5-2 --sweep --n_agents 4 # Run TMS 5-2 sweep with 4 GPU agents spd-run --experiments resid_mlp2 --sweep --n_agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents @@ -364,6 +414,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee - Default sweep parameters are loaded from `spd/scripts/sweep_params.yaml` - You can specify a custom sweep parameters file by passing its path to `--sweep` - Sweep parameters support both experiment-specific and global configurations: + ```yaml # Global parameters applied to all experiments global: @@ -376,7 +427,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee # Experiment-specific parameters (override global) tms_5-2: seed: - values: [100, 200] # Overrides global seed + values: [100, 200] # Overrides global seed task_config: feature_probability: values: [0.05, 0.1] @@ -402,6 +453,7 @@ model = ComponentModel.from_run_info(run_info) # Local paths work too model = ComponentModel.from_pretrained("/path/to/checkpoint.pt") ``` + **Path Formats:** - WandB: `wandb:entity/project/run_id` or `wandb:entity/project/runs/run_id` @@ -415,12 +467,12 @@ Downloaded runs are cached in `SPD_OUT_DIR/runs/-/`. - This includes not setting off multiple sweeps/evals that total >8 GPUs - Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` - ## Coding Guidelines & Software Engineering Principles **This is research code, not production. Prioritize simplicity and fail-fast over defensive programming.** Core principles: + - **Fail fast** - assert assumptions, crash on violations, don't silently recover - **No legacy support** - delete unused code, don't add fallbacks for old formats or migration shims - **Narrow types** - avoid `| None` unless null is semantically meaningful; use discriminated unions over bags of optional fields @@ -451,11 +503,12 @@ config = get_config(path) value = config.key ``` - ### Tests + - The point of tests in this codebase is to ensure that the code is working as expected, not to prevent production outages - there's no deployment here. Therefore, don't worry about lots of larger integration/end-to-end tests. These often require too much overhead for what it's worth in our case, and this codebase is interactively run so often that issues will likely be caught by the user at very little cost. ### Assertions and error handling + - If you have an invariant in your head, assert it. Are you afraid to assert? Sounds like your program might already be broken. Assert, assert, assert. Never soft fail. - Do not write: `if everythingIsOk: continueHappyPath()`. Instead do `assert everythingIsOk` - You should have a VERY good reason to handle an error gracefully. If your program isn't working like it should then it shouldn't be running, you should be fixing it. @@ -463,11 +516,13 @@ value = config.key - **Write for the golden path.** Never let edge cases bloat the code. Before handling them, just raise an exception. If an edge case becomes annoying enough, we'll handle it then — but write first and foremost for the common case. ### Control Flow + - Keep I/O as high up as possible. Make as many functions as possible pure. - Prefer `match` over `if/elif/else` chains when dispatching on conditions - more declarative and makes cases explicit - If you either have (a and b) or neither, don't make them both independently optional. Instead, put them in an optional tuple ### Types, Arguments, and Defaults + - Write your invariants into types as much as possible. - Use jaxtyping for tensor shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. @@ -483,12 +538,11 @@ value = config.key - Don't use `from __future__ import annotations` — use string quotes for forward references instead. ### Tensor Operations + - Try to use einops by default for clarity. - Assert shapes liberally - Document complex tensor manipulations - - ### Comments - Comments hide sloppy code. If you feel the need to write a comment, consider that you should instead @@ -498,10 +552,11 @@ value = config.key - separate an inlined computation into a meaningfully named variable - Don’t write dialogic / narrativised comments or code. Instead, write comments that describe the code as is, not the diff you're making. Examples of narrativising comments: - - `# the function now uses y instead of x` - - `# changed to be faster` - - `# we now traverse in reverse` + - `# the function now uses y instead of x` + - `# changed to be faster` + - `# we now traverse in reverse` - Here's an example of a bad diff, where the new comment makes reference to a change in code, not just the state of the code: + ``` 95 - # Reservoir states 96 - reservoir_states: list[ReservoirState] @@ -509,21 +564,15 @@ value = config.key 96 + reservoir: TensorReservoirState ``` - -### Fire CLI Gotchas - -This codebase uses `python-fire` for CLI entry points in SLURM worker scripts. Two known gotchas: - -- **JSON args become dicts.** Fire auto-parses JSON strings into Python dicts. So `--config_json '{"n_batches": 500}'` arrives as `dict`, not `str`. Use `model_validate()` (not `model_validate_json()`), and type the param as `dict[str, Any]`. -- **Numeric-looking strings become ints/floats.** Fire parses `1234_1` (SLURM array job ID format) as an integer. This is partly why we use string-prefixed IDs everywhere (`s-`, `h-`, `da-`, `a-`) — the prefix prevents Fire from coercing them. - ### Other Important Software Development Practices + - Don't add legacy fallbacks or migration code - just change it and let old data be manually migrated if needed. -- Delete unused code. +- Delete unused code. - If an argument is always x, strongly consider removing as an argument and just inlining - **Update CLAUDE.md files** when changing code structure, adding/removing files, or modifying key interfaces. Update the CLAUDE.md in the same directory (or nearest parent) as the changed files. ### GitHub + - To view github issues and PRs, use the github cli (e.g. `gh issue view 28` or `gh pr view 30`). - When making PRs, use the github template defined in `.github/pull_request_template.md`. - Before committing, ALWAYS ensure you are on the correct branch and do not use `git add .` to add all unstaged files. Instead, add only the individual files you changed, don't commit all files. diff --git a/find_clean_facts.py b/find_clean_facts.py deleted file mode 100644 index e22ca906d..000000000 --- a/find_clean_facts.py +++ /dev/null @@ -1,572 +0,0 @@ -#!/usr/bin/env python3 -""" -Find the cleanest (most monosemantic) facts from the SPD analysis. - -A fact is "clean" if the components that fire on it are monosemantic. - -For down_proj: A component is monosemantic if it responds to a single label. -For up_proj: A component is monosemantic if it responds to: - - A single label, OR - - A single input element at position 0, 1, or 2 - -We score each fact based on how monosemantic its firing components are. -""" - -import re -from collections import Counter, defaultdict - - -def parse_analysis_file(filepath: str): - """Parse the analysis.txt file to extract component and fact information.""" - - with open(filepath) as f: - lines = f.readlines() - - # Parse component-to-facts mapping (from the COMPONENT ACTIVATION ANALYSIS section) - up_proj_components = defaultdict(list) # component_id -> list of (fact_idx, input, label) - down_proj_components = defaultdict(list) - - # Parse the per-fact analysis (from PER-FACT COMPONENT ANALYSIS section) - up_proj_per_fact = {} # fact_idx -> {inputs, label, components} - down_proj_per_fact = {} - - current_module = None - current_section = None # 'component_analysis' or 'per_fact' - current_component = None - - i = 0 - while i < len(lines): - line = lines[i].strip() - - # Detect section changes - if "COMPONENT ACTIVATION ANALYSIS" in line: - current_section = "component_analysis" - elif "PER-FACT COMPONENT ANALYSIS" in line: - current_section = "per_fact" - elif "SUMMARY STATISTICS" in line: - current_section = "summary" - - # Detect module changes - if "MODULE: block.mlp.up_proj" in line: - current_module = "up_proj" - elif "MODULE: block.mlp.down_proj" in line: - current_module = "down_proj" - - # Parse component activation analysis section - if current_section == "component_analysis" and current_module: - # Parse component header: [Rank X] Component Y (mean CI=Z): N facts above threshold - comp_match = re.match(r"\[Rank \d+\] Component (\d+)", line) - if comp_match: - current_component = int(comp_match.group(1)) - - # Parse fact line: Fact X: input=[a, b, c] → label=Y (CI=Z) - fact_match = re.match( - r"Fact\s+(\d+): input=\[(\d+), (\d+), (\d+)\] → label=(\d+)", line - ) - if fact_match and current_component is not None: - fact_idx = int(fact_match.group(1)) - inputs = [ - int(fact_match.group(2)), - int(fact_match.group(3)), - int(fact_match.group(4)), - ] - label = int(fact_match.group(5)) - - if current_module == "up_proj": - up_proj_components[current_component].append((fact_idx, inputs, label)) - else: - down_proj_components[current_component].append((fact_idx, inputs, label)) - - # Parse per-fact analysis section - if current_section == "per_fact" and current_module: - # Parse fact line - fact_match = re.match( - r"Fact\s+(\d+): input=\[(\d+), (\d+), (\d+)\] → label=(\d+)", line - ) - if fact_match: - fact_idx = int(fact_match.group(1)) - inputs = [ - int(fact_match.group(2)), - int(fact_match.group(3)), - int(fact_match.group(4)), - ] - label = int(fact_match.group(5)) - - # Look for components in the next lines - components = [] - j = i + 1 - while j < len(lines): - next_line = lines[j].strip() - - # Check if we've hit the next fact or section - if ( - next_line.startswith("Fact ") - or next_line.startswith("===") - or next_line.startswith("MODULE:") - ): - break - - # Parse component activations like C206(1.000) - comp_matches = re.findall(r"C(\d+)\(([\d.]+)\)", next_line) - for comp_id, ci_score in comp_matches: - components.append((int(comp_id), float(ci_score))) - - j += 1 - - if current_module == "up_proj": - up_proj_per_fact[fact_idx] = { - "inputs": inputs, - "label": label, - "components": components, - } - else: - down_proj_per_fact[fact_idx] = { - "inputs": inputs, - "label": label, - "components": components, - } - - i += 1 - - return up_proj_components, down_proj_components, up_proj_per_fact, down_proj_per_fact - - -def compute_component_monosemanticity(component_facts: list) -> dict: - """ - Compute monosemanticity scores for a component. - """ - if not component_facts: - return None - - labels = [f[2] for f in component_facts] - pos0_vals = [f[1][0] for f in component_facts] - pos1_vals = [f[1][1] for f in component_facts] - pos2_vals = [f[1][2] for f in component_facts] - - label_counts = Counter(labels) - pos0_counts = Counter(pos0_vals) - pos1_counts = Counter(pos1_vals) - pos2_counts = Counter(pos2_vals) - - n = len(component_facts) - - dominant_label, dominant_label_count = label_counts.most_common(1)[0] - dominant_pos0, dominant_pos0_count = pos0_counts.most_common(1)[0] - dominant_pos1, dominant_pos1_count = pos1_counts.most_common(1)[0] - dominant_pos2, dominant_pos2_count = pos2_counts.most_common(1)[0] - - return { - "n_facts": n, - "n_unique_labels": len(label_counts), - "dominant_label": dominant_label, - "dominant_label_ratio": dominant_label_count / n, - "n_unique_pos0": len(pos0_counts), - "dominant_pos0": dominant_pos0, - "dominant_pos0_ratio": dominant_pos0_count / n, - "n_unique_pos1": len(pos1_counts), - "dominant_pos1": dominant_pos1, - "dominant_pos1_ratio": dominant_pos1_count / n, - "n_unique_pos2": len(pos2_counts), - "dominant_pos2": dominant_pos2, - "dominant_pos2_ratio": dominant_pos2_count / n, - } - - -def is_component_monosemantic(stats: dict, threshold: float = 0.9) -> tuple[bool, str]: - """ - Determine if a component is monosemantic based on its statistics. - Returns (is_monosemantic, reason) - """ - if stats is None: - return False, "no_data" - - # Check if it responds to a single label - if stats["dominant_label_ratio"] >= threshold: - return True, f"label_{stats['dominant_label']}" - - # Check if it responds to a single input element - if stats["dominant_pos0_ratio"] >= threshold: - return True, f"pos0_{stats['dominant_pos0']}" - if stats["dominant_pos1_ratio"] >= threshold: - return True, f"pos1_{stats['dominant_pos1']}" - if stats["dominant_pos2_ratio"] >= threshold: - return True, f"pos2_{stats['dominant_pos2']}" - - return False, "polysemantic" - - -def compute_monosemanticity_score(stats: dict) -> float: - """ - Compute a monosemanticity score from 0 to 1. - Higher score = more monosemantic. - """ - if stats is None: - return 0.0 - - # The score is the maximum of all the dominant ratios - return max( - stats["dominant_label_ratio"], - stats["dominant_pos0_ratio"], - stats["dominant_pos1_ratio"], - stats["dominant_pos2_ratio"], - ) - - -def score_fact( - fact_info: dict, - up_proj_mono_scores: dict, - down_proj_mono_scores: dict, - up_proj_stats: dict, - down_proj_stats: dict, -) -> tuple[float, dict]: - """ - Score a fact based on how monosemantic its firing components are. - Returns (score, details) - """ - up_components = fact_info.get("up_proj_components", []) - down_components = fact_info.get("down_proj_components", []) - - if not up_components and not down_components: - return 0.0, { - "reason": "no_components", - "up_proj_components": [], - "down_proj_components": [], - "n_components": 0, - } - - # For each component, get its monosemanticity score - up_scores = [] - for comp_id, ci_score in up_components: - mono_score = up_proj_mono_scores.get(comp_id, 0.0) - stats = up_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(stats, threshold=0.9) if stats else (False, "unknown") - ) - up_scores.append((comp_id, mono_score, ci_score, is_mono, reason)) - - down_scores = [] - for comp_id, ci_score in down_components: - mono_score = down_proj_mono_scores.get(comp_id, 0.0) - stats = down_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(stats, threshold=0.9) if stats else (False, "unknown") - ) - down_scores.append((comp_id, mono_score, ci_score, is_mono, reason)) - - # Compute fact score as minimum monosemanticity of all components - all_mono_scores = [s[1] for s in up_scores] + [s[1] for s in down_scores] - - if not all_mono_scores: - return 0.0, { - "reason": "no_scores", - "up_proj_components": up_scores, - "down_proj_components": down_scores, - "n_components": 0, - } - - min_score = min(all_mono_scores) - mean_score = sum(all_mono_scores) / len(all_mono_scores) - - # Count how many components are monosemantic - n_mono = sum(1 for s in up_scores + down_scores if s[3]) - total = len(up_scores) + len(down_scores) - - return min_score, { - "up_proj_components": up_scores, - "down_proj_components": down_scores, - "min_mono_score": min_score, - "mean_mono_score": mean_score, - "n_components": total, - "n_mono_components": n_mono, - "mono_ratio": n_mono / total if total > 0 else 0, - } - - -def main(): - print("Parsing analysis.txt...") - up_proj_comps, down_proj_comps, up_proj_facts, down_proj_facts = parse_analysis_file( - "analysis.txt" - ) - - print(f"\nFound {len(up_proj_comps)} up_proj components with facts") - print(f"Found {len(down_proj_comps)} down_proj components with facts") - print(f"Found {len(up_proj_facts)} facts with up_proj info") - print(f"Found {len(down_proj_facts)} facts with down_proj info") - - # Sample check - if up_proj_facts: - sample_fact = list(up_proj_facts.items())[0] - print(f"\nSample up_proj fact: {sample_fact}") - if down_proj_facts: - sample_fact = list(down_proj_facts.items())[0] - print(f"Sample down_proj fact: {sample_fact}") - - # Compute monosemanticity for each component - print("\nComputing component monosemanticity...") - - up_proj_stats = {} - up_proj_mono_scores = {} - for comp_id, facts in up_proj_comps.items(): - stats = compute_component_monosemanticity(facts) - up_proj_stats[comp_id] = stats - up_proj_mono_scores[comp_id] = compute_monosemanticity_score(stats) - - down_proj_stats = {} - down_proj_mono_scores = {} - for comp_id, facts in down_proj_comps.items(): - stats = compute_component_monosemanticity(facts) - down_proj_stats[comp_id] = stats - down_proj_mono_scores[comp_id] = compute_monosemanticity_score(stats) - - # Print some example monosemantic components - print("\n" + "=" * 80) - print("MONOSEMANTIC UP_PROJ COMPONENTS (threshold >= 0.9)") - print("=" * 80) - mono_up = [] - for comp_id, stats in up_proj_stats.items(): - is_mono, reason = is_component_monosemantic(stats, threshold=0.9) - if is_mono: - mono_up.append((comp_id, stats, reason)) - - mono_up.sort(key=lambda x: compute_monosemanticity_score(x[1]), reverse=True) - for comp_id, stats, reason in mono_up[:20]: - print( - f" Component {comp_id}: {reason}, score={compute_monosemanticity_score(stats):.3f}, n_facts={stats['n_facts']}" - ) - print(f" ... and {len(mono_up) - 20} more" if len(mono_up) > 20 else "") - - print(f"\nTotal monosemantic up_proj components: {len(mono_up)} / {len(up_proj_stats)}") - - print("\n" + "=" * 80) - print("MONOSEMANTIC DOWN_PROJ COMPONENTS (threshold >= 0.9)") - print("=" * 80) - mono_down = [] - for comp_id, stats in down_proj_stats.items(): - is_mono, reason = is_component_monosemantic(stats, threshold=0.9) - if is_mono: - mono_down.append((comp_id, stats, reason)) - - mono_down.sort(key=lambda x: compute_monosemanticity_score(x[1]), reverse=True) - for comp_id, stats, reason in mono_down[:20]: - print( - f" Component {comp_id}: {reason}, score={compute_monosemanticity_score(stats):.3f}, n_facts={stats['n_facts']}" - ) - print(f" ... and {len(mono_down) - 20} more" if len(mono_down) > 20 else "") - - print(f"\nTotal monosemantic down_proj components: {len(mono_down)} / {len(down_proj_stats)}") - - # Combine up_proj and down_proj info for each fact - print("\n" + "=" * 80) - print("SCORING FACTS BY MONOSEMANTICITY") - print("=" * 80) - - all_facts = set(up_proj_facts.keys()) | set(down_proj_facts.keys()) - fact_scores = [] - - for fact_idx in all_facts: - up_info = up_proj_facts.get(fact_idx, {}) - down_info = down_proj_facts.get(fact_idx, {}) - - # Get the inputs and label from either source - inputs = up_info.get("inputs") or down_info.get("inputs", []) - label = up_info.get("label", down_info.get("label", -1)) - - combined_info = { - "inputs": inputs, - "label": label, - "up_proj_components": up_info.get("components", []), - "down_proj_components": down_info.get("components", []), - } - - score, details = score_fact( - combined_info, - up_proj_mono_scores, - down_proj_mono_scores, - up_proj_stats, - down_proj_stats, - ) - - fact_scores.append( - { - "fact_idx": fact_idx, - "inputs": inputs, - "label": label, - "score": score, - "details": details, - } - ) - - # Sort by score (highest = cleanest), then by mono ratio, then by fewer components - fact_scores.sort( - key=lambda x: ( - x["score"], - x["details"].get("mono_ratio", 0), - -x["details"].get("n_components", 999), - ), - reverse=True, - ) - - # Print top cleanest facts - print("\nTOP 50 CLEANEST FACTS (highest monosemanticity score):") - print("-" * 80) - - for i, fs in enumerate(fact_scores[:50]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - up_str = ", ".join([f"C{c[0]}({c[4]})" for c in up_comps]) - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"\n{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print(f" Score: {fs['score']:.3f}, mono_ratio: {fs['details'].get('mono_ratio', 0):.2f}") - print(f" Up_proj ({len(up_comps)}): {up_str if up_str else 'none'}") - print(f" Down_proj ({len(down_comps)}): {down_str if down_str else 'none'}") - - # Find facts where ALL components are monosemantic - print("\n" + "=" * 80) - print("FACTS WHERE ALL COMPONENTS ARE MONOSEMANTIC") - print("=" * 80) - - all_mono_facts = [ - fs - for fs in fact_scores - if fs["details"].get("n_components", 0) > 0 and fs["details"].get("mono_ratio", 0) == 1.0 - ] - - print(f"\nFound {len(all_mono_facts)} facts where ALL components are monosemantic:\n") - - for i, fs in enumerate(all_mono_facts[:30]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - up_str = ", ".join([f"C{c[0]}({c[4]})" for c in up_comps]) - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print(f" Up_proj: {up_str if up_str else 'none'}") - print(f" Down_proj: {down_str if down_str else 'none'}") - print() - - if len(all_mono_facts) > 30: - print(f" ... and {len(all_mono_facts) - 30} more") - - # Also show facts with only 1 component firing in up_proj - print("\n" + "=" * 80) - print("FACTS WITH ONLY 1 UP_PROJ COMPONENT FIRING") - print("=" * 80) - - single_comp_facts = [ - fs for fs in fact_scores if len(fs["details"].get("up_proj_components", [])) == 1 - ] - single_comp_facts.sort(key=lambda x: x["score"], reverse=True) - - print(f"\nFound {len(single_comp_facts)} facts with only 1 up_proj component:\n") - - for i, fs in enumerate(single_comp_facts[:30]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - comp_id = up_comps[0][0] - comp_stats = up_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(comp_stats, threshold=0.9) - if comp_stats - else (False, "unknown") - ) - - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print( - f" Up_proj C{comp_id}: mono_score={fs['score']:.3f}, is_mono={is_mono}, reason={reason}" - ) - print(f" Down_proj: {down_str if down_str else 'none'}") - if comp_stats: - print( - f" Component stats: dominant_label={comp_stats['dominant_label']} ({comp_stats['dominant_label_ratio']:.1%})" - ) - print() - - # Print summary - print("\n" + "=" * 80) - print("SUMMARY") - print("=" * 80) - - # Count facts with at least one component - facts_with_components = [fs for fs in fact_scores if fs["details"].get("n_components", 0) > 0] - print(f"\nTotal facts with at least one component: {len(facts_with_components)}") - - score_thresholds = [1.0, 0.95, 0.9, 0.8, 0.5, 0.0] - for thresh in score_thresholds: - count = sum(1 for fs in facts_with_components if fs["score"] >= thresh) - print(f" Facts with monosemanticity score >= {thresh}: {count}") - - # Save results to a file - print("\n\nSaving detailed results to clean_facts_ranking.txt...") - with open("clean_facts_ranking.txt", "w") as f: - f.write("FACTS RANKED BY MONOSEMANTICITY SCORE\n") - f.write("=" * 80 + "\n\n") - f.write("A fact is 'clean' if all components that fire on it are monosemantic.\n") - f.write( - "Monosemantic = responds to a single label or single input position value (>= 90%).\n\n" - ) - - f.write(f"Total facts with at least one component: {len(facts_with_components)}\n") - f.write(f"Facts where ALL components are monosemantic: {len(all_mono_facts)}\n\n") - - f.write("=" * 80 + "\n") - f.write("CLEANEST FACTS (all components monosemantic)\n") - f.write("=" * 80 + "\n\n") - - for i, fs in enumerate(all_mono_facts): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - f.write(f"Rank {i + 1}: Fact {fs['fact_idx']}\n") - f.write(f" Input: {fs['inputs']} → Label: {fs['label']}\n") - f.write(f" Monosemanticity Score: {fs['score']:.4f}\n") - f.write(f" Up_proj components ({len(up_comps)}):\n") - for comp_id, mono_score, ci_score, _is_mono, reason in up_comps: - f.write( - f" C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write(f" Down_proj components ({len(down_comps)}):\n") - for comp_id, mono_score, ci_score, _is_mono, reason in down_comps: - f.write( - f" C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write("\n") - - f.write("\n" + "=" * 80 + "\n") - f.write("ALL FACTS RANKED\n") - f.write("=" * 80 + "\n\n") - - for i, fs in enumerate(facts_with_components): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - f.write(f"Rank {i + 1}: Fact {fs['fact_idx']}\n") - f.write(f" Input: {fs['inputs']} → Label: {fs['label']}\n") - f.write(f" Min Monosemanticity Score: {fs['score']:.4f}\n") - f.write( - f" Mono ratio: {fs['details'].get('mono_ratio', 0):.2f} ({fs['details'].get('n_mono_components', 0)}/{fs['details'].get('n_components', 0)})\n" - ) - f.write(f" Up_proj components ({len(up_comps)}):\n") - for comp_id, mono_score, ci_score, is_mono, reason in up_comps: - mono_marker = "✓" if is_mono else "✗" - f.write( - f" {mono_marker} C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write(f" Down_proj components ({len(down_comps)}):\n") - for comp_id, mono_score, ci_score, is_mono, reason in down_comps: - mono_marker = "✓" if is_mono else "✗" - f.write( - f" {mono_marker} C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write("\n") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/package-lock.json b/package-lock.json deleted file mode 100644 index ca5f8195a..000000000 --- a/package-lock.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "spd", - "lockfileVersion": 3, - "requires": true, - "packages": {} -} diff --git a/scripts/export_circuit_json.py b/scripts/export_circuit_json.py new file mode 100644 index 000000000..144727a33 --- /dev/null +++ b/scripts/export_circuit_json.py @@ -0,0 +1,200 @@ +"""Export an OptimizedPromptAttributionResult to JSON for the circuit graph renderer. + +Usage from a notebook / script: + + from scripts.export_circuit_json import export_circuit_json + + export_circuit_json( + circuit=circuit, + token_ids=tokens.tolist(), + token_strings=tok.get_spans(tokens.tolist()), + output_path=Path("king_circuit.json"), + interp=interp, # optional InterpRepo + graph_interp=gi_repo, # optional GraphInterpRepo + ) +""" + +import json +from pathlib import Path + +from spd.graph_interp.repo import GraphInterpRepo + +from spd.app.backend.compute import Edge, OptimizedPromptAttributionResult +from spd.autointerp.repo import InterpRepo + + +def _parse_node_key(key: str) -> tuple[str, int, int]: + """Parse 'h.3.attn.o_proj:11:361' -> ('h.3.attn.o_proj', 11, 361).""" + parts = key.split(":") + layer = ":".join(parts[:-2]) + seq = int(parts[-2]) + cidx = int(parts[-1]) + return layer, seq, cidx + + +def _concrete_to_canonical(layer: str) -> str: + """Map concrete model path to canonical address used by graph layout. + + h.3.attn.o_proj -> 3.attn.o + h.3.attn.v_proj -> 3.attn.v + h.3.attn.q_proj -> 3.attn.q + h.3.attn.k_proj -> 3.attn.k + h.3.mlp.c_fc -> 3.mlp.up + h.3.mlp.down_proj -> 3.mlp.down + h.3.mlp.gate_proj -> 3.glu.gate + h.3.mlp.up_proj -> 3.glu.up + wte -> embed + lm_head -> output + """ + if layer == "wte": + return "embed" + if layer == "lm_head": + return "output" + + PROJ_MAP = { + "q_proj": "q", + "k_proj": "k", + "v_proj": "v", + "o_proj": "o", + "c_fc": "up", + "down_proj": "down", + "gate_proj": "gate", + "up_proj": "up", + } + + # h.{block}.{sublayer}.{proj_name} + parts = layer.split(".") + assert len(parts) == 4, f"Expected h.N.sublayer.proj, got {layer!r}" + block_idx = parts[1] + sublayer = parts[2] # "attn" or "mlp" + proj_name = parts[3] + + canonical_proj = PROJ_MAP.get(proj_name) + assert canonical_proj is not None, f"Unknown projection: {proj_name!r} in {layer!r}" + + # Determine canonical sublayer + if sublayer == "attn": + canonical_sublayer = "attn" + elif sublayer == "mlp": + canonical_sublayer = "glu" if proj_name in ("gate_proj", "up_proj") else "mlp" + else: + canonical_sublayer = sublayer + + return f"{block_idx}.{canonical_sublayer}.{canonical_proj}" + + +def _edge_to_dict(e: Edge) -> dict[str, object]: + return { + "source": str(e.source), + "target": str(e.target), + "attribution": e.strength, + "is_cross_seq": e.is_cross_seq, + } + + +def _get_label( + layer: str, + cidx: int, + interp: InterpRepo | None, + graph_interp: GraphInterpRepo | None, +) -> str | None: + component_key = f"{layer}:{cidx}" + if graph_interp is not None: + unified = graph_interp.get_unified_label(component_key) + if unified is not None: + return unified.label + output = graph_interp.get_output_label(component_key) + if output is not None: + return output.label + if interp is not None: + ir = interp.get_interpretation(component_key) + if ir is not None: + return ir.label + return None + + +def export_circuit_json( + circuit: OptimizedPromptAttributionResult, + token_ids: list[int], + token_strings: list[str], + output_path: Path, + min_ci: float = 0.3, + min_edge_attr: float = 0.1, + interp: InterpRepo | None = None, + graph_interp: GraphInterpRepo | None = None, +) -> None: + """Export circuit to JSON for the standalone HTML renderer. + + Args: + circuit: The optimized circuit result from EditableModel.optimize_circuit. + token_ids: Raw token IDs for each position. + token_strings: Display strings for each position (from tok.get_spans). + output_path: Where to write the JSON. + min_ci: Minimum CI to include a node. + min_edge_attr: Minimum |attribution| to include an edge. + interp: Optional InterpRepo for autointerp labels. + graph_interp: Optional GraphInterpRepo for graph-interp labels. + """ + # Build tokens list + tokens = [ + {"pos": i, "id": tid, "string": tstr} + for i, (tid, tstr) in enumerate(zip(token_ids, token_strings, strict=True)) + ] + + # Build nodes from ci_vals, filtering by min_ci + nodes = [] + node_keys_kept = set() + for key, ci in circuit.node_ci_vals.items(): + if ci < min_ci: + continue + layer, seq, cidx = _parse_node_key(key) + canonical = _concrete_to_canonical(layer) + act = circuit.node_subcomp_acts.get(key, 0.0) + + label = _get_label(layer, cidx, interp, graph_interp) + + nodes.append( + { + "key": key, + "graph_key": f"{layer}:{cidx}", + "layer": layer, + "canonical": canonical, + "seq": seq, + "cidx": cidx, + "ci": round(ci, 4), + "activation": round(act, 4), + "token": token_strings[seq] if seq < len(token_strings) else "?", + "label": label, + } + ) + node_keys_kept.add(key) + + # Build edges, filtering by min_edge_attr and requiring both endpoints in kept nodes + edges = [] + for e in circuit.edges: + if abs(e.strength) < min_edge_attr: + continue + src_key = str(e.source) + tgt_key = str(e.target) + if src_key not in node_keys_kept or tgt_key not in node_keys_kept: + continue + edges.append(_edge_to_dict(e)) + + metrics: dict[str, float] = {"l0_total": circuit.metrics.l0_total} + if circuit.metrics.ci_masked_label_prob is not None: + metrics["ci_masked_label_prob"] = circuit.metrics.ci_masked_label_prob + if circuit.metrics.stoch_masked_label_prob is not None: + metrics["stoch_masked_label_prob"] = circuit.metrics.stoch_masked_label_prob + + data = { + "tokens": tokens, + "nodes": nodes, + "edges": edges, + "metrics": metrics, + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(data, f, indent=2) + + print(f"Exported {len(nodes)} nodes, {len(edges)} edges to {output_path}") diff --git a/scripts/migrate_harvest_data.py b/scripts/migrate_harvest_data.py new file mode 100644 index 000000000..c0da50dcf --- /dev/null +++ b/scripts/migrate_harvest_data.py @@ -0,0 +1,369 @@ +"""Migrate legacy harvest + autointerp data to the new layout. + +Copies data into new directories/DBs without modifying the originals. + +Legacy layout: + harvest//activation_contexts/harvest.db (schema: mean_ci, ci_values, component_acts) + harvest//correlations/*.pt + harvest//eval/intruder/*.jsonl + autointerp//interp.db (top-level, with scores already merged) + +New layout: + harvest//h-/harvest.db (schema: firing_density, mean_activations, firings, activations) + harvest//h-/*.pt + autointerp//a-/interp.db +""" + +import shutil +import sqlite3 +from dataclasses import dataclass, field +from pathlib import Path + +import fire +import orjson +import torch + +from spd.harvest.storage import TokenStatsStorage +from spd.settings import SPD_OUT_DIR + + +@dataclass +class RunDiagnostic: + run_id: str + harvest_db: bool = False + token_stats: bool = False + correlations: bool = False + n_components: int = 0 + n_tokens: int = 0 + ci_threshold: str = "" + intruder_scores: int = 0 + interp_db: bool = False + n_interpretations: int = 0 + interp_score_types: list[str] = field(default_factory=list) + already_migrated: bool = False + problems: list[str] = field(default_factory=list) + + @property + def ready(self) -> bool: + return not self.problems and not self.already_migrated + + +def diagnose_run(run_id: str, timestamp: str = "20260218_000000") -> RunDiagnostic: + harvest_root = SPD_OUT_DIR / "harvest" / run_id + autointerp_root = SPD_OUT_DIR / "autointerp" / run_id + diag = RunDiagnostic(run_id=run_id) + + # Already migrated? + if (harvest_root / f"h-{timestamp}").exists(): + diag.already_migrated = True + return diag + + # Harvest DB + old_db = harvest_root / "activation_contexts" / "harvest.db" + diag.harvest_db = old_db.exists() + if not diag.harvest_db: + diag.problems.append("missing harvest.db") + return diag + + conn = sqlite3.connect(f"file:{old_db}?immutable=1", uri=True) + diag.n_components = conn.execute("SELECT COUNT(*) FROM components").fetchone()[0] + threshold_row = conn.execute("SELECT value FROM config WHERE key = 'ci_threshold'").fetchone() + diag.ci_threshold = threshold_row[0] if threshold_row else "0.0 (default)" + conn.close() + + # Token stats + ts_path = harvest_root / "correlations" / "token_stats.pt" + diag.token_stats = ts_path.exists() + if diag.token_stats: + ts_data = torch.load(ts_path, weights_only=False) + diag.n_tokens = ts_data["n_tokens"] + else: + diag.problems.append("missing token_stats.pt") + + # Correlations + diag.correlations = (harvest_root / "correlations" / "component_correlations.pt").exists() + + # Intruder scores + intruder_dir = harvest_root / "eval" / "intruder" + if intruder_dir.exists(): + jsonl_files = list(intruder_dir.glob("*.jsonl")) + if jsonl_files: + largest = max(jsonl_files, key=lambda f: f.stat().st_size) + with open(largest, "rb") as f: + diag.intruder_scores = sum(1 for _ in f) + + # Autointerp + interp_db = autointerp_root / "interp.db" + diag.interp_db = interp_db.exists() + if diag.interp_db: + conn = sqlite3.connect(f"file:{interp_db}?immutable=1", uri=True) + tables = [ + r[0] + for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + if "interpretations" in tables: + diag.n_interpretations = conn.execute( + "SELECT COUNT(*) FROM interpretations" + ).fetchone()[0] + if "scores" in tables: + diag.interp_score_types = [ + r[0] for r in conn.execute("SELECT DISTINCT score_type FROM scores").fetchall() + ] + conn.close() + + return diag + + +def print_diagnostics(diagnostics: list[RunDiagnostic]) -> None: + ready = [d for d in diagnostics if d.ready] + skipped = [d for d in diagnostics if d.already_migrated] + blocked = [d for d in diagnostics if d.problems] + + if skipped: + print(f"Already migrated ({len(skipped)}): {', '.join(d.run_id for d in skipped)}\n") + + if blocked: + print(f"BLOCKED ({len(blocked)}):") + for d in blocked: + print(f" {d.run_id}: {', '.join(d.problems)}") + print() + + if ready: + print(f"Ready to migrate ({len(ready)}):") + print( + f"{'run_id':20s} {'comps':>6s} {'n_tokens':>14s} {'threshold':>10s} " + f"{'intruder':>8s} {'interps':>7s} {'scores':>20s} {'corr':>5s}" + ) + print("-" * 95) + for d in ready: + scores_str = ", ".join(d.interp_score_types) if d.interp_score_types else "-" + print( + f"{d.run_id:20s} {d.n_components:>6d} {d.n_tokens:>14,} {d.ci_threshold:>10s} " + f"{d.intruder_scores:>8d} {d.n_interpretations:>7d} {scores_str:>20s} " + f"{'yes' if d.correlations else 'no':>5s}" + ) + print() + + +def migrate_harvest_db(old_db_path: Path, new_db_path: Path, token_stats_path: Path) -> int: + """Copy harvest DB, transforming schema from legacy to new format. + + Old schema: mean_ci REAL, activation_examples with {token_ids, ci_values, component_acts} + New schema: firing_density REAL, mean_activations TEXT, + activation_examples with {token_ids, firings, activations} + """ + old_conn = sqlite3.connect(f"file:{old_db_path}?immutable=1", uri=True) + old_conn.row_factory = sqlite3.Row + + new_conn = sqlite3.connect(str(new_db_path)) + new_conn.execute("PRAGMA journal_mode=WAL") + new_conn.executescript("""\ + CREATE TABLE IF NOT EXISTS components ( + component_key TEXT PRIMARY KEY, + layer TEXT NOT NULL, + component_idx INTEGER NOT NULL, + firing_density REAL NOT NULL, + mean_activations TEXT NOT NULL, + activation_examples TEXT NOT NULL, + input_token_pmi TEXT NOT NULL, + output_token_pmi TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS scores ( + component_key TEXT NOT NULL, + score_type TEXT NOT NULL, + score REAL NOT NULL, + details TEXT NOT NULL, + PRIMARY KEY (component_key, score_type) + ); + """) + + # Load firing densities from token_stats.pt + assert token_stats_path.exists(), f"No token_stats.pt at {token_stats_path}" + ts = TokenStatsStorage.load(token_stats_path) + firing_density_map: dict[str, float] = {} + for i, key in enumerate(ts.component_keys): + firing_density_map[key] = ts.firing_counts[i].item() / ts.n_tokens + + # Read activation threshold from old DB config + threshold_row = old_conn.execute( + "SELECT value FROM config WHERE key = 'ci_threshold'" + ).fetchone() + activation_threshold = float(threshold_row["value"]) if threshold_row else 0.0 + + # Migrate config + for row in old_conn.execute("SELECT key, value FROM config").fetchall(): + key = row["key"] + value = row["value"] + if key == "ci_threshold": + key = "activation_threshold" + new_conn.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", (key, value)) + + # Migrate components row-by-row + n = 0 + rows = old_conn.execute("SELECT * FROM components").fetchall() + for row in rows: + old_examples = orjson.loads(row["activation_examples"]) + new_examples = [] + for ex in old_examples: + ci_values = ex["ci_values"] + component_acts = ex["component_acts"] + new_examples.append( + { + "token_ids": ex["token_ids"], + "firings": [v > activation_threshold for v in ci_values], + "activations": { + "causal_importance": ci_values, + "component_activation": component_acts, + }, + } + ) + + mean_ci = row["mean_ci"] + component_key = row["component_key"] + assert component_key in firing_density_map, f"{component_key} missing from token_stats.pt" + firing_density = firing_density_map[component_key] + mean_activations = {"causal_importance": mean_ci} + + new_conn.execute( + "INSERT OR REPLACE INTO components VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + component_key, + row["layer"], + row["component_idx"], + firing_density, + orjson.dumps(mean_activations).decode(), + orjson.dumps(new_examples).decode(), + row["input_token_pmi"], + row["output_token_pmi"], + ), + ) + n += 1 + + new_conn.commit() + + # Migrate intruder scores from JSONL into harvest DB scores table + old_intruder_dir = old_db_path.parent.parent / "eval" / "intruder" + if old_intruder_dir.exists(): + # Use the largest file (most complete run) + jsonl_files = sorted(old_intruder_dir.glob("*.jsonl"), key=lambda f: f.stat().st_size) + if jsonl_files: + intruder_file = jsonl_files[-1] + n_scores = 0 + with open(intruder_file, "rb") as f: + for line in f: + record = orjson.loads(line) + new_conn.execute( + "INSERT OR REPLACE INTO scores VALUES (?, ?, ?, ?)", + ( + record["component_key"], + "intruder", + record["score"], + orjson.dumps(record.get("trials", [])).decode(), + ), + ) + n_scores += 1 + new_conn.commit() + print(f" Migrated {n_scores} intruder scores from {intruder_file.name}") + + old_conn.close() + new_conn.close() + return n + + +def migrate_autointerp_db(old_db_path: Path, new_db_path: Path) -> int: + """Copy autointerp DB (schema is compatible, just copy and strip intruder scores).""" + shutil.copy2(old_db_path, new_db_path) + + # Remove intruder scores — those belong in harvest.db in the new layout + conn = sqlite3.connect(str(new_db_path)) + conn.execute("DELETE FROM scores WHERE score_type = 'intruder'") + conn.commit() + n = conn.execute("SELECT COUNT(*) FROM interpretations").fetchone()[0] + conn.close() + return n + + +def migrate_run(run_id: str, timestamp: str = "20260218_000000") -> None: + """Migrate a single run's harvest + autointerp data to the new layout.""" + harvest_root = SPD_OUT_DIR / "harvest" / run_id + autointerp_root = SPD_OUT_DIR / "autointerp" / run_id + + print(f"Migrating {run_id}...") + + # --- Harvest --- + old_harvest_db = harvest_root / "activation_contexts" / "harvest.db" + assert old_harvest_db.exists(), f"No legacy harvest DB at {old_harvest_db}" + + new_subrun = harvest_root / f"h-{timestamp}" + assert not new_subrun.exists(), f"Target already exists: {new_subrun}" + new_subrun.mkdir(parents=True) + + token_stats_path = harvest_root / "correlations" / "token_stats.pt" + assert token_stats_path.exists(), f"No token_stats.pt at {token_stats_path}" + print(f" Harvest DB: {old_harvest_db} -> {new_subrun / 'harvest.db'}") + n_components = migrate_harvest_db(old_harvest_db, new_subrun / "harvest.db", token_stats_path) + print(f" Migrated {n_components} components") + + # Copy .pt files + old_corr_dir = harvest_root / "correlations" + for pt_file in ["component_correlations.pt", "token_stats.pt"]: + src = old_corr_dir / pt_file + if src.exists(): + dst = new_subrun / pt_file + shutil.copy2(src, dst) + print(f" Copied {pt_file}") + + # --- Autointerp --- + old_interp_db = autointerp_root / "interp.db" + if old_interp_db.exists(): + new_autointerp_subrun = autointerp_root / f"a-{timestamp}" + assert not new_autointerp_subrun.exists(), f"Target already exists: {new_autointerp_subrun}" + new_autointerp_subrun.mkdir(parents=True) + + print(f" Interp DB: {old_interp_db} -> {new_autointerp_subrun / 'interp.db'}") + n_interps = migrate_autointerp_db(old_interp_db, new_autointerp_subrun / "interp.db") + print(f" Migrated {n_interps} interpretations (detection + fuzzing scores preserved)") + else: + print(" No top-level interp.db found, skipping autointerp") + + print("Done!") + + +def _find_legacy_run_ids() -> list[str]: + harvest_root = SPD_OUT_DIR / "harvest" + return sorted( + d.name + for d in harvest_root.iterdir() + if (d / "activation_contexts" / "harvest.db").exists() + ) + + +def diagnose(run_id: str | None = None, timestamp: str = "20260218_000000") -> None: + """Print diagnostic report without migrating anything.""" + run_ids = [run_id] if run_id else _find_legacy_run_ids() + diagnostics = [diagnose_run(rid, timestamp) for rid in run_ids] + print_diagnostics(diagnostics) + + +def migrate_all(timestamp: str = "20260218_000000", dry_run: bool = False) -> None: + """Discover and migrate all legacy harvest runs.""" + run_ids = _find_legacy_run_ids() + diagnostics = [diagnose_run(rid, timestamp) for rid in run_ids] + + print_diagnostics(diagnostics) + + ready = [d for d in diagnostics if d.ready] + if dry_run or not ready: + return + + for d in ready: + migrate_run(d.run_id, timestamp=timestamp) + print(f"\nAll done — migrated {len(ready)} runs") + + +if __name__ == "__main__": + fire.Fire({"run": migrate_run, "all": migrate_all, "diagnose": diagnose}) diff --git a/scripts/parse_transformer_circuits_post.py b/scripts/parse_transformer_circuits_post.py new file mode 100644 index 000000000..1f528179e --- /dev/null +++ b/scripts/parse_transformer_circuits_post.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +"""Export a Transformer Circuits post to local Markdown + downloaded image assets. + +Example: + .venv/bin/python scripts/parse_transformer_circuits_post.py \ + --url https://transformer-circuits.pub/2025/attribution-graphs/biology.html \ + --output-md papers/biology_source/biology.md \ + --assets-dir papers/biology_source/assets +""" + +from __future__ import annotations + +import argparse +import re +from collections.abc import Iterable +from pathlib import Path +from urllib.parse import urljoin, urlparse + +import requests +from bs4 import BeautifulSoup, NavigableString, Tag + +DEFAULT_URL = "https://transformer-circuits.pub/2025/attribution-graphs/biology.html" +USER_AGENT = "spd-biology-exporter/1.0" + + +def normalize_whitespace(text: str) -> str: + return re.sub(r"\s+", " ", text).strip() + + +def should_insert_space(left: str, right: str) -> bool: + if not left or not right: + return False + left = left.rstrip() + right = right.lstrip() + if not left or not right: + return False + if left.endswith((" ", "\n", "\t", "/", "(", "[", "{", "-", "“", '"', "'")): + return False + if right.startswith( + (" ", "\n", "\t", ".", ",", ":", ";", "!", "?", ")", "]", "}", "-", "”", '"', "'") + ): + return False + if left.endswith(" "): + return False + left_char = left[-1] + right_char = right[0] + return bool( + re.match(r"[A-Za-z0-9\]\)]", left_char) and re.match(r"[A-Za-z0-9\[\(]", right_char) + ) + + +class Exporter: + def __init__( + self, + *, + base_url: str, + output_md: Path, + assets_dir: Path, + download_assets: bool = True, + timeout: float = 30.0, + ) -> None: + self.base_url = base_url + self.output_md = output_md + self.assets_dir = assets_dir + self.download_assets = download_assets + self.timeout = timeout + self.session = requests.Session() + self.session.headers.update({"User-Agent": USER_AGENT}) + self.asset_map: dict[str, str] = {} + self.asset_counter = 0 + + def fetch_html(self) -> str: + response = self.session.get(self.base_url, timeout=self.timeout) + response.raise_for_status() + # Distill pages sometimes produce mojibake via default charset guessing. + return response.content.decode("utf-8", errors="replace") + + def _relative_asset_path(self, local_path: Path) -> str: + return local_path.relative_to(self.output_md.parent).as_posix() + + def _unique_asset_filename(self, remote_url: str) -> str: + parsed = urlparse(remote_url) + basename = Path(parsed.path).name or f"asset_{self.asset_counter}" + if "." not in basename: + basename = f"{basename}.bin" + candidate = basename + while (self.assets_dir / candidate).exists(): + self.asset_counter += 1 + stem = Path(basename).stem + suffix = Path(basename).suffix + candidate = f"{stem}_{self.asset_counter}{suffix}" + return candidate + + def download_asset(self, src: str) -> str: + if not src: + return "" + remote_url = urljoin(self.base_url, src) + if remote_url in self.asset_map: + return self.asset_map[remote_url] + if remote_url.startswith("data:"): + return remote_url + local_rel = remote_url + if self.download_assets: + self.assets_dir.mkdir(parents=True, exist_ok=True) + preferred = Path(urlparse(remote_url).path).name + if preferred and "." in preferred and (self.assets_dir / preferred).exists(): + local_path = self.assets_dir / preferred + else: + filename = self._unique_asset_filename(remote_url) + local_path = self.assets_dir / filename + response = self.session.get(remote_url, timeout=self.timeout) + response.raise_for_status() + local_path.write_bytes(response.content) + local_rel = self._relative_asset_path(local_path) + self.asset_map[remote_url] = local_rel + return local_rel + + def render_inline(self, node: Tag | NavigableString) -> str: + if isinstance(node, NavigableString): + return str(node) + if not isinstance(node, Tag): + return "" + name = node.name.lower() + if name == "br": + return " \n" + if name in {"b", "strong"}: + return f"**{self.render_children_inline(node.children)}**" + if name in {"i", "em"}: + return f"*{self.render_children_inline(node.children)}*" + if name == "code": + text = normalize_whitespace(self.render_children_inline(node.children)) + return f"`{text}`" if text else "" + if name == "a": + href = (node.get("href") or "").strip() + text = normalize_whitespace(self.render_children_inline(node.children)) + if not text: + text = href + if not href: + return text + full_href = urljoin(self.base_url, href) + return f"[{text}]({full_href})" + if name == "d-cite": + cite = normalize_whitespace(self.render_children_inline(node.children)) + return f"[{cite}]" if cite else "[citation]" + if name == "d-footnote": + note = normalize_whitespace(self.render_children_inline(node.children)) + return f" (Footnote: {note}) " if note else "" + return self.render_children_inline(node.children) + + def render_children_inline(self, nodes: Iterable[Tag | NavigableString]) -> str: + pieces = [self.render_inline(child) for child in nodes] + raw = "" + for piece in pieces: + if not piece: + continue + if raw and should_insert_space(raw, piece): + raw += " " + raw += piece + # Keep explicit markdown line breaks but normalize other whitespace. + parts = raw.split(" \n") + return " \n".join(normalize_whitespace(part) for part in parts) + + def render_list(self, list_tag: Tag, level: int = 0) -> list[str]: + lines: list[str] = [] + ordered = list_tag.name.lower() == "ol" + counter = 1 + indent = " " * level + for li in list_tag.find_all("li", recursive=False): + inline_nodes: list[Tag | NavigableString] = [] + nested_lists: list[Tag] = [] + for child in li.children: + if isinstance(child, Tag) and child.name and child.name.lower() in {"ul", "ol"}: + nested_lists.append(child) + else: + inline_nodes.append(child) + text = normalize_whitespace(self.render_children_inline(inline_nodes)) + marker = f"{counter}." if ordered else "-" + if text: + lines.append(f"{indent}{marker} {text}") + else: + lines.append(f"{indent}{marker}") + for nested in nested_lists: + lines.extend(self.render_list(nested, level + 1)) + counter += 1 + lines.append("") + return lines + + def render_table(self, table: Tag) -> list[str]: + rows: list[list[str]] = [] + for tr in table.find_all("tr"): + row: list[str] = [] + cells = tr.find_all(["th", "td"]) + for cell in cells: + row.append(normalize_whitespace(self.render_children_inline(cell.children))) + if row: + rows.append(row) + if not rows: + return [] + width = max(len(row) for row in rows) + padded = [row + [""] * (width - len(row)) for row in rows] + header = padded[0] + sep = ["---"] * width + lines = [ + "| " + " | ".join(header) + " |", + "| " + " | ".join(sep) + " |", + ] + for row in padded[1:]: + lines.append("| " + " | ".join(row) + " |") + lines.append("") + return lines + + def render_figure(self, figure: Tag) -> list[str]: + lines: list[str] = [] + imgs = figure.find_all("img") + for img in imgs: + src = (img.get("src") or "").strip() + alt = normalize_whitespace(img.get("alt") or "") + local_src = self.download_asset(src) + if local_src: + lines.append(f"![{alt}]({local_src})") + caption = figure.find("figcaption") + if caption: + caption_text = normalize_whitespace(self.render_children_inline(caption.children)) + if caption_text: + lines.append(f"_Figure: {caption_text}_") + if lines: + lines.append("") + return lines + + def render_block(self, node: Tag | NavigableString) -> list[str]: + if isinstance(node, NavigableString): + text = normalize_whitespace(str(node)) + return [text, ""] if text else [] + if not isinstance(node, Tag): + return [] + name = node.name.lower() + if name in {"style", "script"}: + return [] + if name in {"h1", "h2", "h3", "h4", "h5", "h6"}: + level = int(name[1]) + text = normalize_whitespace(self.render_children_inline(node.children)) + if not text: + return [] + anchor = node.get("id") + anchor_suffix = f" " if anchor else "" + return [f"{'#' * level} {text}{anchor_suffix}", ""] + if name == "p": + text = normalize_whitespace(self.render_children_inline(node.children)) + return [text, ""] if text else [] + if name in {"ul", "ol"}: + return self.render_list(node) + if name == "figure": + return self.render_figure(node) + if name == "table": + return self.render_table(node) + if name == "hr": + return ["---", ""] + if name == "br": + return [""] + if name == "d-contents": + return [] + if name in {"div", "section", "nav", "d-appendix", "d-article"}: + lines: list[str] = [] + for child in node.children: + lines.extend(self.render_block(child)) + return lines + text = normalize_whitespace(self.render_children_inline(node.children)) + return [text, ""] if text else [] + + def export(self, include_appendix: bool = True) -> tuple[str, int]: + html = self.fetch_html() + soup = BeautifulSoup(html, "html.parser") + title = normalize_whitespace(soup.title.string if soup.title else "") or "Untitled" + article = soup.find("d-article") + if article is None: + raise RuntimeError("Could not find in the page") + lines: list[str] = [ + f"# {title}", + "", + f"Source: {self.base_url}", + "", + "> Auto-generated by scripts/parse_transformer_circuits_post.py", + "", + ] + for child in article.children: + lines.extend(self.render_block(child)) + if include_appendix: + appendix = soup.find("d-appendix") + if appendix is not None: + appendix_lines: list[str] = [] + for child in appendix.children: + appendix_lines.extend(self.render_block(child)) + appendix_content = [line for line in appendix_lines if line.strip()] + if appendix_content: + lines.extend(["## Appendix", ""]) + lines.extend(appendix_lines) + # Strip trailing whitespace and collapse excessive blank lines. + cleaned: list[str] = [] + blank_run = 0 + for line in lines: + stripped = line.rstrip() + if not stripped: + blank_run += 1 + if blank_run <= 1: + cleaned.append("") + else: + blank_run = 0 + cleaned.append(stripped) + markdown = "\n".join(cleaned).strip() + "\n" + self.output_md.parent.mkdir(parents=True, exist_ok=True) + self.output_md.write_text(markdown, encoding="utf-8") + return markdown, len(self.asset_map) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--url", default=DEFAULT_URL, help="Post URL to parse") + parser.add_argument( + "--output-md", + default="papers/biology_source/biology.md", + help="Path for generated markdown", + ) + parser.add_argument( + "--assets-dir", + default="papers/biology_source/assets", + help="Directory for downloaded assets", + ) + parser.add_argument( + "--skip-appendix", + action="store_true", + help="Do not include d-appendix content", + ) + parser.add_argument( + "--skip-download-assets", + action="store_true", + help="Keep remote asset links instead of downloading files", + ) + parser.add_argument("--timeout", type=float, default=30.0, help="HTTP timeout in seconds") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + output_md = Path(args.output_md).resolve() + assets_dir = Path(args.assets_dir).resolve() + exporter = Exporter( + base_url=args.url, + output_md=output_md, + assets_dir=assets_dir, + download_assets=not args.skip_download_assets, + timeout=args.timeout, + ) + _, asset_count = exporter.export(include_appendix=not args.skip_appendix) + print(f"Wrote markdown: {output_md}") + if args.skip_download_assets: + print("Assets were not downloaded (--skip-download-assets set).") + else: + print(f"Downloaded/linked assets: {asset_count}") + print(f"Assets dir: {assets_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/render_circuit_html.py b/scripts/render_circuit_html.py new file mode 100644 index 000000000..677c9cb57 --- /dev/null +++ b/scripts/render_circuit_html.py @@ -0,0 +1,69 @@ +"""Render a circuit JSON file as a self-contained HTML page. + +This is a thin wrapper that copies the circuit.html template and embeds the JSON +data inline, so the result is a single self-contained HTML file (no separate data file needed). + +Usage: + python scripts/render_circuit_html.py data/king_circuit.json -o circuit_standalone.html + +Or from Python: + from scripts.render_circuit_html import render_circuit_html + render_circuit_html(Path("data/king_circuit.json"), Path("circuit_standalone.html")) +""" + +import argparse +import json +from pathlib import Path + + +def render_circuit_html(json_path: Path, output_path: Path, title: str = "Circuit Graph") -> None: + """Render a circuit JSON as a self-contained HTML file. + + Takes the interactive circuit.html template and replaces the fetch() call + with inline data, producing a single portable HTML file. + """ + with open(json_path) as f: + data = json.load(f) + + # Read the template + template_path = Path(__file__).parent.parent / "scripts" / "_circuit_template.html" + if not template_path.exists(): + # Fallback: read from www + from spd.settings import SPD_OUT_DIR + + template_path = SPD_OUT_DIR / "www" / "pile-editing" / "circuit.html" + + assert template_path.exists(), f"Template not found: {template_path}" + template = template_path.read_text() + + # Replace the fetch with inline data + inline_js = f"data = {json.dumps(data)}; init();" + template = template.replace( + "fetch(DATA_URL)\n" + " .then(r => { if (!r.ok) throw new Error(`HTTP ${r.status}`); return r.json(); })\n" + " .then(d => { data = d; init(); })\n" + " .catch(e => { document.getElementById('stats').textContent = `Error: ${e.message}`; });", + inline_js, + ) + + # Update title + template = template.replace( + 'Circuit Graph — King → "he"', f"{title}" + ) + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + f.write(template) + + print(f"Wrote self-contained HTML to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Render circuit JSON as self-contained HTML") + parser.add_argument("json_path", type=Path, help="Path to circuit JSON file") + parser.add_argument( + "-o", "--output", type=Path, default=Path("circuit.html"), help="Output HTML path" + ) + parser.add_argument("--title", default="Circuit Graph", help="Page title") + args = parser.parse_args() + render_circuit_html(args.json_path, args.output, args.title) diff --git a/scripts/test_abs_grad_trick.py b/scripts/test_abs_grad_trick.py new file mode 100644 index 000000000..b3d15d296 --- /dev/null +++ b/scripts/test_abs_grad_trick.py @@ -0,0 +1,150 @@ +"""Verify that ∂|y|/∂x = sign(y) · ∂y/∂x for a scalar y, even through nonlinearities. + +The chain rule: ∂|y|/∂x = (d|y|/dy) · (∂y/∂x) = sign(y) · ∂y/∂x + +This holds regardless of what nonlinear computation sits between x and y, +because ∂y/∂x already accounts for all intermediate nonlinearities. +The sign(y) factor is just the outermost link in the chain. +""" + +import torch +from torch import nn + + +def test_simple_linear(): + """Linear: y = Wx, trivial case.""" + x = torch.randn(5, requires_grad=True) + W = torch.randn(3, 5) + y_vec = W @ x + y = y_vec[1] # pick one scalar + + grad = torch.autograd.grad(y, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] + grad_trick = y.sign() * grad + + assert torch.allclose(grad_abs, grad_trick, atol=1e-7), ( + f"FAIL: {(grad_abs - grad_trick).abs().max()}" + ) + print(f" linear: max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓") + + +def test_deep_nonlinear(): + """Deep net with ReLU, tanh, and GELU — representative of a transformer.""" + torch.manual_seed(42) + net = nn.Sequential( + nn.Linear(8, 16), + nn.ReLU(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.GELU(), + nn.Linear(16, 4), + ) + x = torch.randn(8, requires_grad=True) + y_vec = net(x) + y = y_vec[2] # scalar output + + grad = torch.autograd.grad(y, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] + grad_trick = y.sign() * grad + + assert torch.allclose(grad_abs, grad_trick, atol=1e-6), ( + f"FAIL: {(grad_abs - grad_trick).abs().max()}" + ) + print( + f" deep nonlinear (y={y.item():.4f}): max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓" + ) + + +def test_negative_target(): + """Ensure it works when y < 0 (sign flips the gradient).""" + torch.manual_seed(99) + net = nn.Sequential(nn.Linear(4, 8), nn.Tanh(), nn.Linear(8, 1)) + # Find an input that gives negative output + for _seed in range(200): + x = torch.randn(4, requires_grad=True) + y = net(x).squeeze() + if y.item() < -0.1: + break + assert y.item() < 0, "Couldn't find negative output" + + grad = torch.autograd.grad(y, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] + grad_trick = y.sign() * grad + + assert torch.allclose(grad_abs, grad_trick, atol=1e-6), ( + f"FAIL: {(grad_abs - grad_trick).abs().max()}" + ) + print( + f" negative target (y={y.item():.4f}): max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓" + ) + + +def test_multiple_inputs(): + """Multiple input tensors (mirrors the app's in_post_detaches list).""" + torch.manual_seed(7) + x1 = torch.randn(3, 4, requires_grad=True) + x2 = torch.randn(3, 4, requires_grad=True) + + # Nonlinear function of both inputs + h = torch.relu(x1) + torch.tanh(x2) + y = (h @ torch.randn(4, 1)).sum() # scalar + + grads = torch.autograd.grad(y, [x1, x2], retain_graph=True) + grads_abs = torch.autograd.grad(y.abs(), [x1, x2], retain_graph=True) + + for i, (g, g_abs) in enumerate(zip(grads, grads_abs, strict=True)): + g_trick = y.sign() * g + assert torch.allclose(g_abs, g_trick, atol=1e-6), ( + f"FAIL input {i}: {(g_abs - g_trick).abs().max()}" + ) + + print(f" multiple inputs (y={y.item():.4f}): all match ✓") + + +def test_sum_of_abs_DOES_NOT_work(): + """Show that the trick FAILS for sum-of-abs (dataset attributions case). + + ∂(Σ|y_i|)/∂x ≠ sign(Σy_i) · ∂(Σy_i)/∂x + because each y_i has a different sign. + """ + torch.manual_seed(42) + x = torch.randn(4, requires_grad=True) + W = torch.randn(3, 4) + y_vec = W @ x # [3] + + target_signed = y_vec.sum() + target_abs = y_vec.abs().sum() + + grad_signed = torch.autograd.grad(target_signed, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(target_abs, x, retain_graph=True)[0] + + # The WRONG trick: use sign of the sum + grad_wrong = target_signed.sign() * grad_signed + + # The correct per-element version + grad_correct = sum( + y_vec[i].sign() * torch.autograd.grad(y_vec[i], x, retain_graph=True)[0] + for i in range(len(y_vec)) + ) + + wrong_diff = (grad_abs - grad_wrong).abs().max() + correct_diff = (grad_abs - grad_correct).abs().max() + print( + f" sum-of-abs: wrong trick diff = {wrong_diff:.4f}, correct per-element diff = {correct_diff:.2e}" + ) + assert wrong_diff > 0.01, "Expected the wrong trick to fail for sum-of-abs" + assert correct_diff < 1e-6, "Per-element version should match" + print(" → confirms: trick works for scalar y, NOT for sum-of-abs ✓") + + +if __name__ == "__main__": + print("Testing ∂|y|/∂x = sign(y) · ∂y/∂x for scalar y:\n") + test_simple_linear() + test_deep_nonlinear() + test_negative_target() + test_multiple_inputs() + print() + print("Testing that the trick does NOT work for sum-of-abs:\n") + test_sum_of_abs_DOES_NOT_work() + print("\nAll tests passed.") diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 0f95d54a8..42f745de6 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -4,17 +4,18 @@ Web-based visualization and analysis tool for exploring neural network component - **Backend**: Python FastAPI (`backend/`) - **Frontend**: Svelte 5 + TypeScript (`frontend/`) -- **Database**: SQLite at `.data/app/prompt_attr.db` (relative to repo root) +- **Database**: SQLite at `SPD_OUT_DIR/app/prompt_attr.db` (shared across team via NFS) - **TODOs**: See `TODO.md` for open work items ## Project Context This is a **rapidly iterated research tool**. Key implications: -- **Please do not code for backwards compatibility**: Schema changes don't need migrations, expect state can be deleted, etc. -- **Database is disposable**: Delete `.data/app/prompt_attr.db` if schema changes break things +- **Please do not code for backwards compatibility**: Schema changes don't need migrations +- **Database is shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, accessible by multiple backends. Do not delete without checking with the team. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert +- **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. ## Running the App @@ -34,14 +35,14 @@ This launches both backend (FastAPI/uvicorn) and frontend (Vite) dev servers. backend/ ├── server.py # FastAPI app, CORS, routers ├── state.py # Singleton StateManager + HarvestRepo (lazy-loaded harvest data) -├── compute.py # Core attribution computation +├── compute.py # Core attribution computation + intervention evaluation ├── app_tokenizer.py # AppTokenizer: wraps HF tokenizers for display/encoding ├── (topology lives at spd/topology.py — TransformerTopology) ├── schemas.py # Pydantic API models ├── dependencies.py # FastAPI dependency injection ├── utils.py # Logging/timing utilities ├── database.py # SQLite interface -├── optim_cis.py # Sparse CI optimization +├── optim_cis.py # Sparse CI optimization, loss configs, PGD └── routers/ ├── runs.py # Load W&B runs + GET /api/model_info ├── graphs.py # Compute attribution graphs @@ -50,6 +51,9 @@ backend/ ├── intervention.py # Selective component activation ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering + ├── dataset_search.py # SimpleStories dataset search + ├── agents.py # Various useful endpoints that AI agents should look at when helping + ├── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ├── dataset_search.py # Dataset search (reads dataset from run config) └── agents.py # Various useful endpoints that AI agents should look at when helping ``` @@ -90,7 +94,7 @@ frontend/src/ ├── ActivationContextsPagedTable.svelte ├── DatasetSearchTab.svelte # Dataset search UI ├── DatasetSearchResults.svelte - ├── ClusterPathInput.svelte # Cluster path selector + ├── ClusterPathInput.svelte # Cluster path selector (dropdown populated from registry.ts) ├── ComponentProbeInput.svelte # Component probe UI ├── TokenHighlights.svelte # Token highlighting ├── prompt-attr/ @@ -154,8 +158,13 @@ Edge(source: Node, target: Node, strength: float, is_cross_seq: bool) # strength = gradient * activation # is_cross_seq = True for k/v → o_proj (attention pattern) -PromptAttributionResult(edges: list[Edge], output_probs: Tensor[seq, vocab], node_ci_vals: dict[str, float]) -# node_ci_vals maps "layer:seq:c_idx" → CI value +PromptAttributionResult(edges, ci_masked_out_logits, target_out_logits, node_ci_vals, node_subcomp_acts) + +TokenPrediction(token, token_id, prob, logit, target_prob, target_logit) + +InterventionResult(input_tokens, ci, stochastic, adversarial, ci_loss, stochastic_loss, adversarial_loss) +# ci/stochastic/adversarial are list[list[TokenPrediction]] (per-position top-k) +# losses are evaluated using the graph's implied loss context ``` ### Frontend Types (`promptAttributionsTypes.ts`) @@ -211,13 +220,31 @@ Finds sparse CI mask that: - Minimizes L0 (active component count) - Uses importance minimality + CE loss (or KL loss) -### Intervention Forward +### Interventions (`compute.py → compute_intervention`) + +A single unified function evaluates a node selection under three masking regimes: + +- **CI**: mask = selection (binary on/off) +- **Stochastic**: mask = selection + (1-selection) × Uniform(0,1) +- **Adversarial**: PGD optimizes alive-but-unselected components to maximize loss; non-alive get Uniform(0,1) + +Returns `InterventionResult` with top-k `TokenPrediction`s per position for each regime, plus per-regime loss values. + +**Loss context**: Every graph has an implied loss that interventions evaluate against: -`compute_intervention_forward()`: +- **Standard/manual graphs** → `MeanKLLossConfig` (mean KL divergence from target across all positions) +- **Optimized graphs** → the graph's optimization loss (CE for a specific token at a position, or KL at a position) -1. Build component masks (all zeros) -2. Set mask=1.0 for selected nodes -3. Forward pass → top-k predictions per position +This loss is used for two things: (1) what PGD maximizes during adversarial evaluation, and (2) the `ci_loss`/`stochastic_loss`/`adversarial_loss` metrics reported in `InterventionResult`. + +**Alive masks**: `compute_intervention` recomputes the model's natural CI (one forward pass + `calc_causal_importances`) and binarizes at 0 to get alive masks. This ensures the alive set is always the full model's CI — not the graph's potentially sparse optimized CI. PGD can only manipulate alive-but-unselected components. + +**Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, +`adv_pgd_step_size`) are a _training_ regularizer — they make CI optimization robust. The PGD in +`compute_intervention` is an _eval_ metric — it measures worst-case performance for a given node +selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). + +**Base intervention run**: Created automatically during graph computation. Uses all interventable nodes with CI > 0. Persisted as an `intervention_run` so predictions are available synchronously. --- @@ -245,9 +272,14 @@ POST /api/graphs ### Intervention ``` -POST /api/intervention {text, nodes: ["h.0.attn.q_proj:3:5", ...]} - → compute_intervention_forward() - ← InterventionResponse with top-k predictions +POST /api/intervention/run {graph_id, selected_nodes, top_k, adv_pgd} + → compute_intervention(active_nodes, graph_alive_masks, loss_config) + ← InterventionRunSummary {id, selected_nodes, result: InterventionResult} + +InterventionResult = { + input_tokens, ci, stochastic, adversarial, // TokenPrediction[][] per regime + ci_loss, stochastic_loss, adversarial_loss // loss under each regime +} ``` ### Component Correlations & Interpretations @@ -281,14 +313,14 @@ GET /api/dataset/results?page=1&page_size=20 ## Database Schema -Located at `.data/app/prompt_attr.db`. Delete this file if schema changes cause issues. +Located at `SPD_OUT_DIR/app/prompt_attr.db` (shared via NFS). Uses DELETE journal mode with `fcntl.flock` write locking for safe concurrent access from multiple backends. -| Table | Key | Purpose | -| ------------------ | ---------------------------------- | ------------------------------------------------- | -| `runs` | `wandb_path` | W&B run references | -| `prompts` | `(run_id, context_length)` | Token sequences | -| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + output probs + node CI values | -| `intervention_runs`| `graph_id` | Saved intervention results | +| Table | Key | Purpose | +| ------------------- | ---------------------------------- | -------------------------------------------------------- | +| `runs` | `wandb_path` | W&B run references | +| `prompts` | `(run_id, context_length)` | Token sequences | +| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + CI/target logits + node CI values | +| `intervention_runs` | `graph_id` | Saved `InterventionResult` JSON (single `result` column) | Note: Activation contexts, correlations, token stats, and interpretations are loaded from pre-harvested data at `SPD_OUT_DIR/{harvest,autointerp}/` (see `spd/harvest/` and `spd/autointerp/`). diff --git a/spd/app/TODO.md b/spd/app/TODO.md index a3c7eb1aa..21a8ba1fd 100644 --- a/spd/app/TODO.md +++ b/spd/app/TODO.md @@ -1,3 +1,172 @@ -# App TODOs +# App Backend Review & Action Items -- Audit SQLite access pragma stuff — `immutable=1` in `HarvestDB` causes "database disk image is malformed" errors when the app reads a harvest DB mid-write (WAL not yet checkpointed). Investigate whether to check for WAL file existence, use normal locking mode, or add another safeguard. See `spd/harvest/db.py:79`. +Review date: 2026-03-04. Scope: `spd/app/backend/` — all Python files. + +Context: the app is a **researcher-first local tool** (frontend + backend launched together, opened in browser). Errors should be loud, silent failures absent, the prompt DB is deletable short-term state, no backwards compatibility needed. + +## Overview + +The backend is ~6,500 lines across 18 Python files. The core architecture (FastAPI + SQLite + singleton state + SSE streaming) is sound. The main concerns are: a few real bugs, accumulated dead code, some silent failures that violate the "loud errors" principle, and a few design seams where complexity hides. + +### File size inventory + +| File | Lines | Risk | +|---|---|---| +| `routers/mcp.py` | 1637 | High — mixed concerns, largest file | +| `routers/graphs.py` | 1036 | Medium — streaming complexity | +| `compute.py` | 920 | Low — core algorithm, well-structured | +| `database.py` | 827 | Medium — manual serialization | +| `optim_cis.py` | 504 | Low | +| `routers/dataset_search.py` | 473 | Medium — hardcoded dataset names | +| `routers/correlations.py` | 386 | Low | +| `routers/graph_interp.py` | 373 | Low | +| `routers/investigations.py` | 317 | Low | +| `routers/pretrain_info.py` | 246 | Low | +| `server.py` | 212 | Low — clean | +| `routers/activation_contexts.py` | 207 | Low | +| `routers/runs.py` | 191 | Low | +| `routers/dataset_attributions.py` | 170 | Low | +| `routers/intervention.py` | 169 | Low — clean | +| `state.py` | 132 | Low — clean | +| `app_tokenizer.py` | 119 | Low | +| `routers/prompts.py` | 115 | Low | + +--- + +## Bugs + +### 1. `dataset_search.py:262` — KeyError on tokenized results + +`get_tokenized_results` accesses `result["story"]` but `search_dataset` stores results with key `"text"` (line 137: `results.append({"text": text, ...})`). This will crash with `KeyError: 'story'` whenever tokenized results are requested. + +**Fix:** Change line 262 from `result["story"]` to `result["text"]`. Also line 287: the metadata exclusion list references `"story"` — should be `"text"`. + +### 2. `dataset_search.py` — random endpoints hardcode SimpleStories + +`get_random_samples` (line 336) and `get_random_samples_with_loss` (line 415) both hardcode `load_dataset("lennart-finke/SimpleStories", ...)` and access `item_dict["story"]`. Since primary models are now Pile-trained, these endpoints are broken for current research. They also don't use `DepLoadedRun` to get the dataset name from the run config like `search_dataset` does. + +**Fix:** Make them take `DepLoadedRun`, read `task_config.dataset_name` and `task_config.column_name`, and use those instead of hardcoded values. Or, if the random endpoints aren't used with Pile models, consider deleting them. + +--- + +## Dead code to delete + +### 3. `ForkedInterventionRunRecord` + `forked_intervention_runs` table + +`database.py:117-125` defines `ForkedInterventionRunRecord`. Lines 256-265 create the `forked_intervention_runs` table. Lines 744-827 implement 3 CRUD methods (`save_forked_intervention_run`, `get_forked_intervention_runs`, `delete_forked_intervention_run`). No router references any of these — the fork endpoints were removed. Delete all of it. + +Files: `database.py` + +### 4. `optim_cis.py:500-504` — `get_out_dir()` never called + +Dead utility function that creates a local `out/` directory. Nothing references it. + +Files: `optim_cis.py` + +### 5. Unused schemas in `graphs.py:188-209` + +`ComponentStats`, `PromptSearchQuery`, and `PromptSearchResponse` are defined but no endpoint uses them. They appear to be leftovers from a removed prompt search feature. The `PromptPreview` in `graphs.py:114` also duplicates the one in `prompts.py:25`. + +Files: `routers/graphs.py` + +### 6. `spd/app/TODO.md` was empty + +(This file — now repurposed for this review.) + +--- + +## Design issues + +### 7. `OptimizationParams` mixes config inputs with computed outputs + +`database.py:69-82` — Fields like `imp_min_coeff`, `steps`, `pnorm`, `beta` are optimization *inputs*. Fields like `ci_masked_label_prob`, `stoch_masked_label_prob`, `adv_pgd_label_prob` are computed *outputs*. These metrics are mutated in-place after construction in `graphs.py:759-761`. + +This makes the object's contract unclear — is it immutable config or mutable state? + +**Suggestion:** Either nest the metrics in a sub-model (`metrics: OptimMetrics | None`), or at minimum stop mutating after construction (compute the metrics before constructing `OptimizationParams`). + +### 8. `StoredGraph.id = -1` sentinel value + +`database.py:90` uses `-1` as "unsaved graph". If a graph is accidentally used before being saved, that `-1` leaks into API responses or DB queries. `id: int | None = None` is more honest and lets the type system catch misuse. + +### 9. GPU lock accessed two different ways + +- `graphs.py:603,844` — `stream_computation(work, manager._gpu_lock)` reaches into the private lock directly +- `intervention.py:86` — `with manager.gpu_lock():` uses the context manager + +The stream pattern is inherently different (hold lock across SSE generator lifetime), but accessing `_gpu_lock` directly breaks encapsulation. + +**Suggestion:** Add a `stream_with_gpu_lock(work)` method on `StateManager` that encapsulates the lock acquisition + SSE streaming pattern. Then `graphs.py` calls `manager.stream_with_gpu_lock(work)` instead of reaching into privates. + +### 10. `load_run` returns untyped dicts + +`runs.py:96,139` returns `{"status": "loaded", "run_id": ...}` and `{"status": "already_loaded", ...}`. No response model, so the frontend has no type-safe contract for this endpoint. + +**Fix:** Define a `LoadRunResponse(BaseModel)` with `status`, `run_id`, `wandb_path`. + +### 11. Edge truncation is invisible to the user + +`graphs.py:903` logs a warning when edges exceed `GLOBAL_EDGE_LIMIT = 50_000` and are truncated, but this info only goes to server logs. The researcher never sees it. + +**Fix:** Add `edges_truncated: bool` (or `total_edge_count: int`) to `GraphData` so the frontend can show a notice. + +### 12. Module-level `DEVICE = get_device()` in multiple files + +`graphs.py:266`, `intervention.py:48`, `dataset_search.py`, `prompts.py:18` all call `get_device()` at import time. Fine in practice but makes testing and non-GPU imports impossible. + +**Suggestion:** Move to a function call or lazily-evaluated property when/if this becomes a testing bottleneck. Low priority. + +### 13. `_GRAPH_INTERP_MOCK_MODE` cross-router import + +`runs.py:13` imports `MOCK_MODE` from `routers/graph_interp.py` and uses it in the status endpoint (line 174). The TODO comment says to remove it. This cross-router dependency for a mock flag should be cleaned up — the mock mode should either be a config flag on `StateManager` or deleted entirely. + +--- + +## Silent failure patterns (violate "loud errors" principle) + +### 14. `compute.py:79-86` — output node capping is silent + +`compute_layer_alive_info` caps output nodes to `MAX_OUTPUT_NODES_PER_POS = 15` per position without any logging or indication. If a researcher has >15 high-probability output tokens at a position, they silently lose some. + +At minimum, log when capping occurs. + +### 15. `correlations.py:291,302` — token stats returns `None` silently + +`get_component_token_stats` returns `None` when token stats haven't been harvested. This means the endpoint returns a `200 null` response, which the frontend has to special-case. An explicit 404 with a message is more honest. + +### 16. `correlations.py:112,260` — interpretations/intruder scores return `{}` silently + +`get_all_interpretations` and `get_intruder_scores` return empty dicts when data isn't available. This is defensible for bulk endpoints (the frontend can check emptiness), but it means the researcher has no way to distinguish "no interpretations exist" from "interpretations not yet generated." Consider logging or adding a `has_interpretations` flag to `LoadedRun`. + +Note: `LoadedRun.autointerp_available` already partially addresses this. But the endpoints themselves don't use it — they independently check `loaded.interp is None`. + +--- + +## Lower priority / nice-to-haves + +### 17. `extract_node_ci_vals` Python double loop + +`compute.py:640-648` iterates every `(seq_pos, component_idx)` pair in Python. For large models (39K components × 512 seq), this is a lot of Python overhead. Could be vectorized to only extract non-zero entries. + +### 18. `database.py` manual graph get-or-create race + +Lines 528-539: catches `IntegrityError` on manual graph save, then re-queries. There's a small race window between the failed insert and the re-query. Acceptable for a single-user local app but worth noting. + +### 19. `mcp.py` is 1637 lines + +The MCP router is the largest file, mixing tool definitions, implementation logic, and JSON-RPC handling. It has module-level global state (`_investigation_config`). This file would benefit from being split, but it's also likely to be rewritten when MCP tooling matures, so the ROI of refactoring now is debatable. + +--- + +## Suggested priority order for implementation + +1. Fix `result["story"]` KeyError (bug #1) — 2 min +2. Delete dead code (items #3-5) — 10 min +3. Fix random dataset endpoints or delete if unused (#2) — 15 min +4. Add `edges_truncated` to GraphData (#11) — 10 min +5. Type the `load_run` response (#10) — 5 min +6. Clean up `_GRAPH_INTERP_MOCK_MODE` (#13) — 5 min +7. Deduplicate `MAX_OUTPUT_NODES_PER_POS` (#5 partial) — 2 min +8. `StoredGraph.id` sentinel → `None` (#8) — 10 min +9. Split `OptimizationParams` (#7) — 20 min +10. GPU lock encapsulation (#9) — 15 min diff --git a/spd/app/backend/app_tokenizer.py b/spd/app/backend/app_tokenizer.py index 0d79cd9ba..acfa4d7eb 100644 --- a/spd/app/backend/app_tokenizer.py +++ b/spd/app/backend/app_tokenizer.py @@ -53,6 +53,12 @@ def vocab_size(self) -> int: assert isinstance(size, int) return size + @property + def eos_token_id(self) -> int: + eos = self._tok.eos_token_id + assert isinstance(eos, int) + return eos + def encode(self, text: str) -> list[int]: return self._tok.encode(text, add_special_tokens=False) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 8992e0e06..c99de5a02 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -12,12 +12,26 @@ import torch from jaxtyping import Bool, Float +from pydantic import BaseModel from torch import Tensor, nn from spd.app.backend.app_tokenizer import AppTokenizer -from spd.app.backend.optim_cis import OptimCIConfig, OptimizationMetrics, optimize_ci_values +from spd.app.backend.optim_cis import ( + AdvPGDConfig, + CELossConfig, + CISnapshotCallback, + LogitLossConfig, + LossConfig, + OptimCIConfig, + OptimizationMetrics, + compute_recon_loss, + optimize_ci_values, + optimize_ci_values_batched, + run_adv_pgd, +) from spd.configs import SamplingType from spd.log import logger +from spd.metrics.pgd_utils import interpolate_pgd_mask from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.topology import TransformerTopology @@ -115,6 +129,7 @@ class PromptAttributionResult: """Result of computing prompt attributions for a prompt.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -128,6 +143,7 @@ class OptimizedPromptAttributionResult: """Result of computing prompt attributions with optimized CI values.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -135,7 +151,6 @@ class OptimizedPromptAttributionResult: node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val node_subcomp_acts: dict[str, float] # layer:seq:c_idx -> subcomponent activation (v_i^T @ a) metrics: OptimizationMetrics # Final loss metrics from optimization - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None # Adversarial PGD output logits ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) @@ -168,17 +183,22 @@ def _compute_edges_for_target( cache: dict[str, Tensor], loss_seq_pos: int, topology: TransformerTopology, -) -> list[Edge]: +) -> tuple[list[Edge], list[Edge]]: """Compute all edges flowing into a single target layer. For each alive (s_out, c_out) in the target layer, computes gradient-based - attribution strengths from all alive source components. + attribution strengths from all alive source components. Computes both signed + (∂y/∂x · x) and absolute-target (∂|y|/∂x · x) variants. Args: loss_seq_pos: Maximum sequence position to include (inclusive). Only compute edges for target positions <= loss_seq_pos. + + Returns: + (edges, edges_abs): Signed and absolute-target edge lists. """ edges: list[Edge] = [] + edges_abs: list[Edge] = [] out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{target}_pre_detach"] in_post_detaches: list[Float[Tensor, "1 s C"]] = [ cache[f"{source}_post_detach"] for source in sources @@ -190,11 +210,19 @@ def _compute_edges_for_target( continue for c_out in s_out_alive_c: + target_val = out_pre_detach[0, s_out, c_out] grads = torch.autograd.grad( - outputs=out_pre_detach[0, s_out, c_out], + outputs=target_val, inputs=in_post_detaches, retain_graph=True, ) + # ∂|y|/∂x = sign(y) · ∂y/∂x — avoids a second backward pass. + # This works because target_val is a single scalar. In dataset_attributions/ + # harvester.py, the target is sum(|y_i|) over batch+seq — there each y_i has a + # different sign, so you can't factor out one scalar. The issue isn't the chain + # rule (sign·grad is always valid per-element), it's that abs breaks the + # grad(sum)=sum(grad) trick that makes the batch reduction a single backward pass. + target_sign = target_val.sign() with torch.no_grad(): canonical_target = topology.target_to_canon(target) for source, source_info, grad, in_post_detach in zip( @@ -203,27 +231,35 @@ def _compute_edges_for_target( canonical_source = topology.target_to_canon(source) is_cross_seq = topology.is_cross_seq_pair(canonical_source, canonical_target) weighted: Float[Tensor, "s C"] = (grad * in_post_detach)[0] + weighted_abs: Float[Tensor, "s C"] = weighted * target_sign if canonical_source == "embed": weighted = weighted.sum(dim=1, keepdim=True) + weighted_abs = weighted_abs.sum(dim=1, keepdim=True) s_in_range = range(s_out + 1) if is_cross_seq else [s_out] for s_in in s_in_range: for c_in in source_info.alive_c_idxs: if not source_info.alive_mask[s_in, c_in]: continue + src = Node(layer=canonical_source, seq_pos=s_in, component_idx=c_in) + tgt = Node(layer=canonical_target, seq_pos=s_out, component_idx=c_out) edges.append( Edge( - source=Node( - layer=canonical_source, seq_pos=s_in, component_idx=c_in - ), - target=Node( - layer=canonical_target, seq_pos=s_out, component_idx=c_out - ), + source=src, + target=tgt, strength=weighted[s_in, c_in].item(), is_cross_seq=is_cross_seq, ) ) - return edges + edges_abs.append( + Edge( + source=src, + target=tgt, + strength=weighted_abs[s_in, c_in].item(), + is_cross_seq=is_cross_seq, + ) + ) + return edges, edges_abs def compute_edges_from_ci( @@ -330,12 +366,13 @@ def compute_edges_from_ci( # Compute edges for each target layer t0 = time.perf_counter() edges: list[Edge] = [] + edges_abs: list[Edge] = [] total_source_layers = sum(len(sources) for sources in sources_by_target.values()) progress_count = 0 for target, sources in sources_by_target.items(): t_target = time.perf_counter() - target_edges = _compute_edges_for_target( + target_edges, target_edges_abs = _compute_edges_for_target( target=target, sources=sources, target_info=alive_info[target], @@ -345,6 +382,7 @@ def compute_edges_from_ci( topology=topology, ) edges.extend(target_edges) + edges_abs.extend(target_edges_abs) canonical_target = topology.target_to_canon(target) logger.info( f"[perf] {canonical_target}: {time.perf_counter() - t_target:.2f}s, " @@ -375,6 +413,7 @@ def compute_edges_from_ci( return PromptAttributionResult( edges=edges, + edges_abs=edges_abs, ci_masked_out_probs=ci_masked_out_probs[0, : loss_seq_pos + 1], ci_masked_out_logits=ci_masked_logits[0, : loss_seq_pos + 1], target_out_probs=target_out_probs[0, : loss_seq_pos + 1], @@ -508,6 +547,7 @@ def compute_prompt_attributions_optimized( output_prob_threshold: float, device: str, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizedPromptAttributionResult: """Compute prompt attributions using optimized sparse CI values. @@ -528,6 +568,7 @@ def compute_prompt_attributions_optimized( config=optim_config, device=device, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) ci_outputs = optim_result.params.create_ci_outputs(model, device) @@ -557,13 +598,9 @@ def compute_prompt_attributions_optimized( loss_seq_pos=loss_seq_pos, ) - # Slice adversarial logits to match the loss_seq_pos range - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None - if optim_result.adv_pgd_out_logits is not None: - adv_pgd_out_logits = optim_result.adv_pgd_out_logits[: loss_seq_pos + 1] - return OptimizedPromptAttributionResult( edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_probs=result.ci_masked_out_probs, ci_masked_out_logits=result.ci_masked_out_logits, target_out_probs=result.target_out_probs, @@ -571,10 +608,81 @@ def compute_prompt_attributions_optimized( node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, metrics=optim_result.metrics, - adv_pgd_out_logits=adv_pgd_out_logits, ) +def compute_prompt_attributions_optimized_batched( + model: ComponentModel, + topology: TransformerTopology, + tokens: Float[Tensor, "1 seq"], + sources_by_target: dict[str, list[str]], + configs: list[OptimCIConfig], + output_prob_threshold: float, + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizedPromptAttributionResult]: + """Compute prompt attributions for multiple sparsity coefficients in one batched optimization.""" + with torch.no_grad(), bf16_autocast(): + target_logits = model(tokens) + target_out_probs = torch.softmax(target_logits, dim=-1) + + optim_results = optimize_ci_values_batched( + model=model, + tokens=tokens, + configs=configs, + device=device, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + if on_progress is not None: + on_progress(0, len(optim_results), "graph") + + with torch.no_grad(), bf16_autocast(): + pre_weight_acts = model(tokens, cache_type="input").cache + + loss_seq_pos = configs[0].loss_config.position + + results: list[OptimizedPromptAttributionResult] = [] + for i, optim_result in enumerate(optim_results): + ci_outputs = optim_result.params.create_ci_outputs(model, device) + + result = compute_edges_from_ci( + model=model, + topology=topology, + tokens=tokens, + ci_lower_leaky=ci_outputs.lower_leaky, + pre_weight_acts=pre_weight_acts, + sources_by_target=sources_by_target, + target_out_probs=target_out_probs, + target_out_logits=target_logits, + output_prob_threshold=output_prob_threshold, + device=device, + on_progress=on_progress, + loss_seq_pos=loss_seq_pos, + ) + + results.append( + OptimizedPromptAttributionResult( + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_probs=result.ci_masked_out_probs, + ci_masked_out_logits=result.ci_masked_out_logits, + target_out_probs=result.target_out_probs, + target_out_logits=result.target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + metrics=optim_result.metrics, + ) + ) + + if on_progress is not None: + on_progress(i + 1, len(optim_results), "graph") + + return results + + @dataclass class CIOnlyResult: """Result of computing CI values only (no attribution graph).""" @@ -666,94 +774,248 @@ def extract_node_subcomp_acts( return node_subcomp_acts -@dataclass -class InterventionResult: - """Result of intervention forward pass.""" +class TokenPrediction(BaseModel): + """A single token prediction with probability.""" + + token: str + token_id: int + prob: float + logit: float + target_prob: float + target_logit: float + + +class LabelPredictions(BaseModel): + """Prediction stats for the CE label token at the optimized position, per masking regime.""" + + position: int + ci: TokenPrediction + stochastic: TokenPrediction + adversarial: TokenPrediction + ablated: TokenPrediction | None + + +class InterventionResult(BaseModel): + """Unified result of an intervention evaluation under multiple masking regimes.""" input_tokens: list[str] - predictions_per_position: list[ - list[tuple[str, int, float, float, float, float]] - ] # [(token, id, spd_prob, logit, target_prob, target_logit)] + ci: list[list[TokenPrediction]] + stochastic: list[list[TokenPrediction]] + adversarial: list[list[TokenPrediction]] + ablated: list[list[TokenPrediction]] | None + ci_loss: float + stochastic_loss: float + adversarial_loss: float + ablated_loss: float | None + label: LabelPredictions | None + + +# Default eval PGD settings (distinct from optimization PGD which is a training regularizer) +DEFAULT_EVAL_PGD_CONFIG = AdvPGDConfig(n_steps=4, step_size=1.0, init="random") + + +def _extract_topk_predictions( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + top_k: int, +) -> list[list[TokenPrediction]]: + """Extract top-k token predictions per position, paired with target probs.""" + probs = torch.softmax(logits, dim=-1) + target_probs = torch.softmax(target_logits, dim=-1) + result: list[list[TokenPrediction]] = [] + for pos in range(probs.shape[1]): + top_vals, top_ids = torch.topk(probs[0, pos], top_k) + pos_preds: list[TokenPrediction] = [] + for p, tid_t in zip(top_vals, top_ids, strict=True): + tid = int(tid_t.item()) + pos_preds.append( + TokenPrediction( + token=tokenizer.get_tok_display(tid), + token_id=tid, + prob=float(p.item()), + logit=float(logits[0, pos, tid].item()), + target_prob=float(target_probs[0, pos, tid].item()), + target_logit=float(target_logits[0, pos, tid].item()), + ) + ) + result.append(pos_preds) + return result + + +def _extract_label_prediction( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + position: int, + label_token: int, +) -> TokenPrediction: + """Extract the prediction for a specific token at a specific position.""" + probs = torch.softmax(logits[0, position], dim=-1) + target_probs = torch.softmax(target_logits[0, position], dim=-1) + return TokenPrediction( + token=tokenizer.get_tok_display(label_token), + token_id=label_token, + prob=float(probs[label_token].item()), + logit=float(logits[0, position, label_token].item()), + target_prob=float(target_probs[label_token].item()), + target_logit=float(target_logits[0, position, label_token].item()), + ) -def compute_intervention_forward( +def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], - active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] - top_k: int, + active_nodes: list[tuple[str, int, int]], + nodes_to_ablate: list[tuple[str, int, int]] | None, tokenizer: AppTokenizer, + adv_pgd_config: AdvPGDConfig, + loss_config: LossConfig, + sampling: SamplingType, + top_k: int, ) -> InterventionResult: - """Forward pass with only specified nodes active. + """Unified intervention evaluation: CI, stochastic, adversarial, and optionally ablated. Args: - model: ComponentModel to run intervention on. - tokens: Input tokens of shape [1, seq]. - active_nodes: List of (layer, seq_pos, component_idx) tuples specifying which nodes to activate. + active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. + Used for CI, stochastic, and adversarial masking. + nodes_to_ablate: If provided, nodes to ablate in ablated (full target model minus these). + The frontend computes this as all_graph_nodes - selected_nodes. + If None, ablated is skipped. + loss_config: Loss for PGD adversary to maximize and for reporting metrics. + sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. - tokenizer: Tokenizer for decoding tokens. - - Returns: - InterventionResult with input tokens and top-k predictions per position. """ - seq_len = tokens.shape[1] device = tokens.device - # Build component masks: all zeros, then set 1s for active nodes - component_masks: dict[str, Float[Tensor, "1 seq C"]] = {} - for layer_name, C in model.module_to_c.items(): - component_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) + # Compute natural CI alive masks (the model's own binarized CI, independent of graph) + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = { + k: v > 0 for k, v in ci_outputs.lower_leaky.items() + } + # Build binary CI masks from active nodes (selected = 1, rest = 0) + ci_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name, C in model.module_to_c.items(): + ci_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) for layer, seq_pos, c_idx in active_nodes: - assert layer in component_masks, f"Layer {layer} not in model" - assert 0 <= seq_pos < seq_len, f"seq_pos {seq_pos} out of bounds [0, {seq_len})" - assert 0 <= c_idx < model.module_to_c[layer], ( - f"component_idx {c_idx} out of bounds [0, {model.module_to_c[layer]})" + ci_masks[layer][0, seq_pos, c_idx] = 1.0 + assert alive_masks[layer][0, seq_pos, c_idx], ( + f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) - component_masks[layer][0, seq_pos, c_idx] = 1.0 - mask_infos = make_mask_infos(component_masks, routing_masks="all") + with torch.no_grad(), bf16_autocast(): + # Target forward (unmasked) + target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) + + # CI forward (binary mask) + ci_mask_infos = make_mask_infos(ci_masks, routing_masks="all") + ci_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ci_mask_infos) + + # Stochastic forward: ci + (1-ci) * uniform + stoch_masks = { + layer: ci_masks[layer] + (1 - ci_masks[layer]) * torch.rand_like(ci_masks[layer]) + for layer in ci_masks + } + stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") + stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) + + # Target-sans forward (only if nodes_to_ablate provided) + ts_logits: Float[Tensor, "1 seq vocab"] | None = None + if nodes_to_ablate is not None: + ts_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name in ci_masks: + ts_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) + for layer, seq_pos, c_idx in nodes_to_ablate: + ts_masks[layer][0, seq_pos, c_idx] = 0.0 + weight_deltas = model.calc_weight_deltas() + ts_wd = { + k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() + } + ts_mask_infos = make_mask_infos( + ts_masks, routing_masks="all", weight_deltas_and_masks=ts_wd + ) + ts_logits = model(tokens, mask_infos=ts_mask_infos) + # Adversarial: PGD optimizes alive-but-unselected components + adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=ci_masks, + alive_masks=alive_masks, + adv_config=adv_pgd_config, + target_out=target_logits, + loss_config=loss_config, + ) + # Non-alive positions get uniform random fill + adv_masks = interpolate_pgd_mask(ci_masks, adv_sources) + with torch.no_grad(): + for layer in adv_masks: + non_alive = ~alive_masks[layer] + adv_masks[layer][non_alive] = torch.rand(int(non_alive.sum().item()), device=device) with torch.no_grad(), bf16_autocast(): - # SPD model forward pass (with component masks) - spd_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=mask_infos) - spd_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(spd_logits, dim=-1) + adv_mask_infos = make_mask_infos(adv_masks, routing_masks="all") + adv_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=adv_mask_infos) + + # Extract predictions and loss metrics + device_str = str(device) + with torch.no_grad(): + ci_preds = _extract_topk_predictions(ci_logits, target_logits, tokenizer, top_k) + stoch_preds = _extract_topk_predictions(stoch_logits, target_logits, tokenizer, top_k) + adv_preds = _extract_topk_predictions(adv_logits, target_logits, tokenizer, top_k) + + ci_loss = float( + compute_recon_loss(ci_logits, loss_config, target_logits, device_str).item() + ) + stoch_loss = float( + compute_recon_loss(stoch_logits, loss_config, target_logits, device_str).item() + ) + adv_loss = float( + compute_recon_loss(adv_logits, loss_config, target_logits, device_str).item() + ) - # Target model forward pass (no masks) - target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) - target_out_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(target_logits, dim=-1) - - # Get top-k predictions per position (based on SPD model's top-k) - predictions_per_position: list[list[tuple[str, int, float, float, float, float]]] = [] - for pos in range(seq_len): - pos_spd_probs = spd_probs[0, pos] - pos_spd_logits = spd_logits[0, pos] - pos_target_out_probs = target_out_probs[0, pos] - pos_target_logits = target_logits[0, pos] - top_probs, top_ids = torch.topk(pos_spd_probs, top_k) - - pos_predictions: list[tuple[str, int, float, float, float, float]] = [] - for spd_prob, token_id in zip(top_probs, top_ids, strict=True): - tid = int(token_id.item()) - token_str = tokenizer.get_tok_display(tid) - target_prob = float(pos_target_out_probs[tid].item()) - target_logit = float(pos_target_logits[tid].item()) - pos_predictions.append( - ( - token_str, - tid, - float(spd_prob.item()), - float(pos_spd_logits[tid].item()), - target_prob, - target_logit, - ) + ts_preds: list[list[TokenPrediction]] | None = None + ts_loss: float | None = None + if ts_logits is not None: + ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) + ts_loss = float( + compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() ) - predictions_per_position.append(pos_predictions) - # Decode input tokens + label: LabelPredictions | None = None + if isinstance(loss_config, CELossConfig | LogitLossConfig): + pos, tid = loss_config.position, loss_config.label_token + ts_label = ( + _extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid) + if ts_logits is not None + else None + ) + label = LabelPredictions( + position=pos, + ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), + stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), + adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), + ablated=ts_label, + ) + input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) return InterventionResult( input_tokens=input_tokens, - predictions_per_position=predictions_per_position, + ci=ci_preds, + stochastic=stoch_preds, + adversarial=adv_preds, + ablated=ts_preds, + ci_loss=ci_loss, + stochastic_loss=stoch_loss, + adversarial_loss=adv_loss, + ablated_loss=ts_loss, + label=label, ) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 6b2b09552..61be17ad5 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -6,10 +6,13 @@ Interpretations are stored separately at SPD_OUT_DIR/autointerp//. """ +import fcntl import hashlib import io import json +import os import sqlite3 +from contextlib import contextmanager from pathlib import Path from typing import Literal @@ -17,14 +20,35 @@ from pydantic import BaseModel from spd.app.backend.compute import Edge, Node -from spd.app.backend.optim_cis import CELossConfig, KLLossConfig, LossConfig, MaskType -from spd.settings import REPO_ROOT +from spd.app.backend.optim_cis import ( + CELossConfig, + KLLossConfig, + LogitLossConfig, + MaskType, + PositionalLossConfig, +) +from spd.settings import SPD_OUT_DIR GraphType = Literal["standard", "optimized", "manual"] -# Persistent data directories -_APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = SPD_OUT_DIR / "app" / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path. + + Checks env vars in order: + 1. SPD_INVESTIGATION_DIR - investigation mode, db at dir/app.db + 2. SPD_APP_DB_PATH - explicit override + 3. Default: SPD_OUT_DIR/app/prompt_attr.db + """ + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + return Path(investigation_dir) / "app.db" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH class Run(BaseModel): @@ -43,6 +67,11 @@ class PromptRecord(BaseModel): is_custom: bool = False +class PgdConfig(BaseModel): + n_steps: int + step_size: float + + class OptimizationParams(BaseModel): """Optimization parameters that affect graph computation.""" @@ -51,9 +80,12 @@ class OptimizationParams(BaseModel): pnorm: float beta: float mask_type: MaskType - loss: LossConfig - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + loss: PositionalLossConfig + pgd: PgdConfig | None = None + # Computed metrics (persisted for display on reload) + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + adv_pgd_label_prob: float | None = None class StoredGraph(BaseModel): @@ -66,9 +98,11 @@ class StoredGraph(BaseModel): # Core graph data (all types) edges: list[Edge] + edges_abs: list[Edge] | None = ( + None # absolute-target variant (∂|y|/∂x · x), None for old graphs + ) ci_masked_out_logits: torch.Tensor # [seq, vocab] target_out_logits: torch.Tensor # [seq, vocab] - adv_pgd_out_logits: torch.Tensor | None = None # [seq, vocab] adversarial PGD logits node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val (required for all graphs) node_subcomp_acts: dict[str, float] = {} # layer:seq:c_idx -> subcomp act (v_i^T @ a) @@ -85,17 +119,17 @@ class InterventionRunRecord(BaseModel): id: int graph_id: int selected_nodes: list[str] # node keys that were selected - result_json: str # JSON-encoded InterventionResponse + result_json: str # JSON-encoded InterventionResult created_at: str class ForkedInterventionRunRecord(BaseModel): - """A forked intervention run with modified tokens.""" + """A forked intervention run with modified tokens (currently unused).""" id: int intervention_run_id: int token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result_json: str # JSON-encoded InterventionResponse + result_json: str created_at: str @@ -111,7 +145,8 @@ class PromptAttrDB: """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() + self._lock_path = self.db_path.with_suffix(".db.lock") self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None @@ -135,6 +170,16 @@ def __enter__(self) -> "PromptAttrDB": def __exit__(self, *args: object) -> None: self.close() + @contextmanager + def _write_lock(self): + """Acquire an exclusive file lock for write operations (NFS-safe).""" + with open(self._lock_path, "w") as lock_fd: + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + yield + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + # ------------------------------------------------------------------------- # Schema initialization # ------------------------------------------------------------------------- @@ -142,7 +187,7 @@ def __exit__(self, *args: object) -> None: def init_schema(self) -> None: """Initialize the database schema. Safe to call multiple times.""" conn = self._get_conn() - conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA journal_mode=DELETE") conn.execute("PRAGMA foreign_keys=ON") conn.executescript(""" CREATE TABLE IF NOT EXISTS runs ( @@ -178,12 +223,19 @@ def init_schema(self) -> None: adv_pgd_n_steps INTEGER, adv_pgd_step_size REAL, + -- Optimization metrics (NULL for non-optimized graphs) + ci_masked_label_prob REAL, + stoch_masked_label_prob REAL, + adv_pgd_label_prob REAL, + -- Manual graph params (NULL for non-manual graphs) included_nodes TEXT, -- JSON array of node keys in this graph included_nodes_hash TEXT, -- SHA256 hash of sorted JSON for uniqueness -- The actual graph data (JSON) edges_data TEXT NOT NULL, + -- Absolute-target edges (∂|y|/∂x · x), NULL for old graphs + edges_data_abs TEXT, -- Node CI values: "layer:seq:c_idx" -> ci_val (required for all graphs) node_ci_vals TEXT NOT NULL, -- Node subcomponent activations: "layer:seq:c_idx" -> v_i^T @ a @@ -216,7 +268,7 @@ def init_schema(self) -> None: id INTEGER PRIMARY KEY AUTOINCREMENT, graph_id INTEGER NOT NULL REFERENCES graphs(id), selected_nodes TEXT NOT NULL, -- JSON array of node keys - result TEXT NOT NULL, -- JSON InterventionResponse + result TEXT NOT NULL, -- JSON InterventionResult created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); @@ -234,6 +286,12 @@ def init_schema(self) -> None: CREATE INDEX IF NOT EXISTS idx_forked_intervention_runs_parent ON forked_intervention_runs(intervention_run_id); """) + + # Migration: add edges_data_abs column if missing (backwards compat with existing DBs) + columns = {row[1] for row in conn.execute("PRAGMA table_info(graphs)").fetchall()} + if "edges_data_abs" not in columns: + conn.execute("ALTER TABLE graphs ADD COLUMN edges_data_abs TEXT") + conn.commit() # ------------------------------------------------------------------------- @@ -242,15 +300,16 @@ def init_schema(self) -> None: def create_run(self, wandb_path: str) -> int: """Create a new run. Returns the run ID.""" - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO runs (wandb_path) VALUES (?)", - (wandb_path,), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO runs (wandb_path) VALUES (?)", + (wandb_path,), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_run_by_wandb_path(self, wandb_path: str) -> Run | None: """Get a run by its wandb path.""" @@ -308,19 +367,20 @@ def add_custom_prompt( Returns: The prompt ID (existing or newly created). """ - existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) - if existing_id is not None: - return existing_id + with self._write_lock(): + existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) + if existing_id is not None: + return existing_id - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", - (run_id, json.dumps(token_ids), context_length), - ) - prompt_id = cursor.lastrowid - assert prompt_id is not None - conn.commit() - return prompt_id + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", + (run_id, json.dumps(token_ids), context_length), + ) + prompt_id = cursor.lastrowid + assert prompt_id is not None + conn.commit() + return prompt_id def get_prompt(self, prompt_id: int) -> PromptRecord | None: """Get a prompt by ID.""" @@ -384,24 +444,26 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: "component_idx": n.component_idx, } - edges_json = json.dumps( - [ - { - "source": _node_to_dict(e.source), - "target": _node_to_dict(e.target), - "strength": e.strength, - "is_cross_seq": e.is_cross_seq, - } - for e in graph.edges - ] - ) + def _edges_to_json(edges: list[Edge]) -> str: + return json.dumps( + [ + { + "source": _node_to_dict(e.source), + "target": _node_to_dict(e.target), + "strength": e.strength, + "is_cross_seq": e.is_cross_seq, + } + for e in edges + ] + ) + + edges_json = _edges_to_json(graph.edges) + edges_abs_json = _edges_to_json(graph.edges_abs) if graph.edges_abs is not None else None buf = io.BytesIO() logits_dict: dict[str, torch.Tensor] = { "ci_masked": graph.ci_masked_out_logits, "target": graph.target_out_logits, } - if graph.adv_pgd_out_logits is not None: - logits_dict["adv_pgd"] = graph.adv_pgd_out_logits torch.save(logits_dict, buf) output_logits_blob = buf.getvalue() node_ci_vals_json = json.dumps(graph.node_ci_vals) @@ -417,6 +479,9 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: loss_config_hash: str | None = None adv_pgd_n_steps = None adv_pgd_step_size = None + ci_masked_label_prob = None + stoch_masked_label_prob = None + adv_pgd_label_prob = None if graph.optimization_params: imp_min_coeff = graph.optimization_params.imp_min_coeff @@ -426,8 +491,15 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: mask_type = graph.optimization_params.mask_type loss_config_json = graph.optimization_params.loss.model_dump_json() loss_config_hash = hashlib.sha256(loss_config_json.encode()).hexdigest() - adv_pgd_n_steps = graph.optimization_params.adv_pgd_n_steps - adv_pgd_step_size = graph.optimization_params.adv_pgd_step_size + adv_pgd_n_steps = ( + graph.optimization_params.pgd.n_steps if graph.optimization_params.pgd else None + ) + adv_pgd_step_size = ( + graph.optimization_params.pgd.step_size if graph.optimization_params.pgd else None + ) + ci_masked_label_prob = graph.optimization_params.ci_masked_label_prob + stoch_masked_label_prob = graph.optimization_params.stoch_masked_label_prob + adv_pgd_label_prob = graph.optimization_params.adv_pgd_label_prob # Extract manual-specific values (NULL for non-manual graphs) # Sort included_nodes and compute hash for reliable uniqueness @@ -437,64 +509,70 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: included_nodes_json = json.dumps(sorted(graph.included_nodes)) included_nodes_hash = hashlib.sha256(included_nodes_json.encode()).hexdigest() - try: - cursor = conn.execute( - """INSERT INTO graphs - (prompt_id, graph_type, - imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, loss_config_hash, - adv_pgd_n_steps, adv_pgd_step_size, - included_nodes, included_nodes_hash, - edges_data, output_logits, node_ci_vals, node_subcomp_acts) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - prompt_id, - graph.graph_type, - imp_min_coeff, - steps, - pnorm, - beta, - mask_type, - loss_config_json, - loss_config_hash, - adv_pgd_n_steps, - adv_pgd_step_size, - included_nodes_json, - included_nodes_hash, - edges_json, - output_logits_blob, - node_ci_vals_json, - node_subcomp_acts_json, - ), - ) - conn.commit() - graph_id = cursor.lastrowid - assert graph_id is not None - return graph_id - except sqlite3.IntegrityError as e: - match graph.graph_type: - case "standard": - raise ValueError( - f"Standard graph already exists for prompt_id={prompt_id}. " - "Use get_graphs() to retrieve existing graph or delete it first." - ) from e - case "optimized": - raise ValueError( - f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." - ) from e - case "manual": - # Get-or-create semantics: return existing graph ID - conn.rollback() - row = conn.execute( - """SELECT id FROM graphs - WHERE prompt_id = ? AND graph_type = 'manual' - AND included_nodes_hash = ?""", - (prompt_id, included_nodes_hash), - ).fetchone() - if row: - return row["id"] - # Should not happen if constraint triggered - raise ValueError("A manual graph with the same nodes already exists.") from e + with self._write_lock(): + try: + cursor = conn.execute( + """INSERT INTO graphs + (prompt_id, graph_type, + imp_min_coeff, steps, pnorm, beta, mask_type, + loss_config, loss_config_hash, + adv_pgd_n_steps, adv_pgd_step_size, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob, + included_nodes, included_nodes_hash, + edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + prompt_id, + graph.graph_type, + imp_min_coeff, + steps, + pnorm, + beta, + mask_type, + loss_config_json, + loss_config_hash, + adv_pgd_n_steps, + adv_pgd_step_size, + ci_masked_label_prob, + stoch_masked_label_prob, + adv_pgd_label_prob, + included_nodes_json, + included_nodes_hash, + edges_json, + edges_abs_json, + output_logits_blob, + node_ci_vals_json, + node_subcomp_acts_json, + ), + ) + conn.commit() + graph_id = cursor.lastrowid + assert graph_id is not None + return graph_id + except sqlite3.IntegrityError as e: + match graph.graph_type: + case "standard": + raise ValueError( + f"Standard graph already exists for prompt_id={prompt_id}. " + "Use get_graphs() to retrieve existing graph or delete it first." + ) from e + case "optimized": + raise ValueError( + f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." + ) from e + case "manual": + conn.rollback() + row = conn.execute( + """SELECT id FROM graphs + WHERE prompt_id = ? AND graph_type = 'manual' + AND included_nodes_hash = ?""", + (prompt_id, included_nodes_hash), + ).fetchone() + if row: + return row["id"] + raise ValueError( + "A manual graph with the same nodes already exists." + ) from e def _row_to_stored_graph(self, row: sqlite3.Row) -> StoredGraph: """Convert a database row to a StoredGraph.""" @@ -506,19 +584,22 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: component_idx=int(d["component_idx"]), ) - edges = [ - Edge( - source=_node_from_dict(e["source"]), - target=_node_from_dict(e["target"]), - strength=float(e["strength"]), - is_cross_seq=bool(e["is_cross_seq"]), - ) - for e in json.loads(row["edges_data"]) - ] + def _parse_edges(data: str) -> list[Edge]: + return [ + Edge( + source=_node_from_dict(e["source"]), + target=_node_from_dict(e["target"]), + strength=float(e["strength"]), + is_cross_seq=bool(e["is_cross_seq"]), + ) + for e in json.loads(data) + ] + + edges = _parse_edges(row["edges_data"]) + edges_abs = _parse_edges(row["edges_data_abs"]) if row["edges_data_abs"] else None logits_data = torch.load(io.BytesIO(row["output_logits"]), weights_only=True) ci_masked_out_logits: torch.Tensor = logits_data["ci_masked"] target_out_logits: torch.Tensor = logits_data["target"] - adv_pgd_out_logits: torch.Tensor | None = logits_data.get("adv_pgd") node_ci_vals: dict[str, float] = json.loads(row["node_ci_vals"]) node_subcomp_acts: dict[str, float] = json.loads(row["node_subcomp_acts"] or "{}") @@ -526,12 +607,18 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: if row["graph_type"] == "optimized": loss_config_data = json.loads(row["loss_config"]) loss_type = loss_config_data["type"] - assert loss_type in ("ce", "kl"), f"Unknown loss type: {loss_type}" - loss_config: LossConfig - if loss_type == "ce": - loss_config = CELossConfig(**loss_config_data) - else: - loss_config = KLLossConfig(**loss_config_data) + assert loss_type in ("ce", "kl", "logit"), f"Unknown loss type: {loss_type}" + loss_config: PositionalLossConfig + match loss_type: + case "ce": + loss_config = CELossConfig(**loss_config_data) + case "kl": + loss_config = KLLossConfig(**loss_config_data) + case "logit": + loss_config = LogitLossConfig(**loss_config_data) + pgd = None + if row["adv_pgd_n_steps"] is not None: + pgd = PgdConfig(n_steps=row["adv_pgd_n_steps"], step_size=row["adv_pgd_step_size"]) opt_params = OptimizationParams( imp_min_coeff=row["imp_min_coeff"], steps=row["steps"], @@ -539,8 +626,10 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: beta=row["beta"], mask_type=row["mask_type"], loss=loss_config, - adv_pgd_n_steps=row["adv_pgd_n_steps"], - adv_pgd_step_size=row["adv_pgd_step_size"], + pgd=pgd, + ci_masked_label_prob=row["ci_masked_label_prob"], + stoch_masked_label_prob=row["stoch_masked_label_prob"], + adv_pgd_label_prob=row["adv_pgd_label_prob"], ) # Parse manual-specific fields @@ -552,9 +641,9 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: id=row["id"], graph_type=row["graph_type"], edges=edges, + edges_abs=edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, - adv_pgd_out_logits=adv_pgd_out_logits, node_ci_vals=node_ci_vals, node_subcomp_acts=node_subcomp_acts, optimization_params=opt_params, @@ -572,9 +661,10 @@ def get_graphs(self, prompt_id: int) -> list[StoredGraph]: """ conn = self._get_conn() rows = conn.execute( - """SELECT id, graph_type, edges_data, output_logits, node_ci_vals, + """SELECT id, graph_type, edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes + loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob FROM graphs WHERE prompt_id = ? ORDER BY @@ -588,9 +678,11 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: """Retrieve a single graph by its ID. Returns (graph, prompt_id) or None.""" conn = self._get_conn() row = conn.execute( - """SELECT id, prompt_id, graph_type, edges_data, output_logits, node_ci_vals, - node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes + """SELECT id, prompt_id, graph_type, edges_data, edges_data_abs, output_logits, + node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, + mask_type, loss_config, adv_pgd_n_steps, adv_pgd_step_size, + included_nodes, ci_masked_label_prob, stoch_masked_label_prob, + adv_pgd_label_prob FROM graphs WHERE id = ?""", (graph_id,), @@ -599,23 +691,45 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: return None return (self._row_to_stored_graph(row), row["prompt_id"]) + def delete_prompt(self, prompt_id: int) -> None: + """Delete a prompt and all its graphs, intervention runs, and forked runs.""" + with self._write_lock(): + conn = self._get_conn() + graph_ids_query = "SELECT id FROM graphs WHERE prompt_id = ?" + intervention_ids_query = ( + f"SELECT id FROM intervention_runs WHERE graph_id IN ({graph_ids_query})" + ) + conn.execute( + f"DELETE FROM forked_intervention_runs WHERE intervention_run_id IN ({intervention_ids_query})", + (prompt_id,), + ) + conn.execute( + f"DELETE FROM intervention_runs WHERE graph_id IN ({graph_ids_query})", + (prompt_id,), + ) + conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.execute("DELETE FROM prompts WHERE id = ?", (prompt_id,)) + conn.commit() + def delete_graphs_for_prompt(self, prompt_id: int) -> int: """Delete all graphs for a prompt. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.commit() + return cursor.rowcount def delete_graphs_for_run(self, run_id: int) -> int: """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute( - """DELETE FROM graphs - WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", - (run_id,), - ) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """DELETE FROM graphs + WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", + (run_id,), + ) + conn.commit() + return cursor.rowcount # ------------------------------------------------------------------------- # Intervention run operations @@ -632,21 +746,22 @@ def save_intervention_run( Args: graph_id: The graph ID this run belongs to. selected_nodes: List of node keys that were selected. - result_json: JSON-encoded InterventionResponse. + result_json: JSON-encoded InterventionResult. Returns: The intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO intervention_runs (graph_id, selected_nodes, result) - VALUES (?, ?, ?)""", - (graph_id, json.dumps(selected_nodes), result_json), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO intervention_runs (graph_id, selected_nodes, result) + VALUES (?, ?, ?)""", + (graph_id, json.dumps(selected_nodes), result_json), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: """Get all intervention runs for a graph. @@ -679,16 +794,18 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: def delete_intervention_run(self, run_id: int) -> None: """Delete an intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) + conn.commit() def delete_intervention_runs_for_graph(self, graph_id: int) -> int: """Delete all intervention runs for a graph. Returns count deleted.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) + conn.commit() + return cursor.rowcount # ------------------------------------------------------------------------- # Forked intervention run operations @@ -710,16 +827,17 @@ def save_forked_intervention_run( Returns: The forked intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) - VALUES (?, ?, ?)""", - (intervention_run_id, json.dumps(token_replacements), result_json), - ) - conn.commit() - fork_id = cursor.lastrowid - assert fork_id is not None - return fork_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) + VALUES (?, ?, ?)""", + (intervention_run_id, json.dumps(token_replacements), result_json), + ) + conn.commit() + fork_id = cursor.lastrowid + assert fork_id is not None + return fork_id def get_forked_intervention_runs( self, intervention_run_id: int @@ -775,6 +893,7 @@ def get_intervention_run(self, run_id: int) -> InterventionRunRecord | None: def delete_forked_intervention_run(self, fork_id: int) -> None: """Delete a forked intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) + conn.commit() diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 48403feb6..4c7e42fbb 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -13,6 +13,7 @@ from spd.configs import ImportanceMinimalityLossConfig, PGDInitStrategy, SamplingType from spd.metrics import importance_minimality_loss +from spd.metrics.pgd_utils import get_pgd_init_tensor, interpolate_pgd_mask from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.routing import AllLayersRouter @@ -48,16 +49,33 @@ class KLLossConfig(BaseModel): position: int -LossConfig = CELossConfig | KLLossConfig +class LogitLossConfig(BaseModel): + """Logit loss: maximize the pre-softmax logit for a specific token at a position.""" + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + + +class MeanKLLossConfig(BaseModel): + """Mean KL divergence loss: match target model distribution across all positions.""" + + type: Literal["mean_kl"] = "mean_kl" + coeff: float = 1.0 -def _compute_recon_loss( + +PositionalLossConfig = CELossConfig | KLLossConfig | LogitLossConfig +LossConfig = CELossConfig | KLLossConfig | LogitLossConfig | MeanKLLossConfig + + +def compute_recon_loss( logits: Tensor, loss_config: LossConfig, target_out: Tensor, device: str, ) -> Tensor: - """Compute recon loss (CE or KL) from model output logits at the configured position.""" + """Compute recon loss (CE, KL, or mean KL) from model output logits.""" match loss_config: case CELossConfig(position=pos, label_token=label_token): return F.cross_entropy( @@ -68,14 +86,12 @@ def _compute_recon_loss( target_probs = F.softmax(target_out[0, pos, :], dim=-1) pred_log_probs = F.log_softmax(logits[0, pos, :], dim=-1) return F.kl_div(pred_log_probs, target_probs, reduction="sum") - - -def _interpolate_masks( - ci: dict[str, Tensor], - sources: dict[str, Tensor], -) -> dict[str, Tensor]: - """Compute PGD component masks: ci + (1 - ci) * source.""" - return {layer: ci[layer] + (1 - ci[layer]) * sources[layer] for layer in ci} + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[0, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="batchmean") @dataclass @@ -188,94 +204,6 @@ def create_optimizable_ci_params( ) -def compute_l0_stats( - ci_outputs: CIOutputs, - ci_alive_threshold: float, -) -> dict[str, float]: - """Compute L0 statistics for each layer.""" - stats: dict[str, float] = {} - for layer_name, layer_ci in ci_outputs.lower_leaky.items(): - l0_val = calc_ci_l_zero(layer_ci, ci_alive_threshold) - stats[f"l0/{layer_name}"] = l0_val - stats["l0/total"] = sum(stats.values()) - return stats - - -def compute_specific_pos_ce_kl( - model: ComponentModel, - batch: Tensor, - target_out: Tensor, - ci: dict[str, Tensor], - rounding_threshold: float, - loss_seq_pos: int, -) -> dict[str, float]: - """Compute CE and KL metrics for a specific sequence position. - - Args: - model: The ComponentModel. - batch: Input tokens of shape [1, seq_len]. - target_out: Target model output logits of shape [1, seq_len, vocab]. - ci: Causal importance values (lower_leaky) per layer. - rounding_threshold: Threshold for rounding CI values to binary masks. - loss_seq_pos: Sequence position to compute metrics for. - - Returns: - Dict with kl and ce_difference metrics for ci_masked, unmasked, and rounded_masked. - """ - assert batch.ndim == 2 and batch.shape[0] == 1, "Expected batch shape [1, seq_len]" - - # Get target logits at the specified position - target_logits = target_out[0, loss_seq_pos, :] # [vocab] - - def kl_vs_target(logits: Tensor) -> float: - """KL divergence between predicted and target logits at target position.""" - pos_logits = logits[0, loss_seq_pos, :] # [vocab] - target_probs = F.softmax(target_logits, dim=-1) - pred_log_probs = F.log_softmax(pos_logits, dim=-1) - return F.kl_div(pred_log_probs, target_probs, reduction="sum").item() - - def ce_vs_target(logits: Tensor) -> float: - """CE between predicted logits and target's argmax at target position.""" - pos_logits = logits[0, loss_seq_pos, :] # [vocab] - target_token = target_logits.argmax() - return F.cross_entropy(pos_logits.unsqueeze(0), target_token.unsqueeze(0)).item() - - # Target model CE (baseline) - target_ce = ce_vs_target(target_out) - - # CI masked - ci_mask_infos = make_mask_infos(ci) - with bf16_autocast(): - ci_masked_logits = model(batch, mask_infos=ci_mask_infos) - ci_masked_kl = kl_vs_target(ci_masked_logits) - ci_masked_ce = ce_vs_target(ci_masked_logits) - - # Unmasked (all components active) - unmasked_infos = make_mask_infos({k: torch.ones_like(v) for k, v in ci.items()}) - with bf16_autocast(): - unmasked_logits = model(batch, mask_infos=unmasked_infos) - unmasked_kl = kl_vs_target(unmasked_logits) - unmasked_ce = ce_vs_target(unmasked_logits) - - # Rounded masked (binary masks based on threshold) - rounded_mask_infos = make_mask_infos( - {k: (v > rounding_threshold).float() for k, v in ci.items()} - ) - with bf16_autocast(): - rounded_masked_logits = model(batch, mask_infos=rounded_mask_infos) - rounded_masked_kl = kl_vs_target(rounded_masked_logits) - rounded_masked_ce = ce_vs_target(rounded_masked_logits) - - return { - "kl_ci_masked": ci_masked_kl, - "kl_unmasked": unmasked_kl, - "kl_rounded_masked": rounded_masked_kl, - "ce_difference_ci_masked": ci_masked_ce - target_ce, - "ce_difference_unmasked": unmasked_ce - target_ce, - "ce_difference_rounded_masked": rounded_masked_ce - target_ce, - } - - @dataclass class OptimCIConfig: """Configuration for optimizing CI values on a single prompt.""" @@ -292,9 +220,9 @@ class OptimCIConfig: log_freq: int - # Loss config (exactly one of CE or KL) + # Loss config (CE or KL — must target a specific position) imp_min_config: ImportanceMinimalityLossConfig - loss_config: LossConfig + loss_config: PositionalLossConfig sampling: SamplingType @@ -306,43 +234,51 @@ class OptimCIConfig: ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) +class CISnapshot(BaseModel): + """Snapshot of alive component counts during CI optimization for visualization.""" + + step: int + total_steps: int + layers: list[str] + seq_len: int + initial_alive: list[list[int]] # layers × seq + current_alive: list[list[int]] # layers × seq + l0_total: float + loss: float + + +CISnapshotCallback = Callable[[CISnapshot], None] + + @dataclass class OptimizeCIResult: """Result from CI optimization including params and final metrics.""" params: OptimizableCIParams metrics: OptimizationMetrics - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None -def _run_adv_pgd( +def run_adv_pgd( model: ComponentModel, tokens: Tensor, - ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], + ci: dict[str, Float[Tensor, "1 seq C"]], alive_masks: dict[str, Bool[Tensor, "1 seq C"]], adv_config: AdvPGDConfig, - loss_config: LossConfig, target_out: Tensor, - device: str, + loss_config: LossConfig, ) -> dict[str, Float[Tensor, "1 seq C"]]: - """Run PGD to find adversarial sources maximizing reconstruction loss. + """Run PGD to find adversarial sources maximizing loss. Sources are optimized via signed gradient ascent. Only alive positions are optimized. Masks are computed as ci + (1 - ci) * source (same interpolation as training PGD). Returns detached adversarial source tensors. """ - ci_detached = {k: v.detach() for k, v in ci_lower_leaky.items()} + ci_detached = {k: v.detach() for k, v in ci.items()} adv_sources: dict[str, Tensor] = {} - for layer_name, ci in ci_detached.items(): - match adv_config.init: - case "random": - source = torch.rand_like(ci) - case "ones": - source = torch.ones_like(ci) - case "zeroes": - source = torch.zeros_like(ci) + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) source[~alive_masks[layer_name]] = 0.0 source.requires_grad_(True) adv_sources[layer_name] = source @@ -350,12 +286,13 @@ def _run_adv_pgd( source_list = list(adv_sources.values()) for _ in range(adv_config.n_steps): - mask_infos = make_mask_infos(_interpolate_masks(ci_detached, adv_sources)) + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) with bf16_autocast(): out = model(tokens, mask_infos=mask_infos) - loss = _compute_recon_loss(out, loss_config, target_out, device) + loss = compute_recon_loss(out, loss_config, target_out, str(tokens.device)) + grads = torch.autograd.grad(loss, source_list) with torch.no_grad(): for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): @@ -372,6 +309,7 @@ def optimize_ci_values( config: OptimCIConfig, device: str, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizeCIResult: """Optimize CI values for a single prompt. @@ -406,13 +344,40 @@ def optimize_ci_values( weight_deltas = model.calc_weight_deltas() + # Precompute snapshot metadata for CI visualization + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) progress_interval = max(1, config.steps // 20) # Report ~20 times during optimization + latest_loss: float = 0.0 for step in tqdm(range(config.steps), desc="Optimizing CI values"): - if on_progress is not None and step % progress_interval == 0: - on_progress(step, config.steps, "optimizing") + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params.create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) optimizer.zero_grad() @@ -444,82 +409,46 @@ def optimize_ci_values( p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, ) - recon_loss = _compute_recon_loss(recon_out, config.loss_config, target_out, device) + recon_loss = compute_recon_loss(recon_out, config.loss_config, target_out, device) total_loss = config.loss_config.coeff * recon_loss + imp_min_coeff * imp_min_loss + latest_loss = total_loss.item() # PGD adversarial loss (runs in tandem with recon) if config.adv_pgd is not None: - adv_sources = _run_adv_pgd( + adv_sources = run_adv_pgd( model=model, tokens=tokens, - ci_lower_leaky=ci_outputs.lower_leaky, + ci=ci_outputs.lower_leaky, alive_masks=alive_info.alive_masks, adv_config=config.adv_pgd, loss_config=config.loss_config, target_out=target_out, - device=device, ) pgd_mask_infos = make_mask_infos( - _interpolate_masks(ci_outputs.lower_leaky, adv_sources) + interpolate_pgd_mask(ci_outputs.lower_leaky, adv_sources) ) with bf16_autocast(): pgd_out = model(tokens, mask_infos=pgd_mask_infos) - pgd_loss = _compute_recon_loss(pgd_out, config.loss_config, target_out, device) + pgd_loss = compute_recon_loss(pgd_out, config.loss_config, target_out, device) total_loss = total_loss + config.loss_config.coeff * pgd_loss - if step % config.log_freq == 0 or step == config.steps - 1: - l0_stats = compute_l0_stats(ci_outputs, ci_alive_threshold=0.0) - - with torch.no_grad(): - ce_kl_stats = compute_specific_pos_ce_kl( - model=model, - batch=tokens, - target_out=target_out, - ci=ci_outputs.lower_leaky, - rounding_threshold=config.ce_kl_rounding_threshold, - loss_seq_pos=config.loss_config.position, - ) - - log_terms: dict[str, float] = { - "imp_min_loss": imp_min_loss.item(), - "total_loss": total_loss.item(), - "recon_loss": recon_loss.item(), - } - - if isinstance(config.loss_config, CELossConfig): - pos = config.loss_config.position - label_token = config.loss_config.label_token - recon_label_prob = F.softmax(recon_out[0, pos, :], dim=-1)[label_token] - log_terms["recon_masked_label_prob"] = recon_label_prob.item() - - with torch.no_grad(): - mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") - logits = model(tokens, mask_infos=mask_infos) - probs = F.softmax(logits[0, pos, :], dim=-1) - log_terms["ci_masked_label_prob"] = float(probs[label_token].item()) - - tqdm.write(f"\n--- Step {step} ---") - for name, value in log_terms.items(): - tqdm.write(f" {name}: {value:.6f}") - for name, value in l0_stats.items(): - tqdm.write(f" {name}: {value:.2f}") - for name, value in ce_kl_stats.items(): - tqdm.write(f" {name}: {value:.6f}") - total_loss.backward() optimizer.step() # Compute final metrics after optimization with torch.no_grad(): final_ci_outputs = ci_params.create_ci_outputs(model, device) - final_l0_stats = compute_l0_stats(final_ci_outputs, ci_alive_threshold=0.0) + + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci_outputs.lower_leaky.values() + ) final_ci_masked_label_prob: float | None = None final_stoch_masked_label_prob: float | None = None - if isinstance(config.loss_config, CELossConfig): + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): pos = config.loss_config.position label_token = config.loss_config.label_token @@ -541,29 +470,26 @@ def optimize_ci_values( final_stoch_masked_label_prob = float(stoch_probs[label_token].item()) # Adversarial PGD final evaluation (needs gradients for PGD, so outside no_grad block) - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None final_adv_pgd_label_prob: float | None = None if config.adv_pgd is not None: - final_adv_sources = _run_adv_pgd( + final_adv_sources = run_adv_pgd( model=model, tokens=tokens, - ci_lower_leaky=final_ci_outputs.lower_leaky, + ci=final_ci_outputs.lower_leaky, alive_masks=alive_info.alive_masks, adv_config=config.adv_pgd, - loss_config=config.loss_config, target_out=target_out, - device=device, + loss_config=config.loss_config, ) with torch.no_grad(): adv_pgd_masks = make_mask_infos( - _interpolate_masks(final_ci_outputs.lower_leaky, final_adv_sources) + interpolate_pgd_mask(final_ci_outputs.lower_leaky, final_adv_sources) ) with bf16_autocast(): adv_logits = model(tokens, mask_infos=adv_pgd_masks) - adv_pgd_out_logits = adv_logits[0].detach() # [seq, vocab] - if isinstance(config.loss_config, CELossConfig): + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): pos = config.loss_config.position label_token = config.loss_config.label_token adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) @@ -573,16 +499,335 @@ def optimize_ci_values( ci_masked_label_prob=final_ci_masked_label_prob, stoch_masked_label_prob=final_stoch_masked_label_prob, adv_pgd_label_prob=final_adv_pgd_label_prob, - l0_total=final_l0_stats["l0/total"], + l0_total=total_l0, ) return OptimizeCIResult( params=ci_params, metrics=metrics, - adv_pgd_out_logits=adv_pgd_out_logits, ) +def compute_recon_loss_batched( + logits: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, + target_out: Float[Tensor, "N seq vocab"], + device: str, +) -> Float[Tensor, " N"]: + """Compute per-element reconstruction loss for batched logits.""" + match loss_config: + case CELossConfig(position=pos, label_token=label_token): + labels = torch.full((logits.shape[0],), label_token, device=device) + return F.cross_entropy(logits[:, pos, :], labels, reduction="none") + case KLLossConfig(position=pos): + target_probs = F.softmax(target_out[:, pos, :], dim=-1) + pred_log_probs = F.log_softmax(logits[:, pos, :], dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1) + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[:, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) + + +def importance_minimality_loss_per_element( + ci_upper_leaky_batched: dict[str, Float[Tensor, "N seq C"]], + n_batch: int, + current_frac_of_training: float, + pnorm: float, + beta: float, + eps: float, + p_anneal_start_frac: float, + p_anneal_final_p: float | None, + p_anneal_end_frac: float, +) -> Float[Tensor, " N"]: + """Compute importance minimality loss independently for each batch element.""" + losses = [] + for i in range(n_batch): + element_ci = {k: v[i : i + 1] for k, v in ci_upper_leaky_batched.items()} + losses.append( + importance_minimality_loss( + ci_upper_leaky=element_ci, + current_frac_of_training=current_frac_of_training, + pnorm=pnorm, + beta=beta, + eps=eps, + p_anneal_start_frac=p_anneal_start_frac, + p_anneal_final_p=p_anneal_final_p, + p_anneal_end_frac=p_anneal_end_frac, + ) + ) + return torch.stack(losses) + + +def run_adv_pgd_batched( + model: ComponentModel, + tokens: Float[Tensor, "N seq"], + ci: dict[str, Float[Tensor, "N seq C"]], + alive_masks: dict[str, Bool[Tensor, "N seq C"]], + adv_config: AdvPGDConfig, + target_out: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, +) -> dict[str, Float[Tensor, "N seq C"]]: + """Run PGD adversary with batched tensors. Returns detached adversarial sources.""" + ci_detached = {k: v.detach() for k, v in ci.items()} + + adv_sources: dict[str, Tensor] = {} + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) + source[~alive_masks[layer_name]] = 0.0 + source.requires_grad_(True) + adv_sources[layer_name] = source + + source_list = list(adv_sources.values()) + + for _ in range(adv_config.n_steps): + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) + + with bf16_autocast(): + out = model(tokens, mask_infos=mask_infos) + + losses = compute_recon_loss_batched(out, loss_config, target_out, str(tokens.device)) + loss = losses.sum() + + grads = torch.autograd.grad(loss, source_list) + with torch.no_grad(): + for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): + source.add_(adv_config.step_size * grad.sign()) + source.clamp_(0.0, 1.0) + source[~alive_masks[layer_name]] = 0.0 + + return {k: v.detach() for k, v in adv_sources.items()} + + +def optimize_ci_values_batched( + model: ComponentModel, + tokens: Float[Tensor, "1 seq"], + configs: list[OptimCIConfig], + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizeCIResult]: + """Optimize CI values for N sparsity coefficients in a single batched loop. + + All configs must share the same loss_config, steps, mask_type, adv_pgd settings — + only imp_min_config.coeff varies between them. + """ + N = len(configs) + assert N > 0 + + config = configs[0] + imp_min_coeffs = torch.tensor([c.imp_min_config.coeff for c in configs], device=device) + for c in configs: + assert c.imp_min_config.coeff is not None + + model.requires_grad_(False) + + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + initial_ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + target_out = output_with_cache.output.detach() + + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky) + + ci_params_list = [ + create_optimizable_ci_params( + alive_info=alive_info, + initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, + ) + for _ in range(N) + ] + + weight_deltas = model.calc_weight_deltas() + + all_params: list[Tensor] = [] + for ci_params in ci_params_list: + all_params.extend(ci_params.get_parameters()) + + optimizer = optim.AdamW(all_params, lr=config.lr, weight_decay=config.weight_decay) + + tokens_batched = tokens.expand(N, -1) + target_out_batched = target_out.expand(N, -1, -1) + + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + + progress_interval = max(1, config.steps // 20) + latest_loss = 0.0 + + for step in tqdm(range(config.steps), desc="Optimizing CI values (batched)"): + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params_list[0].create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) + + optimizer.zero_grad() + + ci_outputs_list = [cp.create_ci_outputs(model, device) for cp in ci_params_list] + + layers = list(ci_outputs_list[0].lower_leaky.keys()) + batched_ci_lower_leaky: dict[str, Tensor] = { + layer: torch.cat([co.lower_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + batched_ci_upper_leaky: dict[str, Tensor] = { + layer: torch.cat([co.upper_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + + match config.mask_type: + case "stochastic": + recon_mask_infos = calc_stochastic_component_mask_info( + causal_importances=batched_ci_lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + case "ci": + recon_mask_infos = make_mask_infos(component_masks=batched_ci_lower_leaky) + + with bf16_autocast(): + recon_out = model(tokens_batched, mask_infos=recon_mask_infos) + + imp_min_losses = importance_minimality_loss_per_element( + ci_upper_leaky_batched=batched_ci_upper_leaky, + n_batch=N, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + beta=config.imp_min_config.beta, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) + + recon_losses = compute_recon_loss_batched( + recon_out, config.loss_config, target_out_batched, device + ) + + loss_coeff = config.loss_config.coeff + total_loss = (loss_coeff * recon_losses + imp_min_coeffs * imp_min_losses).sum() + latest_loss = total_loss.item() + + if config.adv_pgd is not None: + batched_alive_masks = { + k: v.expand(N, -1, -1) for k, v in alive_info.alive_masks.items() + } + adv_sources = run_adv_pgd_batched( + model=model, + tokens=tokens_batched, + ci=batched_ci_lower_leaky, + alive_masks=batched_alive_masks, + adv_config=config.adv_pgd, + target_out=target_out_batched, + loss_config=config.loss_config, + ) + pgd_masks = interpolate_pgd_mask(batched_ci_lower_leaky, adv_sources) + pgd_mask_infos = make_mask_infos(pgd_masks) + with bf16_autocast(): + pgd_out = model(tokens_batched, mask_infos=pgd_mask_infos) + pgd_losses = compute_recon_loss_batched( + pgd_out, config.loss_config, target_out_batched, device + ) + total_loss = total_loss + (loss_coeff * pgd_losses).sum() + + total_loss.backward() + optimizer.step() + + # Compute final metrics per element + results: list[OptimizeCIResult] = [] + for ci_params in ci_params_list: + with torch.no_grad(): + final_ci = ci_params.create_ci_outputs(model, device) + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci.lower_leaky.values() + ) + + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + + ci_mask_infos = make_mask_infos(final_ci.lower_leaky, routing_masks="all") + ci_logits = model(tokens, mask_infos=ci_mask_infos) + ci_probs = F.softmax(ci_logits[0, pos, :], dim=-1) + ci_masked_label_prob = float(ci_probs[label_token].item()) + + stoch_mask_infos = calc_stochastic_component_mask_info( + causal_importances=final_ci.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + stoch_logits = model(tokens, mask_infos=stoch_mask_infos) + stoch_probs = F.softmax(stoch_logits[0, pos, :], dim=-1) + stoch_masked_label_prob = float(stoch_probs[label_token].item()) + + adv_pgd_label_prob: float | None = None + if config.adv_pgd is not None: + final_adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=final_ci.lower_leaky, + alive_masks=alive_info.alive_masks, + adv_config=config.adv_pgd, + target_out=target_out, + loss_config=config.loss_config, + ) + with torch.no_grad(): + adv_masks = make_mask_infos( + interpolate_pgd_mask(final_ci.lower_leaky, final_adv_sources) + ) + with bf16_autocast(): + adv_logits = model(tokens, mask_infos=adv_masks) + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) + adv_pgd_label_prob = float(adv_probs[label_token].item()) + + results.append( + OptimizeCIResult( + params=ci_params, + metrics=OptimizationMetrics( + ci_masked_label_prob=ci_masked_label_prob, + stoch_masked_label_prob=stoch_masked_label_prob, + adv_pgd_label_prob=adv_pgd_label_prob, + l0_total=total_l0, + ), + ) + ) + + return results + + def get_out_dir() -> Path: """Get the output directory for optimization results.""" out_dir = Path(__file__).parent / "out" diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index b7a6f8ed3..7b1729fbb 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -7,10 +7,14 @@ from spd.app.backend.routers.data_sources import router as data_sources_router from spd.app.backend.routers.dataset_attributions import router as dataset_attributions_router from spd.app.backend.routers.dataset_search import router as dataset_search_router +from spd.app.backend.routers.graph_interp import router as graph_interp_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router from spd.app.backend.routers.pretrain_info import router as pretrain_info_router from spd.app.backend.routers.prompts import router as prompts_router +from spd.app.backend.routers.run_registry import router as run_registry_router from spd.app.backend.routers.runs import router as runs_router __all__ = [ @@ -20,10 +24,14 @@ "correlations_router", "data_sources_router", "dataset_attributions_router", + "graph_interp_router", "dataset_search_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", "pretrain_info_router", "prompts_router", + "run_registry_router", "runs_router", ] diff --git a/spd/app/backend/routers/clusters.py b/spd/app/backend/routers/clusters.py index e2dbae37a..b2dc1d5b9 100644 --- a/spd/app/backend/routers/clusters.py +++ b/spd/app/backend/routers/clusters.py @@ -10,6 +10,7 @@ from spd.app.backend.utils import log_errors from spd.base_config import BaseConfig from spd.settings import SPD_OUT_DIR +from spd.topology import TransformerTopology router = APIRouter(prefix="/api/clusters", tags=["clusters"]) @@ -86,4 +87,17 @@ def load_cluster_mapping(file_path: str) -> ClusterMapping: f"but loaded run is '{run_state.run.wandb_path}'", ) - return ClusterMapping(mapping=parsed.clusters) + canonical_clusters = _to_canonical_keys(parsed.clusters, run_state.topology) + return ClusterMapping(mapping=canonical_clusters) + + +def _to_canonical_keys( + clusters: dict[str, int | None], topology: TransformerTopology +) -> dict[str, int | None]: + """Convert concrete component keys (e.g. 'h.3.mlp.down_proj:5') to canonical (e.g. '3.mlp.down:5').""" + result: dict[str, int | None] = {} + for key, cluster_id in clusters.items(): + layer, idx = key.rsplit(":", 1) + canonical_layer = topology.target_to_canon(layer) + result[f"{canonical_layer}:{idx}"] = cluster_id + return result diff --git a/spd/app/backend/routers/data_sources.py b/spd/app/backend/routers/data_sources.py index 5287d91bd..6888b339f 100644 --- a/spd/app/backend/routers/data_sources.py +++ b/spd/app/backend/routers/data_sources.py @@ -28,15 +28,21 @@ class AutointerpInfo(BaseModel): class AttributionsInfo(BaseModel): subrun_id: str - n_batches_processed: int n_tokens_processed: int ci_threshold: float +class GraphInterpInfo(BaseModel): + subrun_id: str + config: dict[str, Any] | None + label_counts: dict[str, int] + + class DataSourcesResponse(BaseModel): harvest: HarvestInfo | None autointerp: AutointerpInfo | None attributions: AttributionsInfo | None + graph_interp: GraphInterpInfo | None router = APIRouter(prefix="/api/data_sources", tags=["data_sources"]) @@ -70,13 +76,21 @@ def get_data_sources(loaded: DepLoadedRun) -> DataSourcesResponse: storage = loaded.attributions.get_attributions() attributions_info = AttributionsInfo( subrun_id=loaded.attributions.subrun_id, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, ci_threshold=storage.ci_threshold, ) + graph_interp_info: GraphInterpInfo | None = None + if loaded.graph_interp is not None: + graph_interp_info = GraphInterpInfo( + subrun_id=loaded.graph_interp.subrun_id, + config=loaded.graph_interp.get_config(), + label_counts=loaded.graph_interp.get_label_counts(), + ) + return DataSourcesResponse( harvest=harvest_info, autointerp=autointerp_info, attributions=attributions_info, + graph_interp=graph_interp_info, ) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 4c3d07753..178eefc72 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -7,46 +7,43 @@ from typing import Annotated, Literal from fastapi import APIRouter, HTTPException, Query -from jaxtyping import Float from pydantic import BaseModel -from torch import Tensor from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry -from spd.dataset_attributions.storage import DatasetAttributionStorage +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool - n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -54,91 +51,67 @@ class ComponentAttributions(BaseModel): ) -def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str: - """Translate canonical layer + idx to concrete storage key. - - "embed" maps to the concrete embedding path (e.g. "wte") in storage. - "output" is a pseudo-layer used as-is in storage. - """ - if canonical_layer == "output": - return f"output:{component_idx}" - concrete = loaded.topology.canon_to_target(canonical_layer) - return f"{concrete}:{component_idx}" - - def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) return loaded.attributions.get_attributions() -def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" - if not storage.has_source(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as source in attributions", - ) - - -def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" - if not storage.has_target(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as target in attributions", - ) - - -def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" - return loaded.topology.get_unembed_weight() - - def _to_api_entries( - loaded: DepLoadedRun, entries: list[StorageEntry] + entries: list[StorageEntry], loaded: DepLoadedRun ) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format with canonical keys.""" - - def _canonicalize_layer(layer: str) -> str: - if layer == "output": - return layer - return loaded.topology.target_to_canon(layer) - return [ DatasetAttributionEntry( - component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}", - layer=_canonicalize_layer(e.layer), + component_key=e.component_key, + layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + loaded: DepLoadedRun, + component_key: str, + k: int, + metric: AttrMetric, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric), loaded + ), + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric), loaded + ), + positive_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "positive", metric), loaded + ), + negative_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "negative", metric), loaded + ), + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" if loaded.attributions is None: return DatasetAttributionMetadata( available=False, - n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() return DatasetAttributionMetadata( available=True, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -150,58 +123,18 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all metrics.""" storage = _require_storage(loaded) - component_key = _to_concrete_key(layer, component_idx, loaded) - - # Component can be both a source and a target, so we need to check both - is_source = storage.has_source(component_key) - is_target = storage.has_target(component_key) - - if not is_source and not is_target: - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found in attributions", - ) - - w_unembed = _get_w_unembed(loaded) if is_source else None - - return ComponentAttributions( - positive_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "positive") - ) - if is_target - else [], - negative_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "negative") - ) - if is_target - else [], - positive_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "positive", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], - negative_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "negative", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], + component_key = f"{layer}:{component_idx}" + + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, loaded, component_key, k, metric + ) + for metric in ATTR_METRICS + } ) @@ -213,16 +146,11 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = _to_concrete_key(layer, component_idx, loaded) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if layer == "output" else None - return _to_api_entries( - loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed) + storage.get_top_sources(f"{layer}:{component_idx}", k, sign, metric), loaded ) @@ -234,35 +162,9 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(layer, component_idx, loaded) - _require_source(storage, source_key) - - w_unembed = _get_w_unembed(loaded) - return _to_api_entries( - loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed) + storage.get_top_targets(f"{layer}:{component_idx}", k, sign, metric), loaded ) - - -@router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") -@log_errors -def get_attribution_between( - source_layer: str, - source_idx: int, - target_layer: str, - target_idx: int, - loaded: DepLoadedRun, -) -> float: - """Get attribution strength from source component to target component.""" - storage = _require_storage(loaded) - source_key = _to_concrete_key(source_layer, source_idx, loaded) - target_key = _to_concrete_key(target_layer, target_idx, loaded) - _require_source(storage, source_key) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py new file mode 100644 index 000000000..525dbce9c --- /dev/null +++ b/spd/app/backend/routers/graph_interp.py @@ -0,0 +1,373 @@ +"""Graph interpretation endpoints. + +Serves context-aware component labels (output/input/unified) and the +prompt-edge graph produced by the graph_interp pipeline. +""" + +import random + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from spd.graph_interp.schemas import LabelResult + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors +from spd.topology import TransformerTopology + +# TODO(oli): Remove MOCK_MODE after real data is available +MOCK_MODE = False + +MAX_GRAPH_NODES = 500 + + +_ALREADY_CANONICAL = {"embed", "output"} + + +def _concrete_to_canonical_key(concrete_key: str, topology: TransformerTopology) -> str: + layer, idx = concrete_key.rsplit(":", 1) + if layer in _ALREADY_CANONICAL: + return concrete_key + canonical = topology.target_to_canon(layer) + return f"{canonical}:{idx}" + + +def _canonical_to_concrete_key( + canonical_layer: str, component_idx: int, topology: TransformerTopology +) -> str: + concrete = topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +# -- Schemas ------------------------------------------------------------------- + + +class GraphInterpHeadline(BaseModel): + label: str + confidence: str + output_label: str | None + input_label: str | None + + +class LabelDetail(BaseModel): + label: str + confidence: str + reasoning: str + prompt: str + + +class GraphInterpDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + + +class PromptEdgeResponse(BaseModel): + related_key: str + pass_name: str + attribution: float + related_label: str | None + related_confidence: str | None + token_str: str | None + + +class GraphInterpComponentDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + edges: list[PromptEdgeResponse] + + +class GraphNode(BaseModel): + component_key: str + label: str + confidence: str + + +class GraphEdge(BaseModel): + source: str + target: str + attribution: float + pass_name: str + + +class ModelGraphResponse(BaseModel): + nodes: list[GraphNode] + edges: list[GraphEdge] + + +# -- Router -------------------------------------------------------------------- + +router = APIRouter(prefix="/api/graph_interp", tags=["graph_interp"]) + + +@router.get("/labels") +@log_errors +def get_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + if MOCK_MODE: + return _mock_all_labels(loaded) + + repo = loaded.graph_interp + if repo is None: + return {} + + topology = loaded.topology + unified = repo.get_all_unified_labels() + output = repo.get_all_output_labels() + input_ = repo.get_all_input_labels() + + all_keys = set(unified) | set(output) | set(input_) + result: dict[str, GraphInterpHeadline] = {} + + for concrete_key in all_keys: + u = unified.get(concrete_key) + o = output.get(concrete_key) + i = input_.get(concrete_key) + + label = u or o or i + assert label is not None + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + + result[canonical_key] = GraphInterpHeadline( + label=label.label, + confidence=label.confidence, + output_label=o.label if o else None, + input_label=i.label if i else None, + ) + + return result + + +def _to_detail(label: LabelResult | None) -> LabelDetail | None: + if label is None: + return None + return LabelDetail( + label=label.label, + confidence=label.confidence, + reasoning=label.reasoning, + prompt=label.prompt, + ) + + +@router.get("/labels/{layer}/{c_idx}") +@log_errors +def get_label_detail(layer: str, c_idx: int, loaded: DepLoadedRun) -> GraphInterpDetail: + if MOCK_MODE: + return _mock_label_detail(layer, c_idx) + + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + concrete_key = _canonical_to_concrete_key(layer, c_idx, loaded.topology) + + return GraphInterpDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + ) + + +@router.get("/detail/{layer}/{c_idx}") +@log_errors +def get_component_detail( + layer: str, c_idx: int, loaded: DepLoadedRun +) -> GraphInterpComponentDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + concrete_key = _canonical_to_concrete_key(layer, c_idx, topology) + + raw_edges = repo.get_prompt_edges(concrete_key) + tokenizer = loaded.tokenizer + edges = [] + for e in raw_edges: + rel_layer, rel_idx = e.related_key.rsplit(":", 1) + token_str = tokenizer.decode([int(rel_idx)]) if rel_layer in ("embed", "output") else None + edges.append( + PromptEdgeResponse( + related_key=_concrete_to_canonical_key(e.related_key, topology), + pass_name=e.pass_name, + attribution=e.attribution, + related_label=e.related_label, + related_confidence=e.related_confidence, + token_str=token_str, + ) + ) + + return GraphInterpComponentDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + edges=edges, + ) + + +@router.get("/graph") +@log_errors +def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + if MOCK_MODE: + return _mock_model_graph(loaded) + + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + + unified = repo.get_all_unified_labels() + nodes = [] + for concrete_key, label in unified.items(): + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + nodes.append( + GraphNode( + component_key=canonical_key, + label=label.label, + confidence=label.confidence, + ) + ) + + nodes = nodes[:MAX_GRAPH_NODES] + node_keys = {n.component_key for n in nodes} + + raw_edges = repo.get_all_prompt_edges() + edges = [] + for e in raw_edges: + comp_canon = _concrete_to_canonical_key(e.component_key, topology) + rel_canon = _concrete_to_canonical_key(e.related_key, topology) + + match e.pass_name: + case "output": + source, target = comp_canon, rel_canon + case "input": + source, target = rel_canon, comp_canon + + if source not in node_keys or target not in node_keys: + continue + + edges.append( + GraphEdge( + source=source, + target=target, + attribution=e.attribution, + pass_name=e.pass_name, + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) + + +# -- Mock data (TODO: remove after real data available) ------------------------ + +_MOCK_LABELS = [ + "sentence-final punctuation", + "proper noun completion", + "emotional adjective selection", + "temporal adverb prediction", + "morphological suffix (-ing/-ed)", + "determiner before noun", + "dialogue quotation marks", + "plural noun suffix", + "clause boundary detection", + "verb tense agreement", + "spatial preposition", + "possessive pronoun", + "narrative action verb", + "abstract emotion noun", + "comparative adjective form", + "subject-verb agreement", + "article selection (a/the)", + "comma splice detection", + "pronoun resolution", + "negation scope", +] + +_MOCK_INPUT_LABELS = [ + "sentence-initial capitals", + "mid-sentence verb position", + "adjective-noun boundary", + "clause-final position", + "article-noun sequence", + "subject pronoun at boundary", + "preposition-object pair", + "verb stem before suffix", + "quotation boundary", + "comma-separated items", +] + + +def _mock_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + rng = random.Random(42) + topology = loaded.topology + confidences = ["high", "high", "high", "medium", "medium", "low"] + + result: dict[str, GraphInterpHeadline] = {} + for target_path, components in loaded.model.components.items(): + canon = topology.target_to_canon(target_path) + n_components = components.C + n_mock = min(n_components, rng.randint(5, 20)) + indices = sorted(rng.sample(range(n_components), n_mock)) + for idx in indices: + key = f"{canon}:{idx}" + result[key] = GraphInterpHeadline( + label=rng.choice(_MOCK_LABELS), + confidence=rng.choice(confidences), + output_label=rng.choice(_MOCK_LABELS), + input_label=rng.choice(_MOCK_INPUT_LABELS), + ) + return result + + +def _mock_label_detail(layer: str, c_idx: int) -> GraphInterpDetail: + rng = random.Random(hash((layer, c_idx))) + conf = rng.choice(["high", "medium", "low"]) + return GraphInterpDetail( + output=LabelDetail( + label=rng.choice(_MOCK_LABELS), + confidence=conf, + reasoning=f"Output: Component {layer}:{c_idx} writes {rng.choice(_MOCK_LABELS).lower()} tokens to the residual stream.", + prompt="(mock prompt)", + ), + input=LabelDetail( + label=rng.choice(_MOCK_INPUT_LABELS), + confidence=conf, + reasoning=f"Input: Component {layer}:{c_idx} fires on {rng.choice(_MOCK_INPUT_LABELS).lower()} patterns.", + prompt="(mock prompt)", + ), + unified=LabelDetail( + label=rng.choice(_MOCK_LABELS), + confidence=conf, + reasoning=f"Unified: Combines output ({rng.choice(_MOCK_LABELS).lower()}) and input ({rng.choice(_MOCK_INPUT_LABELS).lower()}) functions.", + prompt="(mock prompt)", + ), + ) + + +def _mock_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + labels = _mock_all_labels(loaded) + + nodes = [ + GraphNode(component_key=key, label=h.label, confidence=h.confidence) + for key, h in labels.items() + ] + + rng = random.Random(42) + keys = list(labels.keys()) + edges: list[GraphEdge] = [] + + for key in keys: + layer = key.rsplit(":", 1)[0] + later_keys = [k for k in keys if k.rsplit(":", 1)[0] != layer] + n_edges = rng.randint(1, 4) + for target in rng.sample(later_keys, min(n_edges, len(later_keys))): + edges.append( + GraphEdge( + source=key, + target=target, + attribution=rng.uniform(-1.0, 1.0), + pass_name=rng.choice(["output", "input"]), + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index a51b1649c..e2b439da1 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -19,26 +19,90 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, Edge, + compute_intervention, compute_prompt_attributions, compute_prompt_attributions_optimized, + compute_prompt_attributions_optimized_batched, +) +from spd.app.backend.database import ( + GraphType, + OptimizationParams, + PgdConfig, + PromptAttrDB, + StoredGraph, ) -from spd.app.backend.database import GraphType, OptimizationParams, StoredGraph from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.optim_cis import ( AdvPGDConfig, CELossConfig, + CISnapshot, KLLossConfig, + LogitLossConfig, LossConfig, MaskType, + MeanKLLossConfig, OptimCIConfig, ) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors -from spd.configs import ImportanceMinimalityLossConfig +from spd.configs import ImportanceMinimalityLossConfig, SamplingType from spd.log import logger +from spd.models.component_model import ComponentModel +from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device +NON_INTERVENTABLE_LAYERS = {"embed", "output"} + + +def _save_base_intervention_run( + graph_id: int, + model: ComponentModel, + tokens: torch.Tensor, + node_ci_vals: dict[str, float], + tokenizer: AppTokenizer, + topology: TransformerTopology, + db: PromptAttrDB, + sampling: SamplingType, + loss_config: LossConfig | None = None, +) -> None: + """Compute intervention for all interventable nodes and save as an intervention run.""" + interventable_keys = [ + k + for k, ci in node_ci_vals.items() + if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 + ] + assert len(interventable_keys) > 0, "No interventable nodes with CI > 0" + + active_nodes: list[tuple[str, int, int]] = [] + for key in interventable_keys: + canon_layer, seq_str, cidx_str = key.split(":") + concrete_path = topology.canon_to_target(canon_layer) + active_nodes.append((concrete_path, int(seq_str), int(cidx_str))) + + effective_loss_config: LossConfig = ( + loss_config if loss_config is not None else MeanKLLossConfig() + ) + + result = compute_intervention( + model=model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=effective_loss_config, + sampling=sampling, + top_k=10, + ) + + db.save_intervention_run( + graph_id=graph_id, + selected_nodes=interventable_keys, + result_json=result.model_dump_json(), + ) + class EdgeData(BaseModel): """Edge in the attribution graph.""" @@ -65,12 +129,14 @@ class GraphData(BaseModel): graphType: GraphType tokens: list[str] edges: list[EdgeData] + edgesAbs: list[EdgeData] | None = None # absolute-target variant, None for old graphs outputProbs: dict[str, OutputProbability] nodeCiVals: dict[ str, float ] # node key -> CI value (or output prob for output nodes or 1 for embed node) nodeSubcompActs: dict[str, float] # node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: float # max absolute edge value + maxAbsAttrAbs: float | None = None # max absolute edge value for abs-target variant maxAbsSubcompAct: float # max absolute subcomponent activation for normalization l0_total: int # total active components at current CI threshold @@ -93,6 +159,16 @@ class KLLossResult(BaseModel): position: int +class LogitLossResult(BaseModel): + """Logit loss result (maximize pre-softmax logit).""" + + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + label_str: str + + class OptimizationMetricsResult(BaseModel): """Final loss metrics from CI optimization.""" @@ -112,10 +188,9 @@ class OptimizationResult(BaseModel): pnorm: float beta: float mask_type: MaskType - loss: CELossResult | KLLossResult + loss: CELossResult | KLLossResult | LogitLossResult metrics: OptimizationMetricsResult - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + pgd: PgdConfig | None = None class GraphDataWithOptimization(GraphData): @@ -156,19 +231,6 @@ class TokenizeResponse(BaseModel): next_token_probs: list[float | None] # Probability of next token (last token is None) -class TokenInfo(BaseModel): - """A single token from the tokenizer vocabulary.""" - - id: int - string: str - - -class TokensResponse(BaseModel): - """Response containing all tokens in the vocabulary.""" - - tokens: list[TokenInfo] - - # SSE streaming message types class ProgressMessage(BaseModel): """Progress update during streaming computation.""" @@ -200,6 +262,12 @@ class CompleteMessageWithOptimization(BaseModel): data: GraphDataWithOptimization +class BatchGraphResult(BaseModel): + """Batch optimization result containing multiple graphs.""" + + graphs: list[GraphDataWithOptimization] + + router = APIRouter(prefix="/api/graphs", tags=["graphs"]) DEVICE = get_device() @@ -218,7 +286,6 @@ def _build_out_probs( ci_masked_out_logits: torch.Tensor, target_out_logits: torch.Tensor, tok_display: Callable[[int], str], - adv_pgd_out_logits: torch.Tensor | None = None, ) -> dict[str, OutputProbability]: """Build output probs dict from logit tensors. @@ -226,9 +293,6 @@ def _build_out_probs( """ ci_masked_out_probs = torch.softmax(ci_masked_out_logits, dim=-1) target_out_probs = torch.softmax(target_out_logits, dim=-1) - adv_pgd_out_probs = ( - torch.softmax(adv_pgd_out_logits, dim=-1) if adv_pgd_out_logits is not None else None - ) out_probs: dict[str, OutputProbability] = {} for s in range(ci_masked_out_probs.shape[0]): @@ -243,65 +307,78 @@ def _build_out_probs( target_prob = float(target_out_probs[s, c_idx].item()) target_logit = float(target_out_logits[s, c_idx].item()) - adv_pgd_prob: float | None = None - adv_pgd_logit: float | None = None - if adv_pgd_out_probs is not None and adv_pgd_out_logits is not None: - adv_pgd_prob = round(float(adv_pgd_out_probs[s, c_idx].item()), 6) - adv_pgd_logit = round(float(adv_pgd_out_logits[s, c_idx].item()), 4) - key = f"{s}:{c_idx}" out_probs[key] = OutputProbability( prob=round(prob, 6), logit=round(logit, 4), target_prob=round(target_prob, 6), target_logit=round(target_logit, 4), - adv_pgd_prob=adv_pgd_prob, - adv_pgd_logit=adv_pgd_logit, token=tok_display(c_idx), ) return out_probs +CISnapshotCallback = Callable[[CISnapshot], None] + + def stream_computation( - work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + work: Callable[[ProgressCallback, CISnapshotCallback | None], BaseModel], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: progress_queue.put({"type": "progress", "current": current, "total": total, "stage": stage}) + def on_ci_snapshot(snapshot: CISnapshot) -> None: + progress_queue.put({"type": "ci_snapshot", **snapshot.model_dump()}) + def compute_thread() -> None: try: - result = work(on_progress) + result = work(on_progress, on_ci_snapshot) progress_queue.put({"type": "result", "result": result}) except Exception as e: traceback.print_exc(file=sys.stderr) progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] in ("progress", "ci_snapshot"): + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -343,40 +420,55 @@ def tokenize_text(text: str, loaded: DepLoadedRun) -> TokenizeResponse: ) -@router.get("/tokens") -@log_errors -def get_all_tokens(loaded: DepLoadedRun) -> TokensResponse: - """Get all tokens in the tokenizer vocabulary for client-side search.""" - tokens = [ - TokenInfo(id=tid, string=loaded.tokenizer.get_tok_display(tid)) - for tid in range(loaded.tokenizer.vocab_size) - ] - return TokensResponse(tokens=tokens) +class TokenSearchResult(BaseModel): + """A token search result with model probability at the queried position.""" + + id: int + string: str + prob: float class TokenSearchResponse(BaseModel): """Response from token search endpoint.""" - tokens: list[TokenInfo] + tokens: list[TokenSearchResult] @router.get("/tokens/search") @log_errors def search_tokens( q: Annotated[str, Query(min_length=1)], + prompt_id: Annotated[int, Query()], + position: Annotated[int, Query()], loaded: DepLoadedRun, - limit: Annotated[int, Query(ge=1, le=50)] = 10, + manager: DepStateManager, + limit: Annotated[int, Query(ge=1, le=50)] = 20, ) -> TokenSearchResponse: - """Search tokens by substring match. Returns up to `limit` results.""" + """Search tokens by substring match, sorted by target model probability at position.""" + prompt = manager.state.db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail=f"prompt {prompt_id} not found") + if not (0 <= position < len(prompt.token_ids)): + raise HTTPException( + status_code=422, + detail=f"position {position} out of range for prompt with {len(prompt.token_ids)} tokens", + ) + + device = next(loaded.model.parameters()).device + tokens_tensor = torch.tensor([prompt.token_ids], device=device) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits[0, position], dim=-1) + query = q.lower() - matches: list[TokenInfo] = [] + matches: list[TokenSearchResult] = [] for tid in range(loaded.tokenizer.vocab_size): string = loaded.tokenizer.get_tok_display(tid) if query in string.lower(): - matches.append(TokenInfo(id=tid, string=string)) - if len(matches) >= limit: - break - return TokenSearchResponse(tokens=matches) + matches.append(TokenSearchResult(id=tid, string=string, prob=probs[tid].item())) + + matches.sort(key=lambda m: m.prob, reverse=True) + return TokenSearchResponse(tokens=matches[:limit]) NormalizeType = Literal["none", "target", "layer"] @@ -450,7 +542,9 @@ def compute_graph_stream( spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) - def work(on_progress: ProgressCallback) -> GraphData: + def work( + on_progress: ProgressCallback, _on_ci_snapshot: CISnapshotCallback | None + ) -> GraphData: t_total = time.perf_counter() result = compute_prompt_attributions( @@ -474,6 +568,7 @@ def work(on_progress: ProgressCallback) -> GraphData: graph=StoredGraph( graph_type=graph_type, edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, node_ci_vals=result.node_ci_vals, @@ -483,6 +578,19 @@ def work(on_progress: ProgressCallback) -> GraphData: ) logger.info(f"[perf] save_graph: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + ) + logger.info(f"[perf] base intervention run: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() fg = filter_graph_for_display( raw_edges=result.edges, @@ -494,6 +602,7 @@ def work(on_progress: ProgressCallback) -> GraphData: num_tokens=len(token_ids), ci_threshold=ci_threshold, normalize=normalize, + raw_edges_abs=result.edges_abs, ) logger.info( f"[perf] filter_graph: {time.perf_counter() - t0:.2f}s ({len(fg.edges)} edges after filter)" @@ -505,15 +614,17 @@ def work(on_progress: ProgressCallback) -> GraphData: graphType=graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -557,7 +668,7 @@ def get_group_key(edge: Edge) -> str: return out_edges -LossType = Literal["ce", "kl"] +LossType = Literal["ce", "kl", "logit"] @router.post("/optimized/stream") @@ -597,6 +708,14 @@ def compute_graph_optimized_stream( ) case "kl": loss_config = KLLossConfig(coeff=loss_coeff, position=loss_position) + case "logit": + if label_token is None: + raise HTTPException( + status_code=400, detail="label_token is required for logit loss" + ) + loss_config = LogitLossConfig( + coeff=loss_coeff, position=loss_position, label_token=label_token + ) lr = 1e-2 @@ -627,8 +746,9 @@ def compute_graph_optimized_stream( beta=beta, mask_type=mask_type, loss=loss_config, - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) + if adv_pgd_n_steps is not None and adv_pgd_step_size is not None + else None, ) optim_config = OptimCIConfig( @@ -650,7 +770,9 @@ def compute_graph_optimized_stream( else None, ) - def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> GraphDataWithOptimization: result = compute_prompt_attributions_optimized( model=loaded.model, topology=loaded.topology, @@ -660,28 +782,42 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: output_prob_threshold=0.01, device=DEVICE, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) ci_masked_out_logits = result.ci_masked_out_logits.cpu() target_out_logits = result.target_out_logits.cpu() - adv_pgd_out_logits = ( - result.adv_pgd_out_logits.cpu() if result.adv_pgd_out_logits is not None else None - ) + + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( graph_type="optimized", edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, - adv_pgd_out_logits=adv_pgd_out_logits, node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, optimization_params=opt_params, ), ) + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + fg = filter_graph_for_display( raw_edges=result.edges, node_ci_vals=result.node_ci_vals, @@ -692,11 +828,11 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: num_tokens=num_tokens, ci_threshold=ci_threshold, normalize=normalize, - adv_pgd_out_logits=adv_pgd_out_logits, + raw_edges_abs=result.edges_abs, ) # Build loss result based on config type - loss_result: CELossResult | KLLossResult + loss_result: CELossResult | KLLossResult | LogitLossResult match loss_config: case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): assert label_str is not None @@ -708,16 +844,26 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: ) case KLLossConfig(coeff=coeff, position=pos): loss_result = KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + assert label_str is not None + loss_result = LogitLossResult( + coeff=coeff, + position=pos, + label_token=label_tok, + label_str=label_str, + ) return GraphDataWithOptimization( id=graph_id, graphType="optimized", tokens=spans_sliced, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, optimization=OptimizationResult( @@ -733,12 +879,240 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, l0_total=result.metrics.l0_total, ), - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) + if adv_pgd_n_steps is not None and adv_pgd_step_size is not None + else None, + ), + ) + + return stream_computation(work, manager._gpu_lock) + + +class BatchOptimizedRequest(BaseModel): + """Request body for batch optimized graph computation.""" + + prompt_id: int + imp_min_coeffs: list[float] + steps: int + pnorm: float + beta: float + normalize: NormalizeType + ci_threshold: float + mask_type: MaskType + loss_type: LossType + loss_coeff: float + loss_position: int + label_token: int | None = None + adv_pgd_n_steps: int | None = None + adv_pgd_step_size: float | None = None + + +@router.post("/optimized/batch/stream") +@log_errors +def compute_graph_optimized_batch_stream( + body: BatchOptimizedRequest, + loaded: DepLoadedRun, + manager: DepStateManager, +): + """Compute optimized graphs for multiple sparsity coefficients in one batched optimization. + + Returns N graphs (one per imp_min_coeff) via SSE streaming. + All coefficients share the same loss config, steps, and other hyperparameters. + """ + assert len(body.imp_min_coeffs) > 0, "At least one coefficient required" + assert len(body.imp_min_coeffs) <= 20, "Too many coefficients (max 20)" + + loss_config: LossConfig + match body.loss_type: + case "ce": + assert body.label_token is not None, "label_token is required for CE loss" + loss_config = CELossConfig( + coeff=body.loss_coeff, position=body.loss_position, label_token=body.label_token + ) + case "kl": + loss_config = KLLossConfig(coeff=body.loss_coeff, position=body.loss_position) + case "logit": + assert body.label_token is not None, "label_token is required for logit loss" + loss_config = LogitLossConfig( + coeff=body.loss_coeff, position=body.loss_position, label_token=body.label_token + ) + + lr = 1e-2 + + db = manager.db + prompt = db.get_prompt(body.prompt_id) + assert prompt is not None, f"prompt {body.prompt_id} not found" + + token_ids = prompt.token_ids + assert body.loss_position < len(token_ids), ( + f"loss_position {body.loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + label_str = ( + loaded.tokenizer.get_tok_display(body.label_token) if body.label_token is not None else None + ) + spans = loaded.tokenizer.get_spans(token_ids) + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + num_tokens = body.loss_position + 1 + spans_sliced = spans[:num_tokens] + + adv_pgd = ( + AdvPGDConfig(n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size, init="random") + if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None + else None + ) + + configs = [ + OptimCIConfig( + seed=0, + lr=lr, + steps=body.steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, body.steps // 4), + imp_min_config=ImportanceMinimalityLossConfig( + coeff=coeff, pnorm=body.pnorm, beta=body.beta ), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type=body.mask_type, + adv_pgd=adv_pgd, ) + for coeff in body.imp_min_coeffs + ] - return stream_computation(work) + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> BatchGraphResult: + results = compute_prompt_attributions_optimized_batched( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + configs=configs, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + graphs: list[GraphDataWithOptimization] = [] + for result, coeff in zip(results, body.imp_min_coeffs, strict=True): + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + opt_params = OptimizationParams( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=loss_config, + pgd=PgdConfig(n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size) + if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None + else None, + ) + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob + + graph_id = db.save_graph( + prompt_id=body.prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + + fg = filter_graph_for_display( + raw_edges=result.edges, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + tok_display=loaded.tokenizer.get_tok_display, + num_tokens=num_tokens, + ci_threshold=body.ci_threshold, + normalize=body.normalize, + raw_edges_abs=result.edges_abs, + ) + + loss_result: CELossResult | KLLossResult | LogitLossResult + match loss_config: + case CELossConfig(coeff=lc, position=pos, label_token=label_tok): + assert label_str is not None + loss_result = CELossResult( + coeff=lc, position=pos, label_token=label_tok, label_str=label_str + ) + case KLLossConfig(coeff=lc, position=pos): + loss_result = KLLossResult(coeff=lc, position=pos) + case LogitLossConfig(coeff=lc, position=pos, label_token=label_tok): + assert label_str is not None + loss_result = LogitLossResult( + coeff=lc, position=pos, label_token=label_tok, label_str=label_str + ) + + graphs.append( + GraphDataWithOptimization( + id=graph_id, + graphType="optimized", + tokens=spans_sliced, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, + nodeSubcompActs=result.node_subcomp_acts, + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, + optimization=OptimizationResult( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=loss_result, + metrics=OptimizationMetricsResult( + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + l0_total=result.metrics.l0_total, + ), + pgd=PgdConfig( + n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size + ) + if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None + else None, + ), + ) + ) + + return BatchGraphResult(graphs=graphs) + + return stream_computation(work, manager._gpu_lock) @dataclass @@ -746,9 +1120,11 @@ class FilteredGraph: """Result of filtering a raw graph for display.""" edges: list[EdgeData] + edges_abs: list[EdgeData] | None # absolute-target variant, None for old graphs node_ci_vals: dict[str, float] # with pseudo nodes out_probs: dict[str, OutputProbability] max_abs_attr: float + max_abs_attr_abs: float | None # max abs for absolute-target edges max_abs_subcomp_act: float l0_total: int @@ -763,8 +1139,8 @@ def filter_graph_for_display( num_tokens: int, ci_threshold: float, normalize: NormalizeType, + raw_edges_abs: list[Edge] | None = None, edge_limit: int = GLOBAL_EDGE_LIMIT, - adv_pgd_out_logits: torch.Tensor | None = None, ) -> FilteredGraph: """Filter and transform a raw attribution graph for display. @@ -775,9 +1151,7 @@ def filter_graph_for_display( 5. Normalize edge strengths (if requested) 6. Cap edges at edge_limit """ - out_probs = _build_out_probs( - ci_masked_out_logits, target_out_logits, tok_display, adv_pgd_out_logits - ) + out_probs = _build_out_probs(ci_masked_out_logits, target_out_logits, tok_display) filtered_node_ci_vals = {k: v for k, v in node_ci_vals.items() if v > ci_threshold} @@ -789,25 +1163,33 @@ def filter_graph_for_display( seq_pos, token_id = key.split(":") node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob - # Filter edges to only those connecting surviving nodes + # Filter, normalize, sort, and truncate an edge list to the surviving node set. node_keys = set(node_ci_vals_with_pseudo.keys()) - edges = [e for e in raw_edges if str(e.source) in node_keys and str(e.target) in node_keys] - edges = _normalize_edges(edges=edges, normalize=normalize) - max_abs_attr = compute_max_abs_attr(edges=edges) + def _filter_edges(raw: list[Edge]) -> tuple[list[EdgeData], float]: + filtered = [e for e in raw if str(e.source) in node_keys and str(e.target) in node_keys] + filtered = _normalize_edges(edges=filtered, normalize=normalize) + max_abs = compute_max_abs_attr(edges=filtered) + filtered = sorted(filtered, key=lambda e: abs(e.strength), reverse=True) + if len(filtered) > edge_limit: + logger.warning(f"Edge limit {edge_limit} exceeded ({len(filtered)} edges), truncating") + filtered = filtered[:edge_limit] + return [_edge_to_edge_data(e) for e in filtered], max_abs - # Always sort by abs(strength) desc so frontend can just slice(0, topK) without re-sorting - edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True) + edges_out, max_abs_attr = _filter_edges(raw_edges) - if len(edges) > edge_limit: - logger.warning(f"Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") - edges = edges[:edge_limit] + edges_abs_out: list[EdgeData] | None = None + max_abs_attr_abs: float | None = None + if raw_edges_abs is not None: + edges_abs_out, max_abs_attr_abs = _filter_edges(raw_edges_abs) return FilteredGraph( - edges=[_edge_to_edge_data(e) for e in edges], + edges=edges_out, + edges_abs=edges_abs_out, node_ci_vals=node_ci_vals_with_pseudo, out_probs=out_probs, max_abs_attr=max_abs_attr, + max_abs_attr_abs=max_abs_attr_abs, max_abs_subcomp_act=compute_max_abs_subcomp_act(node_subcomp_acts), l0_total=len(filtered_node_ci_vals), ) @@ -840,7 +1222,7 @@ def stored_graph_to_response( num_tokens=num_tokens, ci_threshold=ci_threshold, normalize=normalize, - adv_pgd_out_logits=graph.adv_pgd_out_logits, + raw_edges_abs=graph.edges_abs, ) if not is_optimized: @@ -849,10 +1231,12 @@ def stored_graph_to_response( graphType=graph.graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, ) @@ -861,7 +1245,7 @@ def stored_graph_to_response( opt = graph.optimization_params # Build loss result based on stored config type - loss_result: CELossResult | KLLossResult + loss_result: CELossResult | KLLossResult | LogitLossResult match opt.loss: case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): label_str = tokenizer.get_tok_display(label_tok) @@ -873,16 +1257,26 @@ def stored_graph_to_response( ) case KLLossConfig(coeff=coeff, position=pos): loss_result = KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + label_str = tokenizer.get_tok_display(label_tok) + loss_result = LogitLossResult( + coeff=coeff, + position=pos, + label_token=label_tok, + label_str=label_str, + ) return GraphDataWithOptimization( id=graph.id, graphType=graph.graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, optimization=OptimizationResult( @@ -892,10 +1286,15 @@ def stored_graph_to_response( beta=opt.beta, mask_type=opt.mask_type, loss=loss_result, - # Metrics not stored in DB for cached graphs - use l0_total from graph - metrics=OptimizationMetricsResult(l0_total=float(fg.l0_total)), - adv_pgd_n_steps=opt.adv_pgd_n_steps, - adv_pgd_step_size=opt.adv_pgd_step_size, + metrics=OptimizationMetricsResult( + l0_total=float(fg.l0_total), + ci_masked_label_prob=opt.ci_masked_label_prob, + stoch_masked_label_prob=opt.stoch_masked_label_prob, + adv_pgd_label_prob=opt.adv_pgd_label_prob, + ), + pgd=PgdConfig(n_steps=opt.pgd.n_steps, step_size=opt.pgd.step_size) + if opt.pgd is not None + else None, ), ) diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 1ccdbf86c..e26a73462 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -4,8 +4,12 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from spd.app.backend.compute import compute_intervention_forward +from spd.app.backend.compute import ( + InterventionResult, + compute_intervention, +) from spd.app.backend.dependencies import DepDB, DepLoadedRun, DepStateManager +from spd.app.backend.optim_cis import AdvPGDConfig, LossConfig, MeanKLLossConfig from spd.app.backend.utils import log_errors from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device @@ -15,56 +19,19 @@ # ============================================================================= -class InterventionNode(BaseModel): - """A specific node to activate during intervention.""" - - layer: str - seq_pos: int - component_idx: int - - -class InterventionRequest(BaseModel): - """Request for intervention forward pass.""" - - text: str - nodes: list[InterventionNode] - top_k: int - - -class TokenPrediction(BaseModel): - """A single token prediction with probability.""" - - token: str - token_id: int - spd_prob: float - target_prob: float - logit: float - target_logit: float - - -class InterventionResponse(BaseModel): - """Response from intervention forward pass.""" - - input_tokens: list[str] - predictions_per_position: list[list[TokenPrediction]] +class AdvPgdParams(BaseModel): + n_steps: int + step_size: float class RunInterventionRequest(BaseModel): """Request to run and save an intervention.""" graph_id: int - text: str selected_nodes: list[str] # node keys (layer:seq:cIdx) - top_k: int = 10 - - -class ForkedInterventionRunSummary(BaseModel): - """Summary of a forked intervention run with modified tokens.""" - - id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result: InterventionResponse - created_at: str + nodes_to_ablate: list[str] | None = None # node keys to ablate in ablated (omit to skip) + top_k: int + adv_pgd: AdvPgdParams class InterventionRunSummary(BaseModel): @@ -72,16 +39,8 @@ class InterventionRunSummary(BaseModel): id: int selected_nodes: list[str] - result: InterventionResponse + result: InterventionResult created_at: str - forked_runs: list[ForkedInterventionRunSummary] - - -class ForkInterventionRequest(BaseModel): - """Request to fork an intervention run with modified tokens.""" - - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - top_k: int = 10 router = APIRouter(prefix="/api/intervention", tags=["intervention"]) @@ -104,100 +63,15 @@ def _parse_node_key(key: str, topology: TransformerTopology) -> tuple[str, int, return concrete_path, int(seq_str), int(cidx_str) -def _run_intervention_forward( - text: str, - selected_nodes: list[str], - top_k: int, - loaded: DepLoadedRun, -) -> InterventionResponse: - """Run intervention forward pass and return response.""" - token_ids = loaded.tokenizer.encode(text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [_parse_node_key(key, loaded.topology) for key in selected_nodes] - - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) - - -@router.post("") -@log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: - """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [ - ( - loaded.topology.canon_to_target(n.layer), - n.seq_pos, - n.component_idx, - ) - for n in request.nodes - ] - - seq_len = tokens.shape[1] +def _parse_and_validate_active_nodes( + selected_nodes: list[str], topology: TransformerTopology, seq_len: int +) -> list[tuple[str, int, int]]: + """Parse node keys and validate sequence bounds for the current prompt.""" + active_nodes = [_parse_node_key(key, topology) for key in selected_nodes] for _, seq_pos, _ in active_nodes: if seq_pos >= seq_len: raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return active_nodes @router.post("/run") @@ -206,19 +80,59 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + graph_record = db.get_graph(request.graph_id) + if graph_record is None: + raise HTTPException(status_code=404, detail="Graph not found") + graph, prompt_id = graph_record + + prompt = db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail="Prompt not found") + + token_ids = prompt.token_ids + active_nodes = _parse_and_validate_active_nodes( + request.selected_nodes, loaded.topology, len(token_ids) + ) + nodes_to_ablate = ( + _parse_and_validate_active_nodes( + request.nodes_to_ablate, loaded.topology, len(token_ids) + ) + if request.nodes_to_ablate is not None + else None + ) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Use graph's loss config if optimized, else mean KL + loss_config: LossConfig = ( + graph.optimization_params.loss + if graph.optimization_params is not None + else MeanKLLossConfig() + ) + + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=nodes_to_ablate, + tokenizer=loaded.tokenizer, + adv_pgd_config=AdvPGDConfig( + n_steps=request.adv_pgd.n_steps, + step_size=request.adv_pgd.step_size, + init="random", + ), + loss_config=loss_config, + sampling=loaded.config.sampling, + top_k=request.top_k, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, selected_nodes=request.selected_nodes, - result_json=response.model_dump_json(), + result_json=result.model_dump_json(), ) record = db.get_intervention_runs(request.graph_id) @@ -228,41 +142,25 @@ def run_and_save_intervention( return InterventionRunSummary( id=run_id, selected_nodes=request.selected_nodes, - result=response, + result=result, created_at=saved_run.created_at, - forked_runs=[], ) @router.get("/runs/{graph_id}") @log_errors def get_intervention_runs(graph_id: int, db: DepDB) -> list[InterventionRunSummary]: - """Get all intervention runs for a graph, including forked runs.""" + """Get all intervention runs for a graph.""" records = db.get_intervention_runs(graph_id) - results = [] - for r in records: - # Get forked runs for this intervention run - forked_records = db.get_forked_intervention_runs(r.id) - forked_runs = [ - ForkedInterventionRunSummary( - id=fr.id, - token_replacements=fr.token_replacements, - result=InterventionResponse.model_validate_json(fr.result_json), - created_at=fr.created_at, - ) - for fr in forked_records - ] - - results.append( - InterventionRunSummary( - id=r.id, - selected_nodes=r.selected_nodes, - result=InterventionResponse.model_validate_json(r.result_json), - created_at=r.created_at, - forked_runs=forked_runs, - ) + return [ + InterventionRunSummary( + id=r.id, + selected_nodes=r.selected_nodes, + result=InterventionResult.model_validate_json(r.result_json), + created_at=r.created_at, ) - return results + for r in records + ] @router.delete("/runs/{run_id}") @@ -271,86 +169,3 @@ def delete_intervention_run(run_id: int, db: DepDB) -> dict[str, bool]: """Delete an intervention run.""" db.delete_intervention_run(run_id) return {"success": True} - - -@router.post("/runs/{run_id}/fork") -@log_errors -def fork_intervention_run( - run_id: int, - request: ForkInterventionRequest, - loaded: DepLoadedRun, - manager: DepStateManager, -) -> ForkedInterventionRunSummary: - """Fork an intervention run with modified tokens. - - Takes the same selected_nodes from the parent run, applies token replacements - to the original prompt, and runs the intervention forward pass. - """ - db = manager.db - - # Get the parent intervention run - parent_run = db.get_intervention_run(run_id) - if parent_run is None: - raise HTTPException(status_code=404, detail="Intervention run not found") - - # Get the prompt_id from the graph - conn = db._get_conn() - row = conn.execute( - "SELECT prompt_id FROM graphs WHERE id = ?", (parent_run.graph_id,) - ).fetchone() - if row is None: - raise HTTPException(status_code=404, detail="Graph not found") - prompt_id = row["prompt_id"] - - # Get the prompt to get original token_ids - prompt = db.get_prompt(prompt_id) - if prompt is None: - raise HTTPException(status_code=404, detail="Prompt not found") - - # Apply token replacements to get modified token_ids - modified_token_ids = list(prompt.token_ids) # Make a copy - for seq_pos, new_token_id in request.token_replacements: - if seq_pos < 0 or seq_pos >= len(modified_token_ids): - raise HTTPException( - status_code=400, - detail=f"Invalid seq_pos {seq_pos} for prompt with {len(modified_token_ids)} tokens", - ) - modified_token_ids[seq_pos] = new_token_id - - # Decode the modified tokens back to text - modified_text = loaded.tokenizer.decode(modified_token_ids) - - # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) - - # Save the forked run - fork_id = db.save_forked_intervention_run( - intervention_run_id=run_id, - token_replacements=request.token_replacements, - result_json=response.model_dump_json(), - ) - - # Get the saved record for created_at - forked_records = db.get_forked_intervention_runs(run_id) - saved_fork = next((f for f in forked_records if f.id == fork_id), None) - assert saved_fork is not None - - return ForkedInterventionRunSummary( - id=fork_id, - token_replacements=request.token_replacements, - result=response, - created_at=saved_fork.created_at, - ) - - -@router.delete("/forks/{fork_id}") -@log_errors -def delete_forked_intervention_run(fork_id: int, db: DepDB) -> dict[str, bool]: - """Delete a forked intervention run.""" - db.delete_forked_intervention_run(fork_id) - return {"success": True} diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..0784a0eb4 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,317 @@ +"""Investigations endpoint for viewing agent investigation results. + +Lists and serves investigation data from SPD_OUT_DIR/investigations/. +Each investigation directory contains findings from a single agent run. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +INVESTIGATIONS_DIR = SPD_OUT_DIR / "investigations" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + title: str | None + summary: str | None + status: str | None + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + artifact_ids: list[str] + title: str | None + summary: str | None + status: str | None + + +def _parse_metadata(inv_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from an investigation directory.""" + metadata_path = inv_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except Exception: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + try: + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + except Exception: + pass + + return last_time, last_msg, count + + +def _parse_task_summary(inv_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from an investigation directory. Returns (title, summary, status).""" + summary_path = inv_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except Exception: + return None, None, None + + +def _list_artifact_ids(inv_path: Path) -> list[str]: + """List all artifact IDs for an investigation.""" + artifacts_dir = inv_path / "artifacts" + if not artifacts_dir.exists(): + return [] + return [f.stem for f in sorted(artifacts_dir.glob("graph_*.json"))] + + +def _get_created_at(inv_path: Path, metadata: dict[str, Any] | None) -> str: + """Get creation time for an investigation.""" + events_path = inv_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except Exception: + pass + + if metadata and "created_at" in metadata: + return metadata["created_at"] + + return datetime.fromtimestamp(inv_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations(loaded: DepLoadedRun) -> list[InvestigationSummary]: + """List investigations for the currently loaded run.""" + if not INVESTIGATIONS_DIR.exists(): + return [] + + wandb_path = loaded.run.wandb_path + results = [] + + for inv_path in INVESTIGATIONS_DIR.iterdir(): + if not inv_path.is_dir() or not inv_path.name.startswith("inv-"): + continue + + inv_id = inv_path.name + metadata = _parse_metadata(inv_path) + + meta_wandb_path = metadata.get("wandb_path") if metadata else None + if meta_wandb_path is None: + continue + # Normalize to canonical form for comparison (strips "runs/", "wandb:" prefix, etc.) + try: + e, p, r = parse_wandb_run_path(meta_wandb_path) + canonical_meta_path = f"{e}/{p}/{r}" + except ValueError: + continue + if canonical_meta_path != wandb_path: + continue + + events_path = inv_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(inv_path) + + explanations_path = inv_path / "explanations.jsonl" + + results.append( + InvestigationSummary( + id=inv_id, + wandb_path=meta_wandb_path, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + has_research_log=(inv_path / "research_log.md").exists(), + has_explanations=explanations_path.exists() + and explanations_path.stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{inv_id}") +def get_investigation(inv_id: str) -> InvestigationDetail: + """Get full details of an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + + if not inv_path.exists() or not inv_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + + metadata = _parse_metadata(inv_path) + + research_log = None + research_log_path = inv_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + events = [] + events_path = inv_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + explanations: list[dict[str, Any]] = [] + explanations_path = inv_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(inv_path) + artifact_ids = _list_artifact_ids(inv_path) + + return InvestigationDetail( + id=inv_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + artifact_ids=artifact_ids, + title=title, + summary=summary, + status=status, + ) + + +class LaunchRequest(BaseModel): + prompt: str + + +class LaunchResponse(BaseModel): + inv_id: str + job_id: str + + +@router.post("/launch") +def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) -> LaunchResponse: + """Launch a new investigation for the currently loaded run.""" + from spd.investigate.scripts.run_slurm import launch_investigation + + result = launch_investigation( + wandb_path=loaded.run.wandb_path, + prompt=request.prompt, + context_length=loaded.context_length, + max_turns=50, + partition="h200-reserved", + time="8:00:00", + job_suffix=None, + ) + return LaunchResponse(inv_id=result.inv_id, job_id=result.job_id) + + +@router.get("/{inv_id}/artifacts") +def list_artifacts(inv_id: str) -> list[str]: + """List all artifact IDs for an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + if not inv_path.exists(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + return _list_artifact_ids(inv_path) + + +@router.get("/{inv_id}/artifacts/{artifact_id}") +def get_artifact(inv_id: str, artifact_id: str) -> dict[str, Any]: + """Get a specific artifact by ID.""" + inv_path = INVESTIGATIONS_DIR / inv_id + artifact_path = inv_path / "artifacts" / f"{artifact_id}.json" + + if not artifact_path.exists(): + raise HTTPException( + status_code=404, + detail=f"Artifact {artifact_id} not found in {inv_id}", + ) + + data: dict[str, Any] = json.loads(artifact_path.read_text()) + return data diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..448162199 --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1573 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Callable, Generator +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_ci_only, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import _build_out_probs +from spd.app.backend.routers.pretrain_info import _get_pretrain_info +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + + +@dataclass +class InvestigationConfig: + """Configuration for investigation mode. All paths are required when in investigation mode.""" + + events_log_path: Path + investigation_dir: Path + + +_investigation_config: InvestigationConfig | None = None + + +def set_investigation_config(config: InvestigationConfig) -> None: + """Configure MCP for investigation mode.""" + global _investigation_config + _investigation_config = config + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if in investigation mode.""" + if _investigation_config is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_investigation_config.events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response. + + Per JSON-RPC 2.0 spec, exactly one of result/error must be present (not both, not neither). + Use model_dump(exclude_none=True) when serializing to avoid including null fields. + """ + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up', '2.attn.o')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., '0.mlp.up:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), + ToolDefinition( + name="save_graph_artifact", + description="""Save a graph as an artifact for inclusion in your research report. + +After calling optimize_graph and getting a graph_id, call this to save the graph +as an artifact. Then reference it in your research log using the spd:graph syntax: + +```spd:graph +artifact: graph_001 +``` + +This allows humans reviewing your investigation to see interactive circuit visualizations +inline with your research notes.""", + inputSchema={ + "type": "object", + "properties": { + "graph_id": { + "type": "integer", + "description": "The graph ID returned by optimize_graph", + }, + "caption": { + "type": "string", + "description": "Optional caption describing what this graph shows", + }, + }, + "required": ["graph_id"], + }, + ), + ToolDefinition( + name="probe_component", + description="""Fast CI probing on custom text. + +Computes causal importance values and subcomponent activations for a specific component +across all positions in the input text. Also returns next-token probabilities. + +Use this for quick, targeted analysis of how a component responds to specific inputs.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The input text to probe", + }, + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + }, + "required": ["text", "layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_activation_examples", + description="""Get activation examples from harvest data for a component. + +Returns examples showing token windows where the component fires, along with +CI values and activation strengths at each position. + +Use this to understand what inputs activate a component.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "limit": { + "type": "integer", + "description": "Maximum number of examples to return (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_attributions", + description="""Get dataset-level component dependencies from pre-computed attributions. + +Returns the top source and target components that this component attributes to/from, +aggregated over the training dataset. Both positive and negative attributions are returned. + +Use this to understand a component's role in the broader network.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up') or 'output'", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "k": { + "type": "integer", + "description": "Number of top attributions to return per direction (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_model_info", + description="""Get architecture details about the pretrained model. + +Returns model type, summary, target model config, topology, and pretrain info. +No parameters required.""", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _canonicalize_layer(layer: str, loaded: Any) -> str: + """Translate concrete layer name to canonical, passing through 'output'.""" + if layer == "output": + return layer + return loaded.topology.target_to_canon(layer) + + +def _canonicalize_key(concrete_key: str, loaded: Any) -> str: + """Translate concrete component key (e.g. 'h.0.mlp.c_fc:444') to canonical ('0.mlp.up:444').""" + layer, idx = concrete_key.rsplit(":", 1) + return f"{_canonicalize_layer(layer, loaded)}:{idx}" + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + adv_pgd=None, # AdvPGDConfig(n_steps=10, step_size=0.01, init="random"), + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + # Build output probs for response + out_probs = _build_out_probs( + ci_masked_out_logits, + target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + canonical_key = f"{layer}:{component_idx}" + + # Harvest/interp repos store concrete keys (e.g. "h.0.mlp.c_fc:444") + concrete_layer = loaded.topology.canon_to_target(layer) + concrete_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_info: {canonical_key}", + {"layer": layer, "idx": component_idx}, + ) + + result: dict[str, Any] = {"component_key": canonical_key} + + # Get interpretation + if loaded.interp is not None: + interp = loaded.interp.get_interpretation(concrete_key) + if interp is not None: + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + else: + result["interpretation"] = None + + # Get token stats + assert loaded.harvest is not None, "harvest data not loaded" + token_stats = loaded.harvest.get_token_stats() + if token_stats is not None: + input_stats = analysis.get_input_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + else: + result["token_stats"] = None + + # Get correlations (return canonical keys) + correlations = loaded.harvest.get_correlations() + if correlations is not None and analysis.has_component(correlations, concrete_key): + result["correlated_components"] = { + "precision": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "precision", top_k + ) + ], + "pmi": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, + compute_intervention, + ) + from spd.app.backend.optim_cis import MeanKLLossConfig + + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "embed", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=loaded.tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=MeanKLLossConfig(), + sampling=loaded.config.sampling, + top_k=top_k, + ) + + predictions = [] + for pos_predictions in result.ci: + pos_result = [] + for pred in pos_predictions: + pos_result.append( + { + "token": pred.token, + "token_id": pred.token_id, + "circuit_prob": round(pred.prob, 6), + "full_model_prob": round(pred.target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _require_investigation_config() -> InvestigationConfig: + """Get investigation config, raising if not in investigation mode.""" + assert _investigation_config is not None, "Not running in investigation mode" + return _investigation_config + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + config = _require_investigation_config() + content = params["content"] + research_log_path = config.investigation_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.investigate.schemas import BehaviorExplanation, ComponentInfo, Evidence + + config = _require_investigation_config() + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = config.investigation_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + config = _require_investigation_config() + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = config.investigation_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +def _tool_save_graph_artifact(params: dict[str, Any]) -> dict[str, Any]: + """Save a graph as an artifact for the research report. + + Uses the same filtering logic as the main graph API: + 1. Filter nodes by CI threshold + 2. Add pseudo nodes (wte, output) + 3. Filter edges to only active nodes + 4. Apply edge limit + """ + config = _require_investigation_config() + manager, loaded = _get_state() + + graph_id = params["graph_id"] + caption = params.get("caption") + ci_threshold = params.get("ci_threshold", 0.5) + edge_limit = params.get("edge_limit", 5000) + + _log_event( + "tool_call", + f"save_graph_artifact: graph_id={graph_id}", + {"graph_id": graph_id, "caption": caption}, + ) + + # Fetch graph from DB + result = manager.db.get_graph(graph_id) + if result is None: + raise ValueError(f"Graph with id={graph_id} not found") + + graph, prompt_id = result + + # Get tokens from prompt + prompt_record = manager.db.get_prompt(prompt_id) + if prompt_record is None: + raise ValueError(f"Prompt with id={prompt_id} not found") + + tokens = [loaded.tokenizer.get_tok_display(tid) for tid in prompt_record.token_ids] + num_tokens = len(tokens) + + # Create artifacts directory + artifacts_dir = config.investigation_dir / "artifacts" + artifacts_dir.mkdir(exist_ok=True) + + # Generate artifact ID (find max existing number to avoid collisions) + existing_nums = [] + for f in artifacts_dir.glob("graph_*.json"): + try: + num = int(f.stem.split("_")[1]) + existing_nums.append(num) + except (IndexError, ValueError): + continue + artifact_num = max(existing_nums, default=0) + 1 + artifact_id = f"graph_{artifact_num:03d}" + + # Compute out_probs from stored logits + out_probs = _build_out_probs( + graph.ci_masked_out_logits, + graph.target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Step 1: Filter nodes by CI threshold (same as main graph API) + filtered_ci_vals = {k: v for k, v in graph.node_ci_vals.items() if v > ci_threshold} + l0_total = len(filtered_ci_vals) + + # Step 2: Add pseudo nodes (embed and output) - same as _add_pseudo_layer_nodes + node_ci_vals_with_pseudo = dict(filtered_ci_vals) + for seq_pos in range(num_tokens): + node_ci_vals_with_pseudo[f"embed:{seq_pos}:0"] = 1.0 + for key, out_prob in out_probs.items(): + seq_pos, token_id = key.split(":") + node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob + + # Step 3: Filter edges to only active nodes + active_node_keys = set(node_ci_vals_with_pseudo.keys()) + filtered_edges = [ + e + for e in graph.edges + if str(e.source) in active_node_keys and str(e.target) in active_node_keys + ] + + # Step 4: Sort by strength and apply edge limit + filtered_edges.sort(key=lambda e: abs(e.strength), reverse=True) + filtered_edges = filtered_edges[:edge_limit] + + # Build edges data + edges_data = [ + { + "src": str(e.source), + "tgt": str(e.target), + "val": e.strength, + } + for e in filtered_edges + ] + + # Compute max abs attr from filtered edges + max_abs_attr = max((abs(e.strength) for e in filtered_edges), default=0.0) + + # Filter nodeSubcompActs to match nodeCiVals + filtered_subcomp_acts = { + k: v for k, v in graph.node_subcomp_acts.items() if k in node_ci_vals_with_pseudo + } + + # Build artifact data (self-contained GraphData, same structure as API response) + artifact = { + "type": "graph", + "id": artifact_id, + "caption": caption, + "graph_id": graph_id, + "data": { + "tokens": tokens, + "edges": edges_data, + "outputProbs": { + k: { + "prob": v.prob, + "logit": v.logit, + "target_prob": v.target_prob, + "target_logit": v.target_logit, + "token": v.token, + } + for k, v in out_probs.items() + }, + "nodeCiVals": node_ci_vals_with_pseudo, + "nodeSubcompActs": filtered_subcomp_acts, + "maxAbsAttr": max_abs_attr, + "l0_total": l0_total, + }, + } + + # Save artifact + artifact_path = artifacts_dir / f"{artifact_id}.json" + artifact_path.write_text(json.dumps(artifact, indent=2)) + + _log_event( + "artifact_saved", + f"Saved graph artifact: {artifact_id}", + {"artifact_id": artifact_id, "graph_id": graph_id, "path": str(artifact_path)}, + ) + + return {"artifact_id": artifact_id, "path": str(artifact_path)} + + +def _tool_probe_component(params: dict[str, Any]) -> dict[str, Any]: + """Fast CI probing on custom text for a specific component.""" + manager, loaded = _get_state() + + text = params["text"] + layer = params["layer"] + component_idx = params["component_idx"] + + _log_event( + "tool_call", + f"probe_component: '{text[:50]}...' layer={layer} idx={component_idx}", + {"text": text, "layer": layer, "component_idx": component_idx}, + ) + + token_ids = loaded.tokenizer.encode(text) + assert token_ids, "Text produced no tokens" + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + concrete_layer = loaded.topology.canon_to_target(layer) + + with manager.gpu_lock(): + result = compute_ci_only( + model=loaded.model, tokens=tokens_tensor, sampling=loaded.config.sampling + ) + + ci_values = result.ci_lower_leaky[concrete_layer][0, :, component_idx].tolist() + subcomp_acts = result.component_acts[concrete_layer][0, :, component_idx].tolist() + + # Get next token probs from target model output + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = result.target_out_probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "tokens": token_strings, + "ci_values": ci_values, + "subcomp_acts": subcomp_acts, + "next_token_probs": next_token_probs, + } + + +def _tool_get_component_activation_examples(params: dict[str, Any]) -> dict[str, Any]: + """Get activation examples from harvest data.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + limit = params.get("limit", 10) + + concrete_layer = loaded.topology.canon_to_target(layer) + component_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_activation_examples: {component_key}", + {"layer": layer, "component_idx": component_idx, "limit": limit}, + ) + + assert loaded.harvest is not None, "harvest data not loaded" + canonical_key = f"{layer}:{component_idx}" + comp = loaded.harvest.get_component(component_key) + if comp is None: + return {"component_key": canonical_key, "examples": [], "total": 0} + + examples = [] + for ex in comp.activation_examples[:limit]: + token_strings = [loaded.tokenizer.get_tok_display(t) for t in ex.token_ids] + examples.append( + { + "tokens": token_strings, + "ci_values": ex.activations["causal_importance"], + "component_acts": ex.activations["component_activation"], + } + ) + + return { + "component_key": canonical_key, + "examples": examples, + "total": len(comp.activation_examples), + "mean_ci": comp.mean_activations["causal_importance"], + } + + +# def _tool_get_component_attributions(params: dict[str, Any]) -> dict[str, Any]: +# """Get dataset-level component dependencies.""" +# _, loaded = _get_state() + +# layer = params["layer"] +# component_idx = params["component_idx"] +# k = params.get("k", 10) + +# assert loaded.attributions is not None, "dataset attributions not loaded" +# storage = loaded.attributions.get_attributions() + +# concrete_layer = loaded.topology.canon_to_target(layer) if layer != "output" else "output" +# component_key = f"{concrete_layer}:{component_idx}" + +# _log_event( +# "tool_call", +# f"get_component_attributions: {component_key}", +# {"layer": layer, "component_idx": component_idx, "k": k}, +# ) + +# is_source = storage.has_source(component_key) +# is_target = storage.has_target(component_key) + +# assert is_source or is_target, f"Component {component_key} not found in attributions" + +# w_unembed = loaded.topology.get_unembed_weight() if is_source else None + +# def _entries_to_dicts( +# entries: list[Any], +# ) -> list[dict[str, Any]]: +# return [ +# { +# "component_key": f"{_canonicalize_layer(e.layer, loaded)}:{e.component_idx}", +# "layer": _canonicalize_layer(e.layer, loaded), +# "component_idx": e.component_idx, +# "value": e.value, +# } +# for e in entries +# ] + +# positive_sources = ( +# _entries_to_dicts(storage.get_top_sources(component_key, k, "positive")) +# if is_target +# else [] +# ) +# negative_sources = ( +# _entries_to_dicts(storage.get_top_sources(component_key, k, "negative")) +# if is_target +# else [] +# ) +# positive_targets = ( +# _entries_to_dicts( +# storage.get_top_targets( +# component_key, +# k, +# "positive", +# w_unembed=w_unembed, +# include_outputs=w_unembed is not None, +# ) +# ) +# if is_source +# else [] +# ) +# negative_targets = ( +# _entries_to_dicts( +# storage.get_top_targets( +# component_key, +# k, +# "negative", +# w_unembed=w_unembed, +# include_outputs=w_unembed is not None, +# ) +# ) +# if is_source +# else [] +# ) + +# return { +# "component_key": component_key, +# "positive_sources": positive_sources, +# "negative_sources": negative_sources, +# "positive_targets": positive_targets, +# "negative_targets": negative_targets, +# } + + +def _tool_get_model_info(_params: dict[str, Any]) -> dict[str, Any]: + """Get architecture details about the pretrained model.""" + _, loaded = _get_state() + + _log_event("tool_call", "get_model_info", {}) + + info = _get_pretrain_info(loaded.config) + return info.model_dump() + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +_STREAMING_TOOLS: dict[str, Callable[..., Generator[dict[str, Any]]]] = { + "optimize_graph": _tool_optimize_graph, +} + +_SIMPLE_TOOLS: dict[str, Callable[..., dict[str, Any]]] = { + "get_component_info": _tool_get_component_info, + "run_ablation": _tool_run_ablation, + "search_dataset": _tool_search_dataset, + "create_prompt": _tool_create_prompt, + "update_research_log": _tool_update_research_log, + "save_explanation": _tool_save_explanation, + "set_investigation_summary": _tool_set_investigation_summary, + "save_graph_artifact": _tool_save_graph_artifact, + "probe_component": _tool_probe_component, + "get_component_activation_examples": _tool_get_component_activation_examples, + # "get_component_attributions": _tool_get_component_attributions, + "get_model_info": _tool_get_model_info, +} + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name in _STREAMING_TOOLS: + return _STREAMING_TOOLS[name](arguments) + + if name in _SIMPLE_TOOLS: + result = _SIMPLE_TOOLS[name](arguments) + return {"content": [{"type": "text", "text": json.dumps(result, indent=2)}]} + + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(exclude_none=True), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True) + ) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump(exclude_none=True))}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump(exclude_none=True))}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump( + exclude_none=True + ) + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump(exclude_none=True) + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump(exclude_none=True) + ) diff --git a/spd/app/backend/routers/pretrain_info.py b/spd/app/backend/routers/pretrain_info.py index 2872423c9..424f7b035 100644 --- a/spd/app/backend/routers/pretrain_info.py +++ b/spd/app/backend/routers/pretrain_info.py @@ -38,6 +38,7 @@ class TopologyInfo(BaseModel): class PretrainInfoResponse(BaseModel): model_type: str summary: str + dataset_short: str | None target_model_config: dict[str, Any] | None pretrain_config: dict[str, Any] | None pretrain_wandb_path: str | None @@ -161,6 +162,27 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) return " · ".join(parts) +_DATASET_SHORT_NAMES: dict[str, str] = { + "simplestories": "SS", + "pile": "Pile", + "tinystories": "TS", +} + + +def _get_dataset_short(pretrain_config: dict[str, Any] | None) -> str | None: + """Extract a short dataset label from the pretrain config.""" + if pretrain_config is None: + return None + dataset_name: str = ( + pretrain_config.get("train_dataset_config", {}).get("name", "") + or pretrain_config.get("dataset", "") + ).lower() + for key, short in _DATASET_SHORT_NAMES.items(): + if key in dataset_name: + return short + return None + + def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: """Extract pretrain info from an SPD config.""" model_class_name = spd_config.pretrained_model_class @@ -190,10 +212,12 @@ def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: n_blocks = target_model_config.get("n_layer", 0) if target_model_config else 0 topology = _build_topology(model_type, n_blocks) summary = _build_summary(model_type, target_model_config) + dataset_short = _get_dataset_short(pretrain_config) return PretrainInfoResponse( model_type=model_type, summary=summary, + dataset_short=dataset_short, target_model_config=target_model_config, pretrain_config=pretrain_config, pretrain_wandb_path=pretrain_wandb_path, diff --git a/spd/app/backend/routers/run_registry.py b/spd/app/backend/routers/run_registry.py new file mode 100644 index 000000000..d44a7108f --- /dev/null +++ b/spd/app/backend/routers/run_registry.py @@ -0,0 +1,95 @@ +"""Run registry endpoint. + +Returns architecture and data availability for requested SPD runs. +The canonical run list lives in the frontend; the backend just hydrates it. +""" + +import asyncio +from pathlib import Path + +from fastapi import APIRouter +from pydantic import BaseModel + +from spd.app.backend.routers.pretrain_info import _get_pretrain_info, _load_spd_config_lightweight +from spd.app.backend.utils import log_errors +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/run_registry", tags=["run_registry"]) + + +class DataAvailability(BaseModel): + harvest: bool + autointerp: bool + attributions: bool + graph_interp: bool + + +class RunInfoResponse(BaseModel): + wandb_run_id: str + architecture: str | None + availability: DataAvailability + + +def _has_glob_match(pattern_dir: Path, glob_pattern: str) -> bool: + """Check if any file matches a glob pattern under a directory.""" + if not pattern_dir.exists(): + return False + return next(pattern_dir.glob(glob_pattern), None) is not None + + +def _check_availability(run_id: str) -> DataAvailability: + """Lightweight filesystem checks for post-processing data availability.""" + harvest_dir = SPD_OUT_DIR / "harvest" / run_id + autointerp_dir = SPD_OUT_DIR / "autointerp" / run_id + attributions_dir = SPD_OUT_DIR / "dataset_attributions" / run_id + graph_interp_dir = SPD_OUT_DIR / "graph_interp" / run_id + + return DataAvailability( + harvest=_has_glob_match(harvest_dir, "h-*/harvest.db"), + autointerp=_has_glob_match(autointerp_dir, "a-*/.done"), + attributions=_has_glob_match(attributions_dir, "da-*/dataset_attributions.pt"), + graph_interp=_has_glob_match(graph_interp_dir, "*/interp.db"), + ) + + +def _get_architecture_summary(wandb_path: str) -> str | None: + """Get a short architecture label for a run. Returns None on failure.""" + try: + spd_config = _load_spd_config_lightweight(wandb_path) + info = _get_pretrain_info(spd_config) + parts: list[str] = [] + if info.dataset_short: + parts.append(info.dataset_short) + parts.append(info.model_type) + cfg = info.target_model_config + if cfg: + n_layer = cfg.get("n_layer") + n_embd = cfg.get("n_embd") + if n_layer is not None: + parts.append(f"{n_layer}L") + if n_embd is not None: + parts.append(f"d{n_embd}") + return " ".join(parts) + except Exception: + logger.exception(f"[run_registry] Failed to get architecture for {wandb_path}") + return None + + +def _build_run_info(wandb_run_id: str) -> RunInfoResponse: + _, _, run_id = parse_wandb_run_path(wandb_run_id) + return RunInfoResponse( + wandb_run_id=wandb_run_id, + architecture=_get_architecture_summary(wandb_run_id), + availability=_check_availability(run_id), + ) + + +@router.post("") +@log_errors +async def get_run_info(wandb_run_ids: list[str]) -> list[RunInfoResponse]: + """Return architecture and availability for the requested runs.""" + loop = asyncio.get_running_loop() + tasks = [loop.run_in_executor(None, _build_run_info, wid) for wid in wandb_run_ids] + return list(await asyncio.gather(*tasks)) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 0989cea54..914998a4e 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -7,9 +7,11 @@ import yaml from fastapi import APIRouter from pydantic import BaseModel +from spd.graph_interp.repo import GraphInterpRepo from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.dependencies import DepStateManager +from spd.app.backend.routers.graph_interp import MOCK_MODE as _GRAPH_INTERP_MOCK_MODE from spd.app.backend.state import RunState from spd.app.backend.utils import log_errors from spd.autointerp.repo import InterpRepo @@ -42,6 +44,8 @@ class LoadedRun(BaseModel): backend_user: str dataset_attributions_available: bool dataset_search_enabled: bool + graph_interp_available: bool + autointerp_available: bool router = APIRouter(prefix="/api", tags=["runs"]) @@ -128,6 +132,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): harvest=HarvestRepo.open_most_recent(run_id), interp=InterpRepo.open(run_id), attributions=AttributionRepo.open(run_id), + graph_interp=GraphInterpRepo.open(run_id), ) logger.info(f"[API] Run {run.id} loaded on {DEVICE}") @@ -165,6 +170,10 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: backend_user=getpass.getuser(), dataset_attributions_available=manager.run_state.attributions is not None, dataset_search_enabled=dataset_search_enabled, + # TODO(oli): Remove MOCK_MODE import after real data available + graph_interp_available=manager.run_state.graph_interp is not None + or _GRAPH_INTERP_MOCK_MODE, + autointerp_available=manager.run_state.interp is not None, ) diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index bc99dc831..61fa3d9d2 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -19,8 +19,6 @@ class OutputProbability(BaseModel): logit: float # CI-masked (SPD model) raw logit target_prob: float # Target model probability target_logit: float # Target model raw logit - adv_pgd_prob: float | None = None # Adversarial PGD probability - adv_pgd_logit: float | None = None # Adversarial PGD raw logit token: str diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 3804ce756..89ac602b3 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -32,14 +32,19 @@ data_sources_router, dataset_attributions_router, dataset_search_router, + graph_interp_router, graphs_router, intervention_router, + investigations_router, + mcp_router, pretrain_info_router, prompts_router, + run_registry_router, runs_router, ) from spd.app.backend.state import StateManager from spd.log import logger +from spd.settings import SPD_APP_DEFAULT_RUN from spd.utils.distributed_utils import get_device DEVICE = get_device() @@ -48,6 +53,11 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + import os + from pathlib import Path + + from spd.app.backend.routers.mcp import InvestigationConfig, set_investigation_config + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -58,6 +68,24 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for investigation mode (derives paths from investigation dir) + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + inv_dir = Path(investigation_dir) + set_investigation_config( + InvestigationConfig( + events_log_path=inv_dir / "events.jsonl", + investigation_dir=inv_dir, + ) + ) + logger.info(f"[STARTUP] Investigation mode enabled: dir={investigation_dir}") + + if SPD_APP_DEFAULT_RUN is not None: + from spd.app.backend.routers.runs import load_run + + logger.info(f"[STARTUP] Auto-loading default run: {SPD_APP_DEFAULT_RUN}") + load_run(SPD_APP_DEFAULT_RUN, context_length=512, manager=manager) + yield manager.close() @@ -157,8 +185,12 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_search_router) app.include_router(dataset_attributions_router) app.include_router(agents_router) +app.include_router(investigations_router) +app.include_router(mcp_router) app.include_router(data_sources_router) +app.include_router(graph_interp_router) app.include_router(pretrain_info_router) +app.include_router(run_registry_router) def cli(port: int = 8000) -> None: diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index cf71c2bc6..dd70d22ca 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -5,9 +5,15 @@ - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +from fastapi import HTTPException +from spd.graph_interp.repo import GraphInterpRepo + from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB, Run from spd.autointerp.repo import InterpRepo @@ -32,6 +38,7 @@ class RunState: harvest: HarvestRepo | None interp: InterpRepo | None attributions: AttributionRepo | None + graph_interp: GraphInterpRepo | None @dataclass @@ -62,6 +69,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -104,3 +112,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() diff --git a/spd/app/frontend/package-lock.json b/spd/app/frontend/package-lock.json index b6c451303..32da0218c 100644 --- a/spd/app/frontend/package-lock.json +++ b/spd/app/frontend/package-lock.json @@ -7,6 +7,9 @@ "": { "name": "frontend", "version": "0.0.0", + "dependencies": { + "marked": "^17.0.1" + }, "devDependencies": { "@eslint/js": "^9.38.0", "@sveltejs/vite-plugin-svelte": "^6.2.1", @@ -2347,6 +2350,18 @@ "@jridgewell/sourcemap-codec": "^1.5.5" } }, + "node_modules/marked": { + "version": "17.0.1", + "resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz", + "integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==", + "license": "MIT", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 20" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", diff --git a/spd/app/frontend/package.json b/spd/app/frontend/package.json index f54e1bb3d..f298885ce 100644 --- a/spd/app/frontend/package.json +++ b/spd/app/frontend/package.json @@ -27,5 +27,8 @@ "typescript": "~5.9.3", "typescript-eslint": "^8.46.2", "vite": "^7.1.7" + }, + "dependencies": { + "marked": "^17.0.1" } } diff --git a/spd/app/frontend/src/app.css b/spd/app/frontend/src/app.css index 8bb0c490f..bf6649aee 100644 --- a/spd/app/frontend/src/app.css +++ b/spd/app/frontend/src/app.css @@ -1,22 +1,22 @@ :root { - /* Punchy Research - crisp whites, bold contrasts */ + /* Goodfire-inspired - warm whites, navy text, vibrant blue accent */ --bg-base: #ffffff; --bg-surface: #ffffff; --bg-elevated: #ffffff; - --bg-inset: #f8f9fa; - --bg-hover: #f3f4f6; + --bg-inset: #f7f6f2; + --bg-hover: #f0efeb; - --border-subtle: #e0e0e0; - --border-default: #c0c0c0; - --border-strong: #888888; + --border-subtle: #e5e3dc; + --border-default: #c8c5bc; + --border-strong: #8a8780; - --text-primary: #111111; - --text-secondary: #555555; - --text-muted: #999999; + --text-primary: #1d272a; + --text-secondary: #646464; + --text-muted: #b4b4b4; - --accent-primary: #2563eb; - --accent-primary-bright: #3b82f6; - --accent-primary-dim: #1d4ed8; + --accent-primary: #7c4d33; + --accent-primary-bright: #96613f; + --accent-primary-dim: #5e3a27; --status-positive: #16a34a; --status-positive-bright: #22c55e; @@ -24,8 +24,10 @@ --status-negative-bright: #ef4444; --status-warning: #eab308; --status-warning-bright: #facc15; - --status-info: #2563eb; - --status-info-bright: #3b82f6; + --status-info: #4d65ff; + --status-info-bright: #6b7fff; + + --focus-ring: #4d65ff; /* Typography - Clean system fonts with mono for code */ --font-mono: "SF Mono", "Menlo", "Monaco", "Consolas", monospace; diff --git a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte index e20ba1adf..c9c304950 100644 --- a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte +++ b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte @@ -1,23 +1,29 @@ @@ -106,21 +112,22 @@
- @@ -129,58 +136,69 @@
-
- {#if displaySettings.centerOnPeak} -
- {#each paginatedIndices as idx (idx)} - {@const fp = firingPositions[idx]} -
-
- -
-
- + {#if loading} +
+
+ {#each Array(pageSize) as _, i (i)} +
+ {/each} +
+
+ {:else} + {@const d = loaded!} +
+ {#if displaySettings.centerOnPeak} +
+ {#each paginatedIndices as idx (idx)} + {@const fp = firingPositions[idx]} +
+
+ +
+
+ +
+
+ +
-
+ {/each} +
+ {:else} +
+ {#each paginatedIndices as idx (idx)} +
-
- {/each} -
- {:else} -
- {#each paginatedIndices as idx (idx)} -
- -
- {/each} -
- {/if} -
+ {/each} +
+ {/if} +
+ {/if}
diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index d20831c1a..232e4cd39 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -1,15 +1,16 @@
@@ -288,7 +304,7 @@
@@ -412,26 +428,22 @@ {/if} - +
+ + {#if currentGraphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail.status === "loading"} -
Loading component data...
- {:else if componentData.componentDetail.status === "loaded"} - - {:else if componentData.componentDetail.status === "error"} - Error loading component data: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading component data: {String(activationExamples.error)} {:else} - Something went wrong loading component data. + {/if} + import { getContext, onMount } from "svelte"; + import { computeMaxAbsComponentAct } from "../lib/colors"; + import { mapLoadable } from "../lib/index"; + import { anyCorrelationStatsEnabled } from "../lib/displaySettings.svelte"; + import { useComponentDataExpectCached } from "../lib/useComponentDataExpectCached.svelte"; + import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; + import ActivationContextsPagedTable, { type ActivationExamplesData } from "./ActivationContextsPagedTable.svelte"; + import ComponentProbeInput from "./ComponentProbeInput.svelte"; + import ComponentCorrelationMetrics from "./ui/ComponentCorrelationMetrics.svelte"; + import DatasetAttributionsSection from "./ui/DatasetAttributionsSection.svelte"; + import InterpretationBadge from "./ui/InterpretationBadge.svelte"; + import SectionHeader from "./ui/SectionHeader.svelte"; + import StatusText from "./ui/StatusText.svelte"; + import TokenStatsSection from "./ui/TokenStatsSection.svelte"; + + const runState = getContext(RUN_KEY); + + type Props = { + layer: string; + cIdx: number; + }; + + let { layer, cIdx }: Props = $props(); + + const intruderScore = $derived(runState.getIntruderScore(`${layer}:${cIdx}`)); + + const componentData = useComponentDataExpectCached(); + + onMount(() => { + componentData.load(layer, cIdx); + }); + + const inputTokenLists = $derived.by(() => { + const tokenStats = componentData.tokenStats; + if (tokenStats.status !== "loaded" || tokenStats.data === null) return null; + return [ + { + title: "Top Precision", + mathNotation: "P(component fires | token)", + items: tokenStats.data.input.top_precision.map(([token, value]) => ({ + token, + value, + })), + maxScale: 1, + }, + ]; + }); + + const outputTokenLists = $derived.by(() => { + const tokenStats = componentData.tokenStats; + if (tokenStats.status !== "loaded" || tokenStats.data === null) return null; + const maxAbsPmi = Math.max( + tokenStats.data.output.top_pmi[0]?.[1] ?? 0, + Math.abs(tokenStats.data.output.bottom_pmi?.[0]?.[1] ?? 0), + ); + return [ + { + title: "Top PMI", + mathNotation: "positive association with predictions", + items: tokenStats.data.output.top_pmi.map(([token, value]) => ({ token, value })), + maxScale: maxAbsPmi, + }, + { + title: "Bottom PMI", + mathNotation: "negative association with predictions", + items: tokenStats.data.output.bottom_pmi.map(([token, value]) => ({ + token, + value, + })), + maxScale: maxAbsPmi, + }, + ]; + }); + + function formatNumericalValue(val: number): string { + return Math.abs(val) < 0.001 ? val.toExponential(2) : val.toFixed(3); + } + + const maxAbsComponentAct = $derived.by(() => { + if (componentData.componentDetail.status !== "loaded") return 1; + return computeMaxAbsComponentAct(componentData.componentDetail.data.example_component_acts); + }); + + const activationExamples = $derived( + mapLoadable( + componentData.componentDetail, + (d): ActivationExamplesData => ({ + tokens: d.example_tokens, + ci: d.example_ci, + componentActs: d.example_component_acts, + maxAbsComponentAct: computeMaxAbsComponentAct(d.example_component_acts), + }), + ), + ); + + +
+
+

{layer}:{cIdx}

+
+ {#if componentData.componentDetail.status === "loaded"} + Mean CI: {formatNumericalValue(componentData.componentDetail.data.mean_ci)} + {/if} + {#if intruderScore !== null} + Intruder: {Math.round(intruderScore * 100)}% + {/if} +
+
+ + + +
+ + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + + {/if} +
+ + + + {#if componentData.datasetAttributions.status === "uninitialized"} + uninitialized + {:else if componentData.datasetAttributions.status === "loaded"} + {#if componentData.datasetAttributions.data !== null} + + {:else} + No dataset attributions available. + {/if} + {:else if componentData.datasetAttributions.status === "loading"} +
+ + Loading... +
+ {:else if componentData.datasetAttributions.status === "error"} +
+ + Error: {String(componentData.datasetAttributions.error)} +
+ {/if} + +
+ +
+ {#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} + Loading token stats... + {:else if componentData.tokenStats.status === "error"} + Error: {String(componentData.tokenStats.error)} + {:else} + + + + {/if} +
+
+ + {#if anyCorrelationStatsEnabled()} +
+ + {#if componentData.correlations.status === "loading"} + Loading... + {:else if componentData.correlations.status === "loaded" && componentData.correlations.data} + + {:else if componentData.correlations.status === "error"} + Error loading correlations: {String(componentData.correlations.error)} + {:else} + No correlations available. + {/if} +
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClusterPathInput.svelte b/spd/app/frontend/src/components/ClusterPathInput.svelte index 27b066c2c..6adcfb2b3 100644 --- a/spd/app/frontend/src/components/ClusterPathInput.svelte +++ b/spd/app/frontend/src/components/ClusterPathInput.svelte @@ -1,8 +1,8 @@ + +
+ {#if clusterMapping} + + {:else} + No clusters loaded. Use the cluster path input in the header bar to load a cluster mapping. + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClustersViewer.svelte b/spd/app/frontend/src/components/ClustersViewer.svelte new file mode 100644 index 000000000..324b04e77 --- /dev/null +++ b/spd/app/frontend/src/components/ClustersViewer.svelte @@ -0,0 +1,252 @@ + + +
+ {#if selectedClusterId === null} +
+

Clusters ({clusterGroups.sorted.length})

+ {#each clusterGroups.sorted as [clusterId, members] (clusterId)} + {@const previewLabels = getPreviewLabels(members)} + + {/each} + {#if clusterGroups.singletons.length > 0} + + {/if} +
+ {:else} +
+
+ +

+ {selectedClusterId === "unclustered" ? "Unclustered" : `Cluster ${selectedClusterId}`} +

+ {selectedMembers.length} components +
+
+ {#each selectedMembers as member (`${member.layer}:${member.cIdx}`)} +
+ +
+ {/each} +
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/DataSourcesTab.svelte b/spd/app/frontend/src/components/DataSourcesTab.svelte index bc9282c27..c54a1fa0a 100644 --- a/spd/app/frontend/src/components/DataSourcesTab.svelte +++ b/spd/app/frontend/src/components/DataSourcesTab.svelte @@ -29,7 +29,7 @@ }); function formatConfigValue(value: unknown): string { - if (value === null || value === undefined) return "—"; + if (value === null || value === undefined) return "\u2014"; if (typeof value === "object") return JSON.stringify(value); return String(value); } @@ -50,116 +50,91 @@ } -
- {#if runState.run.status === "loaded" && runState.run.data.config_yaml} -
-

Run Config

-
{runState.run.data.config_yaml}
-
- {/if} - - - {#if pretrainData.status === "loading"} -
-

Target Model

-

Loading target model info...

-
- {:else if pretrainData.status === "loaded"} - {@const pt = pretrainData.data} -
-

Target Model

-
- Architecture - {pt.summary} - - {#if pt.pretrain_wandb_path} - Pretrain run - {pt.pretrain_wandb_path} - {/if} -
+
+ +
+ {#if runState.run.status === "loaded" && runState.run.data.config_yaml} +
+

Run Config

+
{runState.run.data.config_yaml}
+
+ {/if} - {#if pt.topology} -
-

Topology

- +
+

Target Model

+ {#if pretrainData.status === "loading"} +

Loading...

+ {:else if pretrainData.status === "loaded"} + {@const pt = pretrainData.data} +
+ Architecture + {pt.summary} + {#if pt.pretrain_wandb_path} + Pretrain run + {pt.pretrain_wandb_path} + {/if}
+ {#if pt.topology} +
+ +
+ {/if} + {#if pt.pretrain_config} +
+ Pretraining config +
{formatPretrainConfigYaml(pt.pretrain_config)}
+
+ {/if} + {:else if pretrainData.status === "error"} +

Failed to load target model info

{/if} - - {#if pt.pretrain_config} -
- Pretraining config -
{formatPretrainConfigYaml(pt.pretrain_config)}
-
- {/if} -
- {:else if pretrainData.status === "error"} -
-

Target Model

-

Failed to load target model info

- {/if} - - {#if data.status === "loading"} -

Loading data sources...

- {:else if data.status === "error"} -

Failed to load data sources: {data.error}

- {:else if data.status === "loaded"} - {@const { harvest, autointerp, attributions } = data.data} - - {#if !harvest && !autointerp && !attributions} -

No pipeline data available for this run.

- {/if} - - {#if harvest} -
-

Harvest

+
+ + +
+ +
+
+ +

Harvest

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.harvest} + {@const harvest = data.data.harvest}
Subrun {harvest.subrun_id} - Components {harvest.n_components.toLocaleString()} - Intruder eval {harvest.has_intruder_scores ? "yes" : "no"} - {#each Object.entries(harvest.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} - - {#if attributions} -
-

Dataset Attributions

-
- Subrun - {attributions.subrun_id} - - Batches - {attributions.n_batches_processed.toLocaleString()} - - Tokens - {attributions.n_tokens_processed.toLocaleString()} - - CI threshold - {attributions.ci_threshold} -
-
- {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
- {#if autointerp} -
-

Autointerp

+ +
+
+ +

Autointerp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.autointerp} + {@const autointerp = data.data.autointerp}
Subrun {autointerp.subrun_id} - Interpretations {autointerp.n_interpretations.toLocaleString()} - Eval scores {#if autointerp.eval_scores.length > 0} @@ -168,65 +143,156 @@ none {/if} - {#each Object.entries(autointerp.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} - {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Dataset Attributions

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.attributions} + {@const attributions = data.data.attributions} +
+ Subrun + {attributions.subrun_id} + Tokens + {attributions.n_tokens_processed.toLocaleString()} + CI threshold + {attributions.ci_threshold} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Graph Interp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.graph_interp} + {@const graph_interp = data.data.graph_interp} +
+ Subrun + {graph_interp.subrun_id} + {#each Object.entries(graph_interp.label_counts) as [key, value] (key)} + {key} labels + {value.toLocaleString()} + {/each} + {#if graph_interp.config} + {#each Object.entries(graph_interp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} + {/if} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+
diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte new file mode 100644 index 000000000..b7752cb5f --- /dev/null +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -0,0 +1,645 @@ + + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ + {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} + {#if selected.data.research_log} + + {:else} +

No research log available

+ {/if} + {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ +
{ + e.preventDefault(); + handleLaunch(); + }} + > + + +
+ {#if launchState.status === "error"} +
{launchState.error}
+ {/if} + + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

+ No investigations found. Run spd-investigate to create one. +

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ModelGraph.svelte b/spd/app/frontend/src/components/ModelGraph.svelte new file mode 100644 index 000000000..ae49a25e3 --- /dev/null +++ b/spd/app/frontend/src/components/ModelGraph.svelte @@ -0,0 +1,520 @@ + + +
+ +
+
+ + + +
+
+ +
+
+ +
+
+ +
+
+ {filteredNodes.length} nodes, {visibleEdges.length} edges +
+
+ + + +
+
+ + + + + + {#if tooltipNode} +
+
{tooltipNode.label}
+
+ {tooltipNode.confidence} + {tooltipNode.key} +
+
+ {/if} + + + {#if selectedNodeKey} + {@const selectedNode = layout.nodes.get(selectedNodeKey)} + {#if selectedNode} +
+
+ {selectedNode.label} + {selectedNode.confidence} + +
+
{selectedNode.key}
+
+ {#if selectedNodeEdges.length > 0} +
+ {#each selectedNodeEdges as e, i (i)} + {@const other = e.source === selectedNodeKey ? e.target : e.source} + {@const otherNode = layout.nodes.get(other)} +
+ {e.source === selectedNodeKey ? "to" : "from"} + {otherNode?.label ?? other} + {e.attribution.toFixed(3)} +
+ {/each} +
+ {:else} + No edges + {/if} +
+
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ProbColoredTokens.svelte b/spd/app/frontend/src/components/ProbColoredTokens.svelte index 9a81b4858..7b30870c3 100644 --- a/spd/app/frontend/src/components/ProbColoredTokens.svelte +++ b/spd/app/frontend/src/components/ProbColoredTokens.svelte @@ -1,5 +1,7 @@ {#each tokens as tok, i (i)}{tok}{#each tokens as tok, i (i)}{@const prob = getProbAtPosition(nextTokenProbs, i)}{/each} @@ -33,34 +26,10 @@ display: inline-flex; flex-wrap: wrap; gap: 1px; - font-family: var(--font-mono); } - .prob-token { - padding: 1px 2px; + .prob-token-wrapper { border-right: 1px solid var(--border-subtle); - position: relative; - white-space: pre; - } - - .prob-token::after { - content: attr(data-tooltip); - position: absolute; - top: calc(100% + 4px); - left: 0; - background: var(--bg-elevated); - border: 1px solid var(--border-strong); - color: var(--text-primary); - padding: var(--space-1) var(--space-2); - font-size: var(--text-xs); - font-family: var(--font-mono); - white-space: nowrap; - opacity: 0; - pointer-events: none; - z-index: 1000; - } - - .prob-token:hover::after { - opacity: 1; + padding: 1px 0; } diff --git a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte index da0d0e935..c5e182917 100644 --- a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte @@ -1,6 +1,6 @@
@@ -40,7 +57,15 @@ {/if} @@ -93,6 +128,10 @@ {runState.run.error}
{/if} + +
+ +
{#if runState.prompts.status === "loaded"}
@@ -109,6 +148,11 @@
+ {#if runState.clusterMapping} +
+ +
+ {/if} {:else if runState.run.status === "loading" || runState.prompts.status === "loading"}

Loading run...

diff --git a/spd/app/frontend/src/components/TokenHighlights.svelte b/spd/app/frontend/src/components/TokenHighlights.svelte index 4916c1f00..456cdbc72 100644 --- a/spd/app/frontend/src/components/TokenHighlights.svelte +++ b/spd/app/frontend/src/components/TokenHighlights.svelte @@ -1,5 +1,6 @@ + +
+ {#if caption} +
{caption}
+ {/if} + + +
+ + +
+ + + {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} + + {getRowLabel(getRowKey(layer))} + + {/each} + + +
+ +
+ + + + + {@html edgesSvg} + + + + {#each Object.entries(layout.nodePositions) as [key, pos] (key)} + {@const style = nodeStyles[key]} + {@const [layer, seqIdxStr, cIdxStr] = key.split(":")} + {@const seqIdx = parseInt(seqIdxStr)} + {@const cIdx = parseInt(cIdxStr)} + + handleNodeHover(e, layer, seqIdx, cIdx)} + onmouseleave={handleNodeLeave} + /> + + + {/each} + + + + +
+ + + {#each data.tokens as token, i (i)} + {@const colLeft = layout.seqXStarts[i] + 8} + + {token} + + [{i}] + {/each} + + +
+
+
+ +
+ L0: {data.l0_total} · Edges: {filteredEdges.length} +
+ + + {#if hoveredNode && runState} + (isHoveringTooltip = true)} + onMouseLeave={() => { + isHoveringTooltip = false; + hoveredNode = null; + }} + /> + {/if} +
+ + diff --git a/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte new file mode 100644 index 000000000..ef9b8d40d --- /dev/null +++ b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte @@ -0,0 +1,223 @@ + + +
+ {#each contentBlocks as block, i (i)} + {#if block.type === "html"} + +
{@html block.content}
+ {:else if block.type === "graph"} + {@const artifact = artifacts[block.artifactId]} + {#if artifact} + + {:else if artifactsLoading} +
+ Loading graph: {block.artifactId}... +
+ {:else} +
+ Graph artifact not found: {block.artifactId} +
+ {/if} + {/if} + {/each} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte index a0d663208..91083d851 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte @@ -1,6 +1,7 @@
@@ -208,30 +212,26 @@
- +
+ + {#if graphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail.status === "uninitialized"} - uninitialized - {:else if componentData.componentDetail.status === "loading"} - Loading details... - {:else if componentData.componentDetail.status === "loaded"} - {#if componentData.componentDetail.data.example_tokens.length > 0} - - {/if} - {:else if componentData.componentDetail.status === "error"} - Error loading details: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + {/if}
@@ -243,34 +243,29 @@ title="Prompt Attributions" incomingLabel="Incoming" outgoingLabel="Outgoing" - {incomingPositive} - {incomingNegative} - {outgoingPositive} - {outgoingNegative} + {incoming} + {outgoing} pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE} onClick={handleEdgeNodeClick} - {tokens} - {outputProbs} /> {/if} - {#if componentData.datasetAttributions.status === "uninitialized"} - uninitialized + {#if componentData.datasetAttributions.status === "loading" || componentData.datasetAttributions.status === "uninitialized"} +
+ +
+
+
+
+
{:else if componentData.datasetAttributions.status === "loaded"} {#if componentData.datasetAttributions.data !== null} - {:else} - No dataset attributions available. {/if} - {:else if componentData.datasetAttributions.status === "loading"} -
- - Loading... -
{:else if componentData.datasetAttributions.status === "error"}
@@ -282,7 +277,12 @@
{#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} - Loading token stats... +
+
+
+
+
+
{:else if componentData.tokenStats.status === "error"} Error: {String(componentData.tokenStats.error)} {:else} @@ -306,7 +306,10 @@
{#if componentData.correlations.status === "loading"} - Loading... +
+
+
+
{:else if componentData.correlations.status === "loaded" && componentData.correlations.data} diff --git a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte index 1f3c09a01..23619c281 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte @@ -1,46 +1,54 @@
-
- {#each state.stages as stage, i (i)} - {@const isCurrent = i === state.currentStage} - {@const isComplete = i < state.currentStage} -
-
- {i + 1} - {stage.name} - {#if isComplete} - - {/if} -
- {#if isCurrent} -
- {#if stage.progress !== null} -
- {:else} -
+
+ {#if ciSnapshot} + + {/if} +
+ {#each state.stages as stage, i (i)} + {@const isCurrent = i === state.currentStage} + {@const isComplete = i < state.currentStage} +
+
+ {i + 1} + {stage.name} + {#if isComplete} + {/if}
- {:else if isComplete} -
-
-
- {:else} -
- {/if} -
- {/each} + {#if isCurrent} +
+ {#if stage.progress !== null} +
+ {:else} +
+ {/if} +
+ {:else if isComplete} +
+
+
+ {:else} +
+ {/if} +
+ {/each} +
@@ -54,6 +62,13 @@ z-index: 100; } + .content { + display: flex; + flex-direction: column; + align-items: center; + gap: var(--space-6); + } + .stages { display: flex; flex-direction: column; diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index 95433d57e..8f85b9a16 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -1,10 +1,11 @@
+ {#if displaySettings.showEdgeAttributions && wteOutgoing.length > 0} + {}} + /> + {/if} {:else if isOutput} - + {:else if !hideNodeCard} {#key `${hoveredNode.layer}:${hoveredNode.cIdx}`} diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte new file mode 100644 index 000000000..3fa6a0213 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte @@ -0,0 +1,134 @@ + + +
+
+ + Step {snapshot.step}/{snapshot.total_steps} + + + L0: {Math.round(snapshot.l0_total)} / {initialL0} + ({(fractionRemaining * 100).toFixed(0)}%) + + {#if snapshot.loss > 0} + loss: {snapshot.loss.toFixed(4)} + {/if} +
+ +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte index fa3413320..83d9f0594 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte @@ -21,42 +21,86 @@
- steps{optimization.steps} - imp_min{optimization.imp_min_coeff} - pnorm{optimization.pnorm} - beta{optimization.beta} - mask{optimization.mask_type} - + steps{optimization.steps} + imp_min{optimization.imp_min_coeff} + pnorm{optimization.pnorm} + beta{optimization.beta} + mask{optimization.mask_type} + {optimization.loss.type}{optimization.loss.coeff} - - pos{optimization.loss.position}{#if tokenAtPos !== null} - ({tokenAtPos}){/if} + + pos + {optimization.loss.position} + {#if tokenAtPos !== null} + ({tokenAtPos}) + {/if} - {#if optimization.loss.type === "ce"} - + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + label({optimization.loss.label_str}) {/if} - {#if optimization.adv_pgd_n_steps !== null} - - adv_steps{optimization.adv_pgd_n_steps} + {#if optimization.pgd} + + pgd_steps{optimization.pgd.n_steps} - - adv_lr{optimization.adv_pgd_step_size} + + pgd_lr{optimization.pgd.step_size} {/if} - + L0{optimization.metrics.l0_total.toFixed(1)} - {#if optimization.loss.type === "ce"} - + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + CI prob{formatProb(optimization.metrics.ci_masked_label_prob)} - + stoch prob{formatProb(optimization.metrics.stoch_masked_label_prob)} + + adv prob{formatProb(optimization.metrics.adv_pgd_label_prob)} + {/if}
diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte index f762cdff9..3c15034b4 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte @@ -1,15 +1,19 @@
@@ -68,48 +68,57 @@ /> Cross-Entropy +
- -
- At position - { - if (e.currentTarget.value === "") return; - const position = parseInt(e.currentTarget.value); - onChange({ ...config, loss: { ...config.loss, position } }); - }} - min={0} - max={tokens.length - 1} - step={1} - /> - {#if tokenAtSeqPos !== null} - ({tokenAtSeqPos}) - {/if} - {#if config.loss.type === "ce"} - , predict - { - if (config.loss.type !== "ce") - throw new Error( - "inconsistent state: Token dropdown rendered but loss not type CE but no label token", - ); - - if (tokenId !== null) { - onChange({ - ...config, - loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, - }); - } - }} - placeholder="token..." - /> - {/if} + +
+ +
+ {#each tokens as tok, i (i)} + {@const prob = getProbAtPosition(nextTokenProbs, i)} + + {/each} +
+
+ pos {config.loss.position} + {#if config.loss.type === "ce" || config.loss.type === "logit"} + {config.loss.type === "logit" ? "maximize" : "predict"} + { + if (config.loss.type !== "ce" && config.loss.type !== "logit") + throw new Error("inconsistent state: Token dropdown rendered but loss type has no label"); + + if (tokenId !== null) { + onChange({ + ...config, + loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, + }); + } + }} + placeholder="token..." + /> + {/if} +
@@ -256,7 +265,7 @@ display: flex; flex-direction: column; gap: var(--space-3); - max-width: 400px; + max-width: 500px; } .loss-type-options { @@ -296,44 +305,99 @@ color: var(--text-primary); } - .target-section { + .position-section { display: flex; - align-items: center; + flex-direction: column; gap: var(--space-2); - flex-wrap: wrap; - padding: var(--space-2); - background: var(--bg-surface); - border: 1px solid var(--border-default); } - .target-label { + .section-label { font-size: var(--text-xs); + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; color: var(--text-muted); } - .pos-input { - width: 50px; - padding: var(--space-1) var(--space-2); + .token-strip { + display: flex; + flex-wrap: wrap; + gap: 2px; + padding: var(--space-2); + background: var(--bg-inset); border: 1px solid var(--border-default); - background: var(--bg-base); - color: var(--text-primary); - font-size: var(--text-sm); font-family: var(--font-mono); + font-size: var(--text-sm); } - .pos-input:focus { - outline: none; - border-color: var(--accent-primary-dim); + .strip-token { + padding: 2px 2px; + border: 1px solid var(--border-subtle); + border-radius: 2px; + cursor: pointer; + white-space: pre; + font-family: inherit; + font-size: inherit; + color: var(--text-primary); + background: transparent; + position: relative; + transition: + border-color var(--transition-fast), + box-shadow var(--transition-fast); } - .token { - white-space: pre; + .strip-token:hover { + border-color: var(--border-strong); + } + + .strip-token.selected { + border-color: var(--accent-primary); + box-shadow: 0 0 0 1px var(--accent-primary); + z-index: 1; + } + + .strip-token::after { + content: attr(title); + position: absolute; + bottom: calc(100% + 4px); + left: 50%; + transform: translateX(-50%); + background: var(--bg-elevated); + border: 1px solid var(--border-strong); + color: var(--text-primary); + padding: var(--space-1) var(--space-2); + font-size: var(--text-xs); + white-space: nowrap; + opacity: 0; + pointer-events: none; + z-index: 100; + border-radius: var(--radius-sm); + } + + .strip-token:hover::after { + opacity: 1; + } + + .position-info { + display: flex; + align-items: center; + gap: var(--space-2); + } + + .pos-label { + font-size: var(--text-xs); font-family: var(--font-mono); + color: var(--text-muted); background: var(--bg-inset); - padding: 0 var(--space-1); + padding: var(--space-1) var(--space-2); border-radius: var(--radius-sm); } + .predict-label { + font-size: var(--text-xs); + color: var(--text-muted); + } + .slider-section { display: flex; flex-direction: column; @@ -346,14 +410,6 @@ align-items: center; } - .section-label { - font-size: var(--text-xs); - font-weight: 600; - text-transform: uppercase; - letter-spacing: 0.05em; - color: var(--text-muted); - } - .imp-min-input { width: 80px; padding: var(--space-1) var(--space-2); diff --git a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte index 4fdd57b1b..328e5449d 100644 --- a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte @@ -1,14 +1,38 @@
@@ -61,13 +77,6 @@ 2, )})
- {#if singlePosEntry.adv_pgd_prob !== null && singlePosEntry.adv_pgd_logit !== null} -
- Adversarial: {(singlePosEntry.adv_pgd_prob * 100).toFixed(1)}% (logit: {singlePosEntry.adv_pgd_logit.toFixed( - 2, - )}) -
- {/if}

Position: @@ -83,10 +92,6 @@ Logit Target Logit - {#if hasAdvPgd} - Adv - Logit - {/if} @@ -97,15 +102,22 @@ {pos.logit.toFixed(2)} {(pos.target_prob * 100).toFixed(2)}% {pos.target_logit.toFixed(2)} - {#if hasAdvPgd} - {pos.adv_pgd_prob !== null ? (pos.adv_pgd_prob * 100).toFixed(2) + "%" : "—"} - {pos.adv_pgd_logit !== null ? pos.adv_pgd_logit.toFixed(2) : "—"} - {/if} {/each} {/if} + {#if displaySettings.showEdgeAttributions && outputIncoming.length > 0} + {}} + /> + {/if}

diff --git a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte index 077917234..63a1062d0 100644 --- a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte +++ b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte @@ -1,16 +1,19 @@
Center on peak +
+
+

Edge Variant

+

Attribution target: value or |value|

+
+ {#each edgeVariants as variant (variant)} + + {/each} +
+

Component Filtering

Filter components in Components tab by mean CI

diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte index 844cb7c04..ad3f821f5 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte @@ -1,5 +1,5 @@ -{#if hasAnyIncoming} +{#if incoming.length > 0}
-
- {#if incomingPositive.length > 0} -
- -
- {/if} - {#if incomingNegative.length > 0} -
- -
- {/if} -
+
{/if} -{#if hasAnyOutgoing} +{#if outgoing.length > 0}
-
- {#if outgoingPositive.length > 0} -
- -
- {/if} - {#if outgoingNegative.length > 0} -
- -
- {/if} -
+
{/if} @@ -110,17 +36,4 @@ flex-direction: column; gap: var(--space-2); } - - .pos-neg-row { - display: grid; - grid-template-columns: 1fr 1fr; - gap: var(--space-3); - } - - .edge-list { - min-width: 0; - display: flex; - flex-direction: column; - gap: var(--space-1); - } diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte index aa98848b5..3809aceaf 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte @@ -1,7 +1,8 @@ + +
+ + + {#if expanded && detail} +
+
+
+ Input + {#if detail.input?.reasoning} +

{detail.input.reasoning}

+ {/if} + {#each incomingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+ Output + {#if detail.output?.reasoning} +

{detail.output.reasoning}

+ {/if} + {#each outgoingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ui/TokenSpan.svelte b/spd/app/frontend/src/components/ui/TokenSpan.svelte new file mode 100644 index 000000000..4a5fa9b81 --- /dev/null +++ b/spd/app/frontend/src/components/ui/TokenSpan.svelte @@ -0,0 +1,43 @@ + + +{sanitizeToken(token)} + + diff --git a/spd/app/frontend/src/lib/api/correlations.ts b/spd/app/frontend/src/lib/api/correlations.ts index 2e56c3c7e..8dcc63f04 100644 --- a/spd/app/frontend/src/lib/api/correlations.ts +++ b/spd/app/frontend/src/lib/api/correlations.ts @@ -3,7 +3,7 @@ */ import type { SubcomponentCorrelationsResponse, TokenStatsResponse } from "../promptAttributionsTypes"; -import { apiUrl, fetchJson } from "./index"; +import { ApiError, apiUrl, fetchJson } from "./index"; export async function getComponentCorrelations( layer: string, @@ -47,10 +47,18 @@ export async function getIntruderScores(): Promise> { return fetchJson>("/api/correlations/intruder_scores"); } -export async function getInterpretationDetail(layer: string, componentIdx: number): Promise { - return fetchJson( - `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, - ); +export async function getInterpretationDetail( + layer: string, + componentIdx: number, +): Promise { + try { + return await fetchJson( + `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, + ); + } catch (e) { + if (e instanceof ApiError && e.status === 404) return null; + throw e; + } } export async function requestComponentInterpretation( diff --git a/spd/app/frontend/src/lib/api/dataSources.ts b/spd/app/frontend/src/lib/api/dataSources.ts index e715af1b1..ac20b7220 100644 --- a/spd/app/frontend/src/lib/api/dataSources.ts +++ b/spd/app/frontend/src/lib/api/dataSources.ts @@ -20,15 +20,21 @@ export type AutointerpInfo = { export type AttributionsInfo = { subrun_id: string; - n_batches_processed: number; n_tokens_processed: number; ci_threshold: number; }; +export type GraphInterpInfoDS = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + export type DataSourcesResponse = { harvest: HarvestInfo | null; autointerp: AutointerpInfo | null; attributions: AttributionsInfo | null; + graph_interp: GraphInterpInfoDS | null; }; export async function fetchDataSources(): Promise { diff --git a/spd/app/frontend/src/lib/api/datasetAttributions.ts b/spd/app/frontend/src/lib/api/datasetAttributions.ts index f995a33f6..030eae6c6 100644 --- a/spd/app/frontend/src/lib/api/datasetAttributions.ts +++ b/spd/app/frontend/src/lib/api/datasetAttributions.ts @@ -9,15 +9,23 @@ export type DatasetAttributionEntry = { layer: string; component_idx: number; value: number; + token_str: string | null; }; -export type ComponentAttributions = { +export type SignedAttributions = { positive_sources: DatasetAttributionEntry[]; negative_sources: DatasetAttributionEntry[]; positive_targets: DatasetAttributionEntry[]; negative_targets: DatasetAttributionEntry[]; }; +export type AttrMetric = "attr" | "attr_abs"; + +export type AllMetricAttributions = { + attr: SignedAttributions; + attr_abs: SignedAttributions; +}; + export type DatasetAttributionsMetadata = { available: boolean; }; @@ -30,8 +38,8 @@ export async function getComponentAttributions( layer: string, componentIdx: number, k: number = 10, -): Promise { +): Promise { const url = apiUrl(`/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("k", String(k)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } diff --git a/spd/app/frontend/src/lib/api/graphInterp.ts b/spd/app/frontend/src/lib/api/graphInterp.ts new file mode 100644 index 000000000..8229e757c --- /dev/null +++ b/spd/app/frontend/src/lib/api/graphInterp.ts @@ -0,0 +1,81 @@ +/** + * API client for /api/graph_interp endpoints. + */ + +import { fetchJson } from "./index"; + +export type GraphInterpHeadline = { + label: string; + confidence: string; + output_label: string | null; + input_label: string | null; +}; + +export type LabelDetail = { + label: string; + confidence: string; + reasoning: string; + prompt: string; +}; + +export type GraphInterpDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; +}; + +export type PromptEdgeResponse = { + related_key: string; + pass_name: string; + attribution: number; + related_label: string | null; + related_confidence: string | null; + token_str: string | null; +}; + +export type GraphInterpComponentDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; + edges: PromptEdgeResponse[]; +}; + +export type GraphNode = { + component_key: string; + label: string; + confidence: string; +}; + +export type GraphEdge = { + source: string; + target: string; + attribution: number; + pass_name: string; +}; + +export type ModelGraphResponse = { + nodes: GraphNode[]; + edges: GraphEdge[]; +}; + +export type GraphInterpInfo = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + +export async function getAllGraphInterpLabels(): Promise> { + return fetchJson>("/api/graph_interp/labels"); +} + +export async function getGraphInterpDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/labels/${layer}/${cIdx}`); +} + +export async function getGraphInterpComponentDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/detail/${layer}/${cIdx}`); +} + +export async function getModelGraph(): Promise { + return fetchJson("/api/graph_interp/graph"); +} diff --git a/spd/app/frontend/src/lib/api/graphs.ts b/spd/app/frontend/src/lib/api/graphs.ts index 42490d531..125d0cd0e 100644 --- a/spd/app/frontend/src/lib/api/graphs.ts +++ b/spd/app/frontend/src/lib/api/graphs.ts @@ -2,11 +2,25 @@ * API client for /api/graphs endpoints. */ -import type { GraphData, TokenizeResponse, TokenInfo } from "../promptAttributionsTypes"; +import type { GraphData, EdgeData, TokenizeResponse, TokenSearchResult, CISnapshot } from "../promptAttributionsTypes"; import { buildEdgeIndexes } from "../promptAttributionsTypes"; -import { setArchitecture } from "../layerAliasing"; import { apiUrl, ApiError, fetchJson } from "./index"; +/** Hydrate a raw API graph response into a full GraphData with edge indexes. */ +function hydrateGraph(raw: Record): GraphData { + const g = raw as Omit; + const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); + const edgesAbs = (g.edgesAbs as EdgeData[] | null) ?? null; + let edgesAbsBySource: Map | null = null; + let edgesAbsByTarget: Map | null = null; + if (edgesAbs) { + const absIndexes = buildEdgeIndexes(edgesAbs); + edgesAbsBySource = absIndexes.edgesBySource; + edgesAbsByTarget = absIndexes.edgesByTarget; + } + return { ...g, edgesBySource, edgesByTarget, edgesAbs, edgesAbsBySource, edgesAbsByTarget } as GraphData; +} + export type NormalizeType = "none" | "target" | "layer"; export type GraphProgress = { @@ -30,6 +44,7 @@ export type ComputeGraphParams = { async function parseGraphSSEStream( response: Response, onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, ): Promise { const reader = response.body?.getReader(); if (!reader) { @@ -56,19 +71,12 @@ async function parseGraphSSEStream( if (data.type === "progress" && onProgress) { onProgress({ current: data.current, total: data.total, stage: data.stage }); + } else if (data.type === "ci_snapshot" && onCISnapshot) { + onCISnapshot(data as CISnapshot); } else if (data.type === "error") { throw new ApiError(data.error, 500); } else if (data.type === "complete") { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of data.data.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(data.data.edges); - result = { ...data.data, edgesBySource, edgesByTarget }; + result = hydrateGraph(data.data); await reader.cancel(); break; } @@ -106,7 +114,7 @@ export async function computeGraphStream( } export type MaskType = "stochastic" | "ci"; -export type LossType = "ce" | "kl"; +export type LossType = "ce" | "kl" | "logit"; export type ComputeGraphOptimizedParams = { promptId: number; @@ -128,6 +136,7 @@ export type ComputeGraphOptimizedParams = { export async function computeGraphOptimizedStream( params: ComputeGraphOptimizedParams, onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, ): Promise { const url = apiUrl("/api/graphs/optimized/stream"); url.searchParams.set("prompt_id", String(params.promptId)); @@ -157,26 +166,121 @@ export async function computeGraphOptimizedStream( throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); } - return parseGraphSSEStream(response, onProgress); + return parseGraphSSEStream(response, onProgress, onCISnapshot); +} + +export type ComputeGraphOptimizedBatchParams = { + promptId: number; + impMinCoeffs: number[]; + steps: number; + pnorm: number; + beta: number; + normalize: NormalizeType; + ciThreshold: number; + maskType: MaskType; + lossType: LossType; + lossCoeff: number; + lossPosition: number; + labelToken?: number; + advPgdNSteps?: number; + advPgdStepSize?: number; +}; + +export async function computeGraphOptimizedBatchStream( + params: ComputeGraphOptimizedBatchParams, + onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { + const url = apiUrl("/api/graphs/optimized/batch/stream"); + + const body: Record = { + prompt_id: params.promptId, + imp_min_coeffs: params.impMinCoeffs, + steps: params.steps, + pnorm: params.pnorm, + beta: params.beta, + normalize: params.normalize, + ci_threshold: params.ciThreshold, + mask_type: params.maskType, + loss_type: params.lossType, + loss_coeff: params.lossCoeff, + loss_position: params.lossPosition, + }; + if (params.labelToken !== undefined) body.label_token = params.labelToken; + if (params.advPgdNSteps !== undefined) body.adv_pgd_n_steps = params.advPgdNSteps; + if (params.advPgdStepSize !== undefined) body.adv_pgd_step_size = params.advPgdStepSize; + + const response = await fetch(url.toString(), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + if (!response.ok) { + const error = await response.json(); + throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); + } + + return parseBatchGraphSSEStream(response, onProgress, onCISnapshot); +} + +async function parseBatchGraphSSEStream( + response: Response, + onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("Response body is not readable"); + } + + const decoder = new TextDecoder(); + let buffer = ""; + let result: GraphData[] | null = null; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + const lines = buffer.split("\n\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (!line.trim() || !line.startsWith("data: ")) continue; + + const data = JSON.parse(line.substring(6)); + + if (data.type === "progress" && onProgress) { + onProgress({ current: data.current, total: data.total, stage: data.stage }); + } else if (data.type === "ci_snapshot" && onCISnapshot) { + onCISnapshot(data as CISnapshot); + } else if (data.type === "error") { + throw new ApiError(data.error, 500); + } else if (data.type === "complete") { + const graphs: GraphData[] = data.data.graphs.map((g: Record) => hydrateGraph(g)); + result = graphs; + await reader.cancel(); + break; + } + } + + if (result) break; + } + + if (!result) { + throw new Error("No result received from stream"); + } + + return result; } export async function getGraphs(promptId: number, normalize: NormalizeType, ciThreshold: number): Promise { const url = apiUrl(`/api/graphs/${promptId}`); url.searchParams.set("normalize", normalize); url.searchParams.set("ci_threshold", String(ciThreshold)); - const graphs = await fetchJson[]>(url.toString()); - return graphs.map((g) => { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of g.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); - return { ...g, edgesBySource, edgesByTarget }; - }); + const graphs = await fetchJson[]>(url.toString()); + return graphs.map((g) => hydrateGraph(g)); } export async function tokenizeText(text: string): Promise { @@ -185,15 +289,17 @@ export async function tokenizeText(text: string): Promise { return fetchJson(url.toString(), { method: "POST" }); } -export async function getAllTokens(): Promise { - const response = await fetchJson<{ tokens: TokenInfo[] }>("/api/graphs/tokens"); - return response.tokens; -} - -export async function searchTokens(query: string, limit: number = 10): Promise { +export async function searchTokens( + query: string, + promptId: number, + position: number, + limit: number = 20, +): Promise { const url = apiUrl("/api/graphs/tokens/search"); url.searchParams.set("q", query); url.searchParams.set("limit", String(limit)); - const response = await fetchJson<{ tokens: TokenInfo[] }>(url.toString()); + url.searchParams.set("prompt_id", String(promptId)); + url.searchParams.set("position", String(position)); + const response = await fetchJson<{ tokens: TokenSearchResult[] }>(url.toString()); return response.tokens; } diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 773663636..d2d810283 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -51,5 +51,8 @@ export * from "./datasetAttributions"; export * from "./intervention"; export * from "./dataset"; export * from "./clusters"; +export * from "./investigations"; export * from "./dataSources"; +export * from "./graphInterp"; export * from "./pretrainInfo"; +export * from "./runRegistry"; diff --git a/spd/app/frontend/src/lib/api/intervention.ts b/spd/app/frontend/src/lib/api/intervention.ts index 689c29cc1..154228181 100644 --- a/spd/app/frontend/src/lib/api/intervention.ts +++ b/spd/app/frontend/src/lib/api/intervention.ts @@ -2,11 +2,7 @@ * API client for /api/intervention endpoints. */ -import type { - ForkedInterventionRunSummary, - InterventionRunSummary, - RunInterventionRequest, -} from "../interventionTypes"; +import type { InterventionRunSummary, RunInterventionRequest } from "../interventionTypes"; export async function runAndSaveIntervention(request: RunInterventionRequest): Promise { const response = await fetch("/api/intervention/run", { @@ -39,30 +35,3 @@ export async function deleteInterventionRun(runId: number): Promise { throw new Error(error.detail || "Failed to delete intervention run"); } } - -export async function forkInterventionRun( - runId: number, - tokenReplacements: [number, number][], - topK: number = 10, -): Promise { - const response = await fetch(`/api/intervention/runs/${runId}/fork`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ token_replacements: tokenReplacements, top_k: topK }), - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to fork intervention run"); - } - return (await response.json()) as ForkedInterventionRunSummary; -} - -export async function deleteForkedInterventionRun(forkId: number): Promise { - const response = await fetch(`/api/intervention/forks/${forkId}`, { - method: "DELETE", - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to delete forked intervention run"); - } -} diff --git a/spd/app/frontend/src/lib/api/investigations.ts b/spd/app/frontend/src/lib/api/investigations.ts new file mode 100644 index 000000000..42f1fb1f3 --- /dev/null +++ b/spd/app/frontend/src/lib/api/investigations.ts @@ -0,0 +1,101 @@ +/** + * API client for investigation results. + */ + +export interface InvestigationSummary { + id: string; // inv_id (e.g., "inv-abc12345") + wandb_path: string | null; + prompt: string | null; + created_at: string; + has_research_log: boolean; + has_explanations: boolean; + event_count: number; + last_event_time: string | null; + last_event_message: string | null; + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; // in_progress, completed, inconclusive +} + +export interface EventEntry { + event_type: string; + timestamp: string; + message: string; + details: Record | null; +} + +export interface InvestigationDetail { + id: string; + wandb_path: string | null; + prompt: string | null; + created_at: string; + research_log: string | null; + events: EventEntry[]; + explanations: Record[]; + artifact_ids: string[]; // List of artifact IDs available for this investigation + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; +} + +import type { EdgeData, OutputProbability } from "../promptAttributionsTypes"; + +/** Data for a graph artifact (subset of GraphData, self-contained for offline viewing) */ +export interface ArtifactGraphData { + tokens: string[]; + edges: EdgeData[]; + outputProbs: Record; + nodeCiVals: Record; + nodeSubcompActs: Record; + maxAbsAttr: number; + l0_total: number; +} + +export interface GraphArtifact { + type: "graph"; + id: string; + caption: string | null; + graph_id: number; + data: ArtifactGraphData; +} + +export interface LaunchResponse { + inv_id: string; + job_id: string; +} + +export async function launchInvestigation(prompt: string): Promise { + const res = await fetch("/api/investigations/launch", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ prompt }), + }); + if (!res.ok) throw new Error(`Failed to launch investigation: ${res.statusText}`); + return res.json(); +} + +export async function listInvestigations(): Promise { + const res = await fetch("/api/investigations"); + if (!res.ok) throw new Error(`Failed to list investigations: ${res.statusText}`); + return res.json(); +} + +export async function getInvestigation(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}`); + if (!res.ok) throw new Error(`Failed to get investigation: ${res.statusText}`); + return res.json(); +} + +export async function listArtifacts(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts`); + if (!res.ok) throw new Error(`Failed to list artifacts: ${res.statusText}`); + return res.json(); +} + +export async function getArtifact(invId: string, artifactId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts/${artifactId}`); + if (!res.ok) throw new Error(`Failed to get artifact: ${res.statusText}`); + return res.json(); +} diff --git a/spd/app/frontend/src/lib/api/pretrainInfo.ts b/spd/app/frontend/src/lib/api/pretrainInfo.ts index 7092c735a..0cd66bd97 100644 --- a/spd/app/frontend/src/lib/api/pretrainInfo.ts +++ b/spd/app/frontend/src/lib/api/pretrainInfo.ts @@ -20,6 +20,7 @@ export type TopologyInfo = { export type PretrainInfoResponse = { model_type: string; summary: string; + dataset_short: string | null; target_model_config: Record | null; pretrain_config: Record | null; pretrain_wandb_path: string | null; diff --git a/spd/app/frontend/src/lib/api/runRegistry.ts b/spd/app/frontend/src/lib/api/runRegistry.ts new file mode 100644 index 000000000..c727f4dcc --- /dev/null +++ b/spd/app/frontend/src/lib/api/runRegistry.ts @@ -0,0 +1,26 @@ +/** + * API client for /api/run_registry endpoint. + */ + +import { fetchJson } from "./index"; + +export type DataAvailability = { + harvest: boolean; + autointerp: boolean; + attributions: boolean; + graph_interp: boolean; +}; + +export type RunInfoResponse = { + wandb_run_id: string; + architecture: string | null; + availability: DataAvailability; +}; + +export async function fetchRunInfo(wandbRunIds: string[]): Promise { + return fetchJson("/api/run_registry", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(wandbRunIds), + }); +} diff --git a/spd/app/frontend/src/lib/api/runs.ts b/spd/app/frontend/src/lib/api/runs.ts index 1430632a4..d898c8671 100644 --- a/spd/app/frontend/src/lib/api/runs.ts +++ b/spd/app/frontend/src/lib/api/runs.ts @@ -14,6 +14,8 @@ export type LoadedRun = { backend_user: string; dataset_attributions_available: boolean; dataset_search_enabled: boolean; + graph_interp_available: boolean; + autointerp_available: boolean; }; export async function getStatus(): Promise { diff --git a/spd/app/frontend/src/lib/colors.ts b/spd/app/frontend/src/lib/colors.ts index e64cc696d..d15462693 100644 --- a/spd/app/frontend/src/lib/colors.ts +++ b/spd/app/frontend/src/lib/colors.ts @@ -7,17 +7,17 @@ */ export const colors = { - // Text - punchy contrast (matches --text-*) - textPrimary: "#111111", - textSecondary: "#555555", - textMuted: "#999999", + // Text - warm navy contrast (matches --text-*) + textPrimary: "#1d272a", + textSecondary: "#646464", + textMuted: "#b4b4b4", // Status colors for edges/data (matches --accent-primary, --status-negative) - positive: "#2563eb", + positive: "#4d65ff", negative: "#dc2626", // RGB components for dynamic opacity - positiveRgb: { r: 37, g: 99, b: 235 }, // blue - matches --accent-primary + positiveRgb: { r: 77, g: 101, b: 255 }, // vibrant blue - matches --accent-primary negativeRgb: { r: 220, g: 38, b: 38 }, // red - matches --status-negative // Output node gradient (green) - matches --status-positive @@ -28,10 +28,10 @@ export const colors = { tokenHighlightOpacity: 0.4, // Node default - nodeDefault: "#6b7280", + nodeDefault: "#8a8780", // Accent (for active states) - matches --accent-primary - accent: "#2563eb", + accent: "#7C4D33", // Set overlap visualization (A/B/intersection) setOverlap: { diff --git a/spd/app/frontend/src/lib/componentKeys.ts b/spd/app/frontend/src/lib/componentKeys.ts new file mode 100644 index 000000000..ff83bda06 --- /dev/null +++ b/spd/app/frontend/src/lib/componentKeys.ts @@ -0,0 +1,17 @@ +/** + * Utilities for component key display (e.g. rendering embed/output keys with token strings). + */ + +export function isTokenNode(key: string): boolean { + const layer = key.split(":")[0]; + return layer === "embed" || layer === "output"; +} + +export function formatComponentKey(key: string, tokenStr: string | null): string { + if (tokenStr && isTokenNode(key)) { + const layer = key.split(":")[0]; + const label = layer === "embed" ? "input" : "output"; + return `'${tokenStr}' (${label})`; + } + return key; +} diff --git a/spd/app/frontend/src/lib/displaySettings.svelte.ts b/spd/app/frontend/src/lib/displaySettings.svelte.ts index db3e3f7c9..6998214ee 100644 --- a/spd/app/frontend/src/lib/displaySettings.svelte.ts +++ b/spd/app/frontend/src/lib/displaySettings.svelte.ts @@ -13,6 +13,14 @@ export const NODE_COLOR_MODE_LABELS: Record = { subcomp_act: "Subcomp Act", }; +// Edge variant for attribution graphs +export type EdgeVariant = "signed" | "abs_target"; + +export const EDGE_VARIANT_LABELS: Record = { + signed: "Signed", + abs_target: "Abs Target", +}; + // Example color mode for activation contexts viewer export type ExampleColorMode = "ci" | "component_act" | "both"; @@ -48,6 +56,8 @@ export const displaySettings = $state({ meanCiCutoff: 1e-7, centerOnPeak: false, showAutoInterpPromptButton: false, + curvedEdges: true, + edgeVariant: "signed" as EdgeVariant, }); export function anyCorrelationStatsEnabled() { diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index c364be243..de8e1af4d 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -1,46 +1,110 @@ /** Types for the intervention forward pass feature */ -export type InterventionNode = { - layer: string; - seq_pos: number; - component_idx: number; -}; +/** Default eval PGD settings (distinct from training PGD which is an optimization regularizer) */ +export const EVAL_PGD_N_STEPS = 4; +export const EVAL_PGD_STEP_SIZE = 1.0; export type TokenPrediction = { token: string; token_id: number; - spd_prob: number; - target_prob: number; + prob: number; logit: number; + target_prob: number; target_logit: number; }; -export type InterventionResponse = { - input_tokens: string[]; - predictions_per_position: TokenPrediction[][]; +export type LabelPredictions = { + position: number; + ci: TokenPrediction; + stochastic: TokenPrediction; + adversarial: TokenPrediction; + ablated: TokenPrediction | null; }; -/** A forked intervention run with modified tokens */ -export type ForkedInterventionRunSummary = { - id: number; - token_replacements: [number, number][]; // [(seq_pos, new_token_id), ...] - result: InterventionResponse; - created_at: string; +export type InterventionResult = { + input_tokens: string[]; + ci: TokenPrediction[][]; + stochastic: TokenPrediction[][]; + adversarial: TokenPrediction[][]; + ablated: TokenPrediction[][] | null; + ci_loss: number; + stochastic_loss: number; + adversarial_loss: number; + ablated_loss: number | null; + label: LabelPredictions | null; }; /** Persisted intervention run from the server */ export type InterventionRunSummary = { id: number; selected_nodes: string[]; // node keys (layer:seq:cIdx) - result: InterventionResponse; + result: InterventionResult; created_at: string; - forked_runs?: ForkedInterventionRunSummary[]; // child runs with modified tokens }; /** Request to run and save an intervention */ export type RunInterventionRequest = { graph_id: number; - text: string; selected_nodes: string[]; - top_k?: number; + nodes_to_ablate?: string[]; + top_k: number; + adv_pgd: { n_steps: number; step_size: number }; +}; + +// --- Frontend-only run lifecycle types --- + +import { SvelteSet } from "svelte/reactivity"; +import { isInterventableNode } from "./promptAttributionsTypes"; + +/** Draft run: cloned from a parent, editable node selection. No forwarded results yet. */ +export type DraftRun = { + kind: "draft"; + parentId: number; + selectedNodes: SvelteSet; +}; + +/** Baked run: forwarded and immutable. Wraps a persisted InterventionRunSummary. */ +export type BakedRun = { + kind: "baked"; + id: number; + selectedNodes: Set; + result: InterventionResult; + createdAt: string; +}; + +export type InterventionRun = DraftRun | BakedRun; + +export type InterventionState = { + runs: InterventionRun[]; + activeIndex: number; }; + +/** Whether a run's selection is editable */ +export function isRunEditable(run: InterventionRun): run is DraftRun { + return run.kind === "draft"; +} + +/** Build initial InterventionState from persisted runs. + * The first persisted run is the base run (all CI > 0 nodes), auto-created during graph computation. */ +export function buildInterventionState(persistedRuns: InterventionRunSummary[]): InterventionState { + if (persistedRuns.length === 0) throw new Error("Graph must have at least one intervention run (the base run)"); + const runs: InterventionRun[] = persistedRuns.map( + (r): BakedRun => ({ + kind: "baked", + id: r.id, + selectedNodes: new Set(r.selected_nodes), + result: r.result, + createdAt: r.created_at, + }), + ); + return { runs, activeIndex: 0 }; +} + +/** Get all interventable node keys with CI > 0 from a nodeCiVals record */ +export function getInterventableNodes(nodeCiVals: Record): Set { + const nodes = new Set(); + for (const [nodeKey, ci] of Object.entries(nodeCiVals)) { + if (isInterventableNode(nodeKey) && ci > 0) nodes.add(nodeKey); + } + return nodes; +} diff --git a/spd/app/frontend/src/lib/layerAliasing.ts b/spd/app/frontend/src/lib/layerAliasing.ts deleted file mode 100644 index 2c5269543..000000000 --- a/spd/app/frontend/src/lib/layerAliasing.ts +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Layer aliasing system - transforms internal module names to human-readable aliases. - * - * Formats: - * - Internal: "h.0.mlp.c_fc", "h.1.attn.q_proj" - * - Aliased: "L0.mlp.in", "L1.attn.q" - * - * Handles multiple architectures: - * - GPT-2: c_fc -> mlp.in, down_proj -> mlp.out - * - Llama SwiGLU: gate_proj -> mlp.gate, up_proj -> mlp.up, down_proj -> mlp.down - * - Attention: q_proj -> attn.q, k_proj -> attn.k, v_proj -> attn.v, o_proj -> attn.o - * - Special: lm_head -> W_U, embed/output unchanged - */ - -type Architecture = "gpt2" | "llama" | "unknown"; - -/** Mapping of internal module names to aliases by architecture */ -const ALIASES: Record> = { - gpt2: { - // MLP - c_fc: "in", - down_proj: "out", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - llama: { - // MLP (SwiGLU) - gate_proj: "gate", - up_proj: "up", - down_proj: "down", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - unknown: { - // Fallback - just do attention mappings - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, -}; - -/** Special layers with fixed display names */ -const SPECIAL_LAYERS: Record = { - lm_head: "W_U", - embed: "embed", - output: "output", -}; - -// Cache for detected architecture from the full model -let cachedArchitecture: Architecture | null = null; - -/** - * Detect architecture from a collection of layer names. - * Llama has gate_proj/up_proj, GPT-2 has c_fc. - * - * This should be called once with all available layer names to establish - * the architecture for the session, ensuring down_proj is aliased correctly. - */ -export function detectArchitectureFromLayers(layers: string[]): Architecture { - const hasLlamaLayers = layers.some((layer) => layer.includes("gate_proj") || layer.includes("up_proj")); - if (hasLlamaLayers) { - return "llama"; - } - - const hasGPT2Layers = layers.some((layer) => layer.includes("c_fc")); - if (hasGPT2Layers) { - return "gpt2"; - } - - return "unknown"; -} - -/** - * Set the architecture for aliasing operations. - * Call this when you have access to all layer names (e.g., when loading a graph). - */ -export function setArchitecture(layers: string[]): void { - cachedArchitecture = detectArchitectureFromLayers(layers); -} - -/** - * Detect architecture from layer name. - * Uses cached architecture if available (set via setArchitecture()), - * otherwise falls back to single-layer detection. - * - * Note: down_proj appears in both architectures with different meanings: - * - GPT-2: down_proj -> "out" (second MLP projection) - * - Llama: down_proj -> "down" (third MLP projection after gate/up) - * - * Single-layer detection cannot distinguish these cases reliably. - */ -function detectArchitecture(layer: string): Architecture { - // Use cached architecture if available - if (cachedArchitecture !== null) { - return cachedArchitecture; - } - - // Fallback: single-layer detection (less reliable for down_proj) - if (layer.includes("gate_proj") || layer.includes("up_proj")) { - return "llama"; - } - if (layer.includes("c_fc")) { - return "gpt2"; - } - // down_proj is ambiguous without context, default to GPT-2 - if (layer.includes("down_proj")) { - return "gpt2"; - } - return "unknown"; -} - -/** - * Parse a layer name into components. - * Returns null for special layers (embed, output, lm_head) or unrecognized formats. - */ -function parseLayerName(layer: string): { block: number; moduleType: string; submodule: string } | null { - if (layer in SPECIAL_LAYERS) { - return null; - } - - const match = layer.match(/^h\.(\d+)\.(attn|mlp)\.(\w+)$/); - if (!match) { - return null; - } - - const [, blockStr, moduleType, submodule] = match; - return { - block: parseInt(blockStr), - moduleType, - submodule, - }; -} - -/** - * Transform a layer name to its aliased form. - * - * Examples: - * - "h.0.mlp.c_fc" -> "L0.mlp.in" - * - "h.2.attn.q_proj" -> "L2.attn.q" - * - "lm_head" -> "W_U" - * - "embed" -> "embed" - */ -export function getLayerAlias(layer: string): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Get a row label for grouped display in graphs. - * - * @param layer - Internal layer name (e.g., "h.0.mlp.c_fc") - * @param isQkvGroup - Whether this represents a grouped QKV row - * @returns Label (e.g., "L0.mlp.in", "L2.attn.qkv") - * - * @example - * getAliasedRowLabel("h.0.mlp.c_fc") // => "L0.mlp.in" - * getAliasedRowLabel("h.2.attn.q_proj", true) // => "L2.attn.qkv" - */ -export function getAliasedRowLabel(layer: string, isQkvGroup = false): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - if (isQkvGroup) { - return `L${parsed.block}.${parsed.moduleType}.qkv`; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Format a node key with aliased layer names. - * - * Node keys are "layer:seq:cIdx" or "layer:cIdx" format. - * - * Examples: - * - "h.0.mlp.c_fc:3:5" -> "L0.mlp.in:3:5" - * - "h.1.attn.q_proj:2:10" -> "L1.attn.q:2:10" - */ -export function formatNodeKeyWithAliases(nodeKey: string): string { - const parts = nodeKey.split(":"); - const layer = parts[0]; - const aliasedLayer = getLayerAlias(layer); - return [aliasedLayer, ...parts.slice(1)].join(":"); -} diff --git a/spd/app/frontend/src/lib/promptAttributionsTypes.ts b/spd/app/frontend/src/lib/promptAttributionsTypes.ts index fc705fad0..b40a63641 100644 --- a/spd/app/frontend/src/lib/promptAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/promptAttributionsTypes.ts @@ -20,6 +20,7 @@ export type EdgeAttribution = { key: string; // "layer:seq:cIdx" for prompt or "layer:cIdx" for dataset value: number; // raw attribution value (positive or negative) normalizedMagnitude: number; // |value| / maxAbsValue, for color intensity (0-1) + tokenStr: string | null; // resolved token string for embed/output layers }; export type OutputProbability = { @@ -27,11 +28,20 @@ export type OutputProbability = { logit: number; // CI-masked (SPD model) raw logit target_prob: number; // Target model probability target_logit: number; // Target model raw logit - adv_pgd_prob: number | null; // Adversarial PGD probability - adv_pgd_logit: number | null; // Adversarial PGD raw logit token: string; }; +export type CISnapshot = { + step: number; + total_steps: number; + layers: string[]; + seq_len: number; + initial_alive: number[][]; + current_alive: number[][]; + l0_total: number; + loss: number; +}; + export type GraphType = "standard" | "optimized" | "manual"; export type GraphData = { @@ -41,10 +51,15 @@ export type GraphData = { edges: EdgeData[]; edgesBySource: Map; // nodeKey -> edges where this node is source edgesByTarget: Map; // nodeKey -> edges where this node is target + // Absolute-target variant (∂|y|/∂x · x), null for old graphs + edgesAbs: EdgeData[] | null; + edgesAbsBySource: Map | null; + edgesAbsByTarget: Map | null; outputProbs: Record; // key is "seq:cIdx" nodeCiVals: Record; // node key -> CI value (or output prob for output nodes or 1 for wte node) nodeSubcompActs: Record; // node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: number; // max absolute edge value + maxAbsAttrAbs: number | null; // max absolute edge value for abs-target variant maxAbsSubcompAct: number; // max absolute subcomponent activation for normalization l0_total: number; // total active components at current CI threshold optimization?: OptimizationResult; @@ -93,7 +108,15 @@ export type KLLossResult = { position: number; }; -export type LossResult = CELossResult | KLLossResult; +export type LogitLossResult = { + type: "logit"; + coeff: number; + position: number; + label_token: number; + label_str: string; +}; + +export type LossResult = CELossResult | KLLossResult | LogitLossResult; export type OptimizationMetrics = { ci_masked_label_prob: number | null; // Probability of label under CI mask (CE loss only) @@ -102,6 +125,11 @@ export type OptimizationMetrics = { l0_total: number; // Total L0 (active components) }; +export type PgdConfig = { + n_steps: number; + step_size: number; +}; + export type OptimizationResult = { imp_min_coeff: number; steps: number; @@ -110,8 +138,7 @@ export type OptimizationResult = { mask_type: MaskType; loss: LossResult; metrics: OptimizationMetrics; - adv_pgd_n_steps: number | null; - adv_pgd_step_size: number | null; + pgd: PgdConfig | null; }; export type SubcomponentMetadata = { @@ -169,11 +196,33 @@ export type TokenizeResponse = { next_token_probs: (number | null)[]; // Probability of next token (last is null) }; -export type TokenInfo = { +export type TokenSearchResult = { id: number; string: string; + prob: number; }; +/** Select active edge set based on variant preference. Falls back to signed if abs unavailable. */ +export function getActiveEdges( + data: GraphData, + variant: "signed" | "abs_target", +): { edges: EdgeData[]; bySource: Map; byTarget: Map; maxAbsAttr: number } { + if (variant === "abs_target" && data.edgesAbs) { + return { + edges: data.edgesAbs, + bySource: data.edgesAbsBySource!, + byTarget: data.edgesAbsByTarget!, + maxAbsAttr: data.maxAbsAttrAbs || 1, + }; + } + return { + edges: data.edges, + bySource: data.edgesBySource, + byTarget: data.edgesByTarget, + maxAbsAttr: data.maxAbsAttr || 1, + }; +} + // Client-side computed types export type NodePosition = { @@ -233,7 +282,7 @@ export function formatNodeKeyForDisplay(nodeKey: string, displayNames: Record>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); // Current coords being loaded/displayed (for interpretation lookup) let currentCoords = $state(null); @@ -132,20 +134,40 @@ export function useComponentData() { datasetAttributions = { status: "loaded", data: null }; } - // Fetch interpretation detail (404 = no interpretation for this component) - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail (skip if not available for this run) + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } /** @@ -159,6 +181,7 @@ export function useComponentData() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache - reactive to both coords and cache @@ -212,6 +235,9 @@ export function useComponentData() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts index f32dab70a..d76c5da9e 100644 --- a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts @@ -6,7 +6,7 @@ * examples (200). Dataset attributions and interpretation detail are on-demand. */ -import { getContext } from "svelte"; +import { getContext, untrack } from "svelte"; import type { Loadable } from "."; import { ApiError, @@ -14,10 +14,11 @@ import { getComponentAttributions, getComponentCorrelations, getComponentTokenStats, + getGraphInterpComponentDetail, getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; +import type { AllMetricAttributions, GraphInterpComponentDetail, InterpretationDetail } from "./api"; import type { SubcomponentCorrelationsResponse, SubcomponentActivationContexts, @@ -29,7 +30,7 @@ const DATASET_ATTRIBUTIONS_TOP_K = 20; /** Fetch more activation examples in background after initial cached load */ const ACTIVATION_EXAMPLES_FULL_LIMIT = 200; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; @@ -39,8 +40,9 @@ export function useComponentDataExpectCached() { let componentDetail = $state>({ status: "uninitialized" }); let correlations = $state>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); let currentCoords = $state(null); let requestId = 0; @@ -87,21 +89,41 @@ export function useComponentDataExpectCached() { datasetAttributions = { status: "loaded", data: null }; } - // Fetch interpretation detail on-demand (not cached) - interpretationDetail = { status: "loading" }; - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + interpretationDetail = { status: "loading" }; + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } function load(layer: string, cIdx: number) { @@ -144,6 +166,7 @@ export function useComponentDataExpectCached() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache @@ -197,6 +220,9 @@ export function useComponentDataExpectCached() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useRun.svelte.ts b/spd/app/frontend/src/lib/useRun.svelte.ts index de6d20c7d..1cfc3cca6 100644 --- a/spd/app/frontend/src/lib/useRun.svelte.ts +++ b/spd/app/frontend/src/lib/useRun.svelte.ts @@ -7,13 +7,8 @@ import type { Loadable } from "."; import * as api from "./api"; -import type { LoadedRun as RunData, InterpretationHeadline } from "./api"; -import type { - PromptPreview, - SubcomponentActivationContexts, - TokenInfo, - SubcomponentMetadata, -} from "./promptAttributionsTypes"; +import type { LoadedRun as RunData, InterpretationHeadline, GraphInterpHeadline } from "./api"; +import type { PromptPreview, SubcomponentActivationContexts, SubcomponentMetadata } from "./promptAttributionsTypes"; /** Maps component keys to cluster IDs. Singletons (unclustered components) have null values. */ export type ClusterMappingData = Record; @@ -46,17 +41,15 @@ export function useRun() { /** Intruder eval scores keyed by component key */ let intruderScores = $state>>({ status: "uninitialized" }); + /** Graph interp labels keyed by component key (layer:cIdx) */ + let graphInterpLabels = $state>>({ status: "uninitialized" }); + /** Cluster mapping for the current run */ let clusterMapping = $state(null); /** Available prompts for the current run */ let prompts = $state>({ status: "uninitialized" }); - /** All tokens in the tokenizer for the current run */ - let allTokens = $state>({ status: "uninitialized" }); - - /** Model topology info for frontend layout */ - /** Activation contexts summary (null = harvest not available) */ let activationContextsSummary = $state | null>>({ status: "uninitialized", @@ -68,9 +61,9 @@ export function useRun() { /** Reset all run-scoped state */ function resetRunScopedState() { prompts = { status: "uninitialized" }; - allTokens = { status: "uninitialized" }; interpretations = { status: "uninitialized" }; intruderScores = { status: "uninitialized" }; + graphInterpLabels = { status: "uninitialized" }; activationContextsSummary = { status: "uninitialized" }; _componentDetailsCache = {}; clusterMapping = null; @@ -88,6 +81,9 @@ export function useRun() { api.getIntruderScores() .then((data) => (intruderScores = { status: "loaded", data })) .catch((error) => (intruderScores = { status: "error", error })); + api.getAllGraphInterpLabels() + .then((data) => (graphInterpLabels = { status: "loaded", data })) + .catch((error) => (graphInterpLabels = { status: "error", error })); api.getAllInterpretations() .then((i) => { interpretations = { @@ -106,14 +102,6 @@ export function useRun() { .catch((error) => (interpretations = { status: "error", error })); } - /** Fetch tokens - must complete before run is considered loaded */ - async function fetchTokens(): Promise { - allTokens = { status: "loading" }; - const tokens = await api.getAllTokens(); - allTokens = { status: "loaded", data: tokens }; - return tokens; - } - async function loadRun(wandbPath: string, contextLength: number) { run = { status: "loading" }; try { @@ -122,8 +110,6 @@ export function useRun() { if (status) { run = { status: "loaded", data: status }; fetchRunScopedData(); - // Fetch tokens in background (no longer blocks UI - used only by token search) - fetchTokens(); } else { run = { status: "error", error: "Failed to load run" }; } @@ -142,10 +128,6 @@ export function useRun() { try { const status = await api.getStatus(); if (status) { - // Fetch tokens and model info if we don't have them (e.g., page refresh) - if (allTokens.status === "uninitialized") { - await fetchTokens(); - } run = { status: "loaded", data: status }; // Fetch other run-scoped data if we don't have it if (interpretations.status === "uninitialized") { @@ -230,6 +212,11 @@ export function useRun() { return clusterMapping?.data[key] ?? null; } + function getGraphInterpLabel(componentKey: string): GraphInterpHeadline | null { + if (graphInterpLabels.status !== "loaded") return null; + return graphInterpLabels.data[componentKey] ?? null; + } + return { get run() { return run; @@ -237,21 +224,27 @@ export function useRun() { get interpretations() { return interpretations; }, + get graphInterpLabels() { + return graphInterpLabels; + }, get clusterMapping() { return clusterMapping; }, get prompts() { return prompts; }, - get allTokens() { - return allTokens; - }, get activationContextsSummary() { return activationContextsSummary; }, get datasetAttributionsAvailable() { return run.status === "loaded" && run.data.dataset_attributions_available; }, + get graphInterpAvailable() { + return run.status === "loaded" && run.data.graph_interp_available; + }, + get autoInterpAvailable() { + return run.status === "loaded" && run.data.autointerp_available; + }, loadRun, clearRun, syncStatus, @@ -259,6 +252,7 @@ export function useRun() { getInterpretation, setInterpretation, getIntruderScore, + getGraphInterpLabel, getActivationContextDetail, loadActivationContextsSummary, setClusterMapping, diff --git a/spd/app/frontend/vite.config.ts b/spd/app/frontend/vite.config.ts index fc72bbc92..a08d086fb 100644 --- a/spd/app/frontend/vite.config.ts +++ b/spd/app/frontend/vite.config.ts @@ -9,6 +9,7 @@ const backendUrl = process.env.BACKEND_URL || "http://localhost:8000"; export default defineConfig({ plugins: [svelte()], server: { + hmr: false, proxy: { "/api": { target: backendUrl, diff --git a/spd/app/run_app.py b/spd/app/run_app.py index 6aff0ce4c..c61174d1e 100755 --- a/spd/app/run_app.py +++ b/spd/app/run_app.py @@ -303,7 +303,7 @@ def spawn_frontend( return proc def monitor_child_liveness(self) -> None: - log_lines_to_show = 5 + log_lines_to_show = 20 prev_lines: list[str] = [] while True: diff --git a/spd/autointerp/db.py b/spd/autointerp/db.py index 4cba168b7..f05227f5c 100644 --- a/spd/autointerp/db.py +++ b/spd/autointerp/db.py @@ -1,6 +1,5 @@ """SQLite database for autointerp data (interpretations and scores). NFS-hosted, single writer then read-only.""" -import sqlite3 from pathlib import Path import orjson diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md index a063596f5..f502f8785 100644 --- a/spd/clustering/CLAUDE.md +++ b/spd/clustering/CLAUDE.md @@ -108,6 +108,18 @@ DistancesArray # Float[np.ndarray, "n_iters n_ens n_ens"] - `matching_dist.py` - Optimal matching distance via Hungarian algorithm - `merge_pair_samplers.py` - Strategies for selecting which pair to merge +## Utility Scripts + +**`get_cluster_mapping.py`**: Extracts cluster assignments at a specific iteration from a clustering run, outputs JSON mapping component labels to cluster indices (singletons mapped to `null`). + +```bash +python -m spd.clustering.scripts.get_cluster_mapping /path/to/clustering_run --iteration 299 +``` + +## App Integration + +To make a cluster mapping available in the app's dropdown for a run, add its path to `CANONICAL_RUNS` in `spd/app/frontend/src/lib/registry.ts` under the corresponding run's `clusterMappings` array. + ## Config Files Configs live in `spd/clustering/configs/`: diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index eccda019f..80e9c63bc 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,4 +6,4 @@ slurm_partition: null wandb_project: "spd" wandb_entity: "goodfire" create_git_snapshot: false -clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file +clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp.json" \ No newline at end of file diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 2633267c3..6a0ba3380 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -73,7 +73,6 @@ def harvest_attributions( rank: int, world_size: int, ) -> None: - device = torch.device(get_device()) logger.info(f"Loading model on {device}") diff --git a/spd/utils/sqlite.py b/spd/utils/sqlite.py index 7ac591780..ab388f4f4 100644 --- a/spd/utils/sqlite.py +++ b/spd/utils/sqlite.py @@ -26,9 +26,7 @@ def open_nfs_sqlite(path: Path, readonly: bool) -> sqlite3.Connection: Write: default DELETE journal (WAL breaks on NFS). """ if readonly: - conn = sqlite3.connect( - f"file:{path}?immutable=1", uri=True, check_same_thread=False - ) + conn = sqlite3.connect(f"file:{path}?immutable=1", uri=True, check_same_thread=False) else: conn = sqlite3.connect(str(path), check_same_thread=False) conn.row_factory = sqlite3.Row diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index f39ef385f..d2cd74057 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -16,6 +16,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB from spd.app.backend.routers import graphs as graphs_router +from spd.app.backend.routers import intervention as intervention_router from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import RunState, StateManager @@ -54,6 +55,7 @@ def app_with_state(): # Patch DEVICE in all router modules to use CPU for tests with ( mock.patch.object(graphs_router, "DEVICE", DEVICE), + mock.patch.object(intervention_router, "DEVICE", DEVICE), mock.patch.object(runs_router, "DEVICE", DEVICE), ): db = PromptAttrDB(db_path=Path(":memory:"), check_same_thread=False) @@ -147,6 +149,7 @@ def app_with_state(): harvest=None, interp=None, attributions=None, + graph_interp=None, ) manager = StateManager.get() @@ -231,6 +234,49 @@ def test_compute_graph(app_with_prompt: tuple[TestClient, int]): assert "outputProbs" in data +def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClient, int]): + """Run-and-save intervention should use graph-linked prompt tokens (no text in request).""" + client, prompt_id = app_with_prompt + + graph_response = client.post( + "/api/graphs", + params={"prompt_id": prompt_id, "normalize": "none", "ci_threshold": 0.0}, + ) + assert graph_response.status_code == 200 + events = [line for line in graph_response.text.strip().split("\n") if line.startswith("data:")] + final_data = json.loads(events[-1].replace("data: ", "")) + graph_data = final_data["data"] + graph_id = graph_data["id"] + + selected_nodes = [ + key + for key, ci in graph_data["nodeCiVals"].items() + if not key.startswith("embed:") and not key.startswith("output:") and ci > 0 + ] + assert len(selected_nodes) > 0 + + request = { + "graph_id": graph_id, + "selected_nodes": selected_nodes[: min(5, len(selected_nodes))], + "top_k": 5, + "adv_pgd": {"n_steps": 1, "step_size": 1.0}, + } + response = client.post("/api/intervention/run", json=request) + assert response.status_code == 200 + body = response.json() + assert body["selected_nodes"] == request["selected_nodes"] + result = body["result"] + assert len(result["input_tokens"]) > 0 + assert len(result["ci"]) > 0 + assert len(result["stochastic"]) > 0 + assert len(result["adversarial"]) > 0 + assert result["target_sans"] is None + assert "ci_loss" in result + assert "stochastic_loss" in result + assert "adversarial_loss" in result + assert result["target_sans_loss"] is None + + # ----------------------------------------------------------------------------- # Streaming: Prompt Generation # ----------------------------------------------------------------------------- diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py index 75841ef2c..969973bd7 100644 --- a/tests/dataset_attributions/test_storage.py +++ b/tests/dataset_attributions/test_storage.py @@ -316,10 +316,12 @@ class TestMergeNumericCorrectness: def test_merge_equals_sum_of_parts(self, tmp_path: Path): """Two workers with known values; merged queries should equal manual computation.""" - s1 = _deterministic_storage(embed_val=4.0, ci_sum_val=20.0, act_sq_sum_val=100.0, - embed_count_val=80, n_tokens=40) - s2 = _deterministic_storage(embed_val=8.0, ci_sum_val=30.0, act_sq_sum_val=500.0, - embed_count_val=120, n_tokens=60) + s1 = _deterministic_storage( + embed_val=4.0, ci_sum_val=20.0, act_sq_sum_val=100.0, embed_count_val=80, n_tokens=40 + ) + s2 = _deterministic_storage( + embed_val=8.0, ci_sum_val=30.0, act_sq_sum_val=500.0, embed_count_val=120, n_tokens=60 + ) p1, p2 = tmp_path / "r0.pt", tmp_path / "r1.pt" s1.save(p1) From 3e746858eba7f42a8ddf9d0a0cff4a98d22c5a38 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 12:46:34 +0000 Subject: [PATCH 02/20] Cleanup: delete one-off scripts, fix silent failures, remove TODO.md - Delete scripts/{migrate_harvest_data,test_abs_grad_trick,parse_transformer_circuits_post}.py - Remove spd/app/TODO.md (moved to ~/app-todo-2026-03-04.md for reference) - Remove hardcoded partition="h200-reserved" in investigations.py - Narrow bare except Exception to json.JSONDecodeError in investigations.py - Add exhaustive match default in graph_interp.py (was NameError on unexpected pass_name) Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/migrate_harvest_data.py | 369 --------------------- scripts/parse_transformer_circuits_post.py | 363 -------------------- scripts/test_abs_grad_trick.py | 150 --------- spd/app/TODO.md | 172 ---------- spd/app/backend/routers/graph_interp.py | 2 + spd/app/backend/routers/investigations.py | 34 +- 6 files changed, 17 insertions(+), 1073 deletions(-) delete mode 100644 scripts/migrate_harvest_data.py delete mode 100644 scripts/parse_transformer_circuits_post.py delete mode 100644 scripts/test_abs_grad_trick.py delete mode 100644 spd/app/TODO.md diff --git a/scripts/migrate_harvest_data.py b/scripts/migrate_harvest_data.py deleted file mode 100644 index c0da50dcf..000000000 --- a/scripts/migrate_harvest_data.py +++ /dev/null @@ -1,369 +0,0 @@ -"""Migrate legacy harvest + autointerp data to the new layout. - -Copies data into new directories/DBs without modifying the originals. - -Legacy layout: - harvest//activation_contexts/harvest.db (schema: mean_ci, ci_values, component_acts) - harvest//correlations/*.pt - harvest//eval/intruder/*.jsonl - autointerp//interp.db (top-level, with scores already merged) - -New layout: - harvest//h-/harvest.db (schema: firing_density, mean_activations, firings, activations) - harvest//h-/*.pt - autointerp//a-/interp.db -""" - -import shutil -import sqlite3 -from dataclasses import dataclass, field -from pathlib import Path - -import fire -import orjson -import torch - -from spd.harvest.storage import TokenStatsStorage -from spd.settings import SPD_OUT_DIR - - -@dataclass -class RunDiagnostic: - run_id: str - harvest_db: bool = False - token_stats: bool = False - correlations: bool = False - n_components: int = 0 - n_tokens: int = 0 - ci_threshold: str = "" - intruder_scores: int = 0 - interp_db: bool = False - n_interpretations: int = 0 - interp_score_types: list[str] = field(default_factory=list) - already_migrated: bool = False - problems: list[str] = field(default_factory=list) - - @property - def ready(self) -> bool: - return not self.problems and not self.already_migrated - - -def diagnose_run(run_id: str, timestamp: str = "20260218_000000") -> RunDiagnostic: - harvest_root = SPD_OUT_DIR / "harvest" / run_id - autointerp_root = SPD_OUT_DIR / "autointerp" / run_id - diag = RunDiagnostic(run_id=run_id) - - # Already migrated? - if (harvest_root / f"h-{timestamp}").exists(): - diag.already_migrated = True - return diag - - # Harvest DB - old_db = harvest_root / "activation_contexts" / "harvest.db" - diag.harvest_db = old_db.exists() - if not diag.harvest_db: - diag.problems.append("missing harvest.db") - return diag - - conn = sqlite3.connect(f"file:{old_db}?immutable=1", uri=True) - diag.n_components = conn.execute("SELECT COUNT(*) FROM components").fetchone()[0] - threshold_row = conn.execute("SELECT value FROM config WHERE key = 'ci_threshold'").fetchone() - diag.ci_threshold = threshold_row[0] if threshold_row else "0.0 (default)" - conn.close() - - # Token stats - ts_path = harvest_root / "correlations" / "token_stats.pt" - diag.token_stats = ts_path.exists() - if diag.token_stats: - ts_data = torch.load(ts_path, weights_only=False) - diag.n_tokens = ts_data["n_tokens"] - else: - diag.problems.append("missing token_stats.pt") - - # Correlations - diag.correlations = (harvest_root / "correlations" / "component_correlations.pt").exists() - - # Intruder scores - intruder_dir = harvest_root / "eval" / "intruder" - if intruder_dir.exists(): - jsonl_files = list(intruder_dir.glob("*.jsonl")) - if jsonl_files: - largest = max(jsonl_files, key=lambda f: f.stat().st_size) - with open(largest, "rb") as f: - diag.intruder_scores = sum(1 for _ in f) - - # Autointerp - interp_db = autointerp_root / "interp.db" - diag.interp_db = interp_db.exists() - if diag.interp_db: - conn = sqlite3.connect(f"file:{interp_db}?immutable=1", uri=True) - tables = [ - r[0] - for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() - ] - if "interpretations" in tables: - diag.n_interpretations = conn.execute( - "SELECT COUNT(*) FROM interpretations" - ).fetchone()[0] - if "scores" in tables: - diag.interp_score_types = [ - r[0] for r in conn.execute("SELECT DISTINCT score_type FROM scores").fetchall() - ] - conn.close() - - return diag - - -def print_diagnostics(diagnostics: list[RunDiagnostic]) -> None: - ready = [d for d in diagnostics if d.ready] - skipped = [d for d in diagnostics if d.already_migrated] - blocked = [d for d in diagnostics if d.problems] - - if skipped: - print(f"Already migrated ({len(skipped)}): {', '.join(d.run_id for d in skipped)}\n") - - if blocked: - print(f"BLOCKED ({len(blocked)}):") - for d in blocked: - print(f" {d.run_id}: {', '.join(d.problems)}") - print() - - if ready: - print(f"Ready to migrate ({len(ready)}):") - print( - f"{'run_id':20s} {'comps':>6s} {'n_tokens':>14s} {'threshold':>10s} " - f"{'intruder':>8s} {'interps':>7s} {'scores':>20s} {'corr':>5s}" - ) - print("-" * 95) - for d in ready: - scores_str = ", ".join(d.interp_score_types) if d.interp_score_types else "-" - print( - f"{d.run_id:20s} {d.n_components:>6d} {d.n_tokens:>14,} {d.ci_threshold:>10s} " - f"{d.intruder_scores:>8d} {d.n_interpretations:>7d} {scores_str:>20s} " - f"{'yes' if d.correlations else 'no':>5s}" - ) - print() - - -def migrate_harvest_db(old_db_path: Path, new_db_path: Path, token_stats_path: Path) -> int: - """Copy harvest DB, transforming schema from legacy to new format. - - Old schema: mean_ci REAL, activation_examples with {token_ids, ci_values, component_acts} - New schema: firing_density REAL, mean_activations TEXT, - activation_examples with {token_ids, firings, activations} - """ - old_conn = sqlite3.connect(f"file:{old_db_path}?immutable=1", uri=True) - old_conn.row_factory = sqlite3.Row - - new_conn = sqlite3.connect(str(new_db_path)) - new_conn.execute("PRAGMA journal_mode=WAL") - new_conn.executescript("""\ - CREATE TABLE IF NOT EXISTS components ( - component_key TEXT PRIMARY KEY, - layer TEXT NOT NULL, - component_idx INTEGER NOT NULL, - firing_density REAL NOT NULL, - mean_activations TEXT NOT NULL, - activation_examples TEXT NOT NULL, - input_token_pmi TEXT NOT NULL, - output_token_pmi TEXT NOT NULL - ); - CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - ); - CREATE TABLE IF NOT EXISTS scores ( - component_key TEXT NOT NULL, - score_type TEXT NOT NULL, - score REAL NOT NULL, - details TEXT NOT NULL, - PRIMARY KEY (component_key, score_type) - ); - """) - - # Load firing densities from token_stats.pt - assert token_stats_path.exists(), f"No token_stats.pt at {token_stats_path}" - ts = TokenStatsStorage.load(token_stats_path) - firing_density_map: dict[str, float] = {} - for i, key in enumerate(ts.component_keys): - firing_density_map[key] = ts.firing_counts[i].item() / ts.n_tokens - - # Read activation threshold from old DB config - threshold_row = old_conn.execute( - "SELECT value FROM config WHERE key = 'ci_threshold'" - ).fetchone() - activation_threshold = float(threshold_row["value"]) if threshold_row else 0.0 - - # Migrate config - for row in old_conn.execute("SELECT key, value FROM config").fetchall(): - key = row["key"] - value = row["value"] - if key == "ci_threshold": - key = "activation_threshold" - new_conn.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", (key, value)) - - # Migrate components row-by-row - n = 0 - rows = old_conn.execute("SELECT * FROM components").fetchall() - for row in rows: - old_examples = orjson.loads(row["activation_examples"]) - new_examples = [] - for ex in old_examples: - ci_values = ex["ci_values"] - component_acts = ex["component_acts"] - new_examples.append( - { - "token_ids": ex["token_ids"], - "firings": [v > activation_threshold for v in ci_values], - "activations": { - "causal_importance": ci_values, - "component_activation": component_acts, - }, - } - ) - - mean_ci = row["mean_ci"] - component_key = row["component_key"] - assert component_key in firing_density_map, f"{component_key} missing from token_stats.pt" - firing_density = firing_density_map[component_key] - mean_activations = {"causal_importance": mean_ci} - - new_conn.execute( - "INSERT OR REPLACE INTO components VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - ( - component_key, - row["layer"], - row["component_idx"], - firing_density, - orjson.dumps(mean_activations).decode(), - orjson.dumps(new_examples).decode(), - row["input_token_pmi"], - row["output_token_pmi"], - ), - ) - n += 1 - - new_conn.commit() - - # Migrate intruder scores from JSONL into harvest DB scores table - old_intruder_dir = old_db_path.parent.parent / "eval" / "intruder" - if old_intruder_dir.exists(): - # Use the largest file (most complete run) - jsonl_files = sorted(old_intruder_dir.glob("*.jsonl"), key=lambda f: f.stat().st_size) - if jsonl_files: - intruder_file = jsonl_files[-1] - n_scores = 0 - with open(intruder_file, "rb") as f: - for line in f: - record = orjson.loads(line) - new_conn.execute( - "INSERT OR REPLACE INTO scores VALUES (?, ?, ?, ?)", - ( - record["component_key"], - "intruder", - record["score"], - orjson.dumps(record.get("trials", [])).decode(), - ), - ) - n_scores += 1 - new_conn.commit() - print(f" Migrated {n_scores} intruder scores from {intruder_file.name}") - - old_conn.close() - new_conn.close() - return n - - -def migrate_autointerp_db(old_db_path: Path, new_db_path: Path) -> int: - """Copy autointerp DB (schema is compatible, just copy and strip intruder scores).""" - shutil.copy2(old_db_path, new_db_path) - - # Remove intruder scores — those belong in harvest.db in the new layout - conn = sqlite3.connect(str(new_db_path)) - conn.execute("DELETE FROM scores WHERE score_type = 'intruder'") - conn.commit() - n = conn.execute("SELECT COUNT(*) FROM interpretations").fetchone()[0] - conn.close() - return n - - -def migrate_run(run_id: str, timestamp: str = "20260218_000000") -> None: - """Migrate a single run's harvest + autointerp data to the new layout.""" - harvest_root = SPD_OUT_DIR / "harvest" / run_id - autointerp_root = SPD_OUT_DIR / "autointerp" / run_id - - print(f"Migrating {run_id}...") - - # --- Harvest --- - old_harvest_db = harvest_root / "activation_contexts" / "harvest.db" - assert old_harvest_db.exists(), f"No legacy harvest DB at {old_harvest_db}" - - new_subrun = harvest_root / f"h-{timestamp}" - assert not new_subrun.exists(), f"Target already exists: {new_subrun}" - new_subrun.mkdir(parents=True) - - token_stats_path = harvest_root / "correlations" / "token_stats.pt" - assert token_stats_path.exists(), f"No token_stats.pt at {token_stats_path}" - print(f" Harvest DB: {old_harvest_db} -> {new_subrun / 'harvest.db'}") - n_components = migrate_harvest_db(old_harvest_db, new_subrun / "harvest.db", token_stats_path) - print(f" Migrated {n_components} components") - - # Copy .pt files - old_corr_dir = harvest_root / "correlations" - for pt_file in ["component_correlations.pt", "token_stats.pt"]: - src = old_corr_dir / pt_file - if src.exists(): - dst = new_subrun / pt_file - shutil.copy2(src, dst) - print(f" Copied {pt_file}") - - # --- Autointerp --- - old_interp_db = autointerp_root / "interp.db" - if old_interp_db.exists(): - new_autointerp_subrun = autointerp_root / f"a-{timestamp}" - assert not new_autointerp_subrun.exists(), f"Target already exists: {new_autointerp_subrun}" - new_autointerp_subrun.mkdir(parents=True) - - print(f" Interp DB: {old_interp_db} -> {new_autointerp_subrun / 'interp.db'}") - n_interps = migrate_autointerp_db(old_interp_db, new_autointerp_subrun / "interp.db") - print(f" Migrated {n_interps} interpretations (detection + fuzzing scores preserved)") - else: - print(" No top-level interp.db found, skipping autointerp") - - print("Done!") - - -def _find_legacy_run_ids() -> list[str]: - harvest_root = SPD_OUT_DIR / "harvest" - return sorted( - d.name - for d in harvest_root.iterdir() - if (d / "activation_contexts" / "harvest.db").exists() - ) - - -def diagnose(run_id: str | None = None, timestamp: str = "20260218_000000") -> None: - """Print diagnostic report without migrating anything.""" - run_ids = [run_id] if run_id else _find_legacy_run_ids() - diagnostics = [diagnose_run(rid, timestamp) for rid in run_ids] - print_diagnostics(diagnostics) - - -def migrate_all(timestamp: str = "20260218_000000", dry_run: bool = False) -> None: - """Discover and migrate all legacy harvest runs.""" - run_ids = _find_legacy_run_ids() - diagnostics = [diagnose_run(rid, timestamp) for rid in run_ids] - - print_diagnostics(diagnostics) - - ready = [d for d in diagnostics if d.ready] - if dry_run or not ready: - return - - for d in ready: - migrate_run(d.run_id, timestamp=timestamp) - print(f"\nAll done — migrated {len(ready)} runs") - - -if __name__ == "__main__": - fire.Fire({"run": migrate_run, "all": migrate_all, "diagnose": diagnose}) diff --git a/scripts/parse_transformer_circuits_post.py b/scripts/parse_transformer_circuits_post.py deleted file mode 100644 index 1f528179e..000000000 --- a/scripts/parse_transformer_circuits_post.py +++ /dev/null @@ -1,363 +0,0 @@ -#!/usr/bin/env python3 -"""Export a Transformer Circuits post to local Markdown + downloaded image assets. - -Example: - .venv/bin/python scripts/parse_transformer_circuits_post.py \ - --url https://transformer-circuits.pub/2025/attribution-graphs/biology.html \ - --output-md papers/biology_source/biology.md \ - --assets-dir papers/biology_source/assets -""" - -from __future__ import annotations - -import argparse -import re -from collections.abc import Iterable -from pathlib import Path -from urllib.parse import urljoin, urlparse - -import requests -from bs4 import BeautifulSoup, NavigableString, Tag - -DEFAULT_URL = "https://transformer-circuits.pub/2025/attribution-graphs/biology.html" -USER_AGENT = "spd-biology-exporter/1.0" - - -def normalize_whitespace(text: str) -> str: - return re.sub(r"\s+", " ", text).strip() - - -def should_insert_space(left: str, right: str) -> bool: - if not left or not right: - return False - left = left.rstrip() - right = right.lstrip() - if not left or not right: - return False - if left.endswith((" ", "\n", "\t", "/", "(", "[", "{", "-", "“", '"', "'")): - return False - if right.startswith( - (" ", "\n", "\t", ".", ",", ":", ";", "!", "?", ")", "]", "}", "-", "”", '"', "'") - ): - return False - if left.endswith(" "): - return False - left_char = left[-1] - right_char = right[0] - return bool( - re.match(r"[A-Za-z0-9\]\)]", left_char) and re.match(r"[A-Za-z0-9\[\(]", right_char) - ) - - -class Exporter: - def __init__( - self, - *, - base_url: str, - output_md: Path, - assets_dir: Path, - download_assets: bool = True, - timeout: float = 30.0, - ) -> None: - self.base_url = base_url - self.output_md = output_md - self.assets_dir = assets_dir - self.download_assets = download_assets - self.timeout = timeout - self.session = requests.Session() - self.session.headers.update({"User-Agent": USER_AGENT}) - self.asset_map: dict[str, str] = {} - self.asset_counter = 0 - - def fetch_html(self) -> str: - response = self.session.get(self.base_url, timeout=self.timeout) - response.raise_for_status() - # Distill pages sometimes produce mojibake via default charset guessing. - return response.content.decode("utf-8", errors="replace") - - def _relative_asset_path(self, local_path: Path) -> str: - return local_path.relative_to(self.output_md.parent).as_posix() - - def _unique_asset_filename(self, remote_url: str) -> str: - parsed = urlparse(remote_url) - basename = Path(parsed.path).name or f"asset_{self.asset_counter}" - if "." not in basename: - basename = f"{basename}.bin" - candidate = basename - while (self.assets_dir / candidate).exists(): - self.asset_counter += 1 - stem = Path(basename).stem - suffix = Path(basename).suffix - candidate = f"{stem}_{self.asset_counter}{suffix}" - return candidate - - def download_asset(self, src: str) -> str: - if not src: - return "" - remote_url = urljoin(self.base_url, src) - if remote_url in self.asset_map: - return self.asset_map[remote_url] - if remote_url.startswith("data:"): - return remote_url - local_rel = remote_url - if self.download_assets: - self.assets_dir.mkdir(parents=True, exist_ok=True) - preferred = Path(urlparse(remote_url).path).name - if preferred and "." in preferred and (self.assets_dir / preferred).exists(): - local_path = self.assets_dir / preferred - else: - filename = self._unique_asset_filename(remote_url) - local_path = self.assets_dir / filename - response = self.session.get(remote_url, timeout=self.timeout) - response.raise_for_status() - local_path.write_bytes(response.content) - local_rel = self._relative_asset_path(local_path) - self.asset_map[remote_url] = local_rel - return local_rel - - def render_inline(self, node: Tag | NavigableString) -> str: - if isinstance(node, NavigableString): - return str(node) - if not isinstance(node, Tag): - return "" - name = node.name.lower() - if name == "br": - return " \n" - if name in {"b", "strong"}: - return f"**{self.render_children_inline(node.children)}**" - if name in {"i", "em"}: - return f"*{self.render_children_inline(node.children)}*" - if name == "code": - text = normalize_whitespace(self.render_children_inline(node.children)) - return f"`{text}`" if text else "" - if name == "a": - href = (node.get("href") or "").strip() - text = normalize_whitespace(self.render_children_inline(node.children)) - if not text: - text = href - if not href: - return text - full_href = urljoin(self.base_url, href) - return f"[{text}]({full_href})" - if name == "d-cite": - cite = normalize_whitespace(self.render_children_inline(node.children)) - return f"[{cite}]" if cite else "[citation]" - if name == "d-footnote": - note = normalize_whitespace(self.render_children_inline(node.children)) - return f" (Footnote: {note}) " if note else "" - return self.render_children_inline(node.children) - - def render_children_inline(self, nodes: Iterable[Tag | NavigableString]) -> str: - pieces = [self.render_inline(child) for child in nodes] - raw = "" - for piece in pieces: - if not piece: - continue - if raw and should_insert_space(raw, piece): - raw += " " - raw += piece - # Keep explicit markdown line breaks but normalize other whitespace. - parts = raw.split(" \n") - return " \n".join(normalize_whitespace(part) for part in parts) - - def render_list(self, list_tag: Tag, level: int = 0) -> list[str]: - lines: list[str] = [] - ordered = list_tag.name.lower() == "ol" - counter = 1 - indent = " " * level - for li in list_tag.find_all("li", recursive=False): - inline_nodes: list[Tag | NavigableString] = [] - nested_lists: list[Tag] = [] - for child in li.children: - if isinstance(child, Tag) and child.name and child.name.lower() in {"ul", "ol"}: - nested_lists.append(child) - else: - inline_nodes.append(child) - text = normalize_whitespace(self.render_children_inline(inline_nodes)) - marker = f"{counter}." if ordered else "-" - if text: - lines.append(f"{indent}{marker} {text}") - else: - lines.append(f"{indent}{marker}") - for nested in nested_lists: - lines.extend(self.render_list(nested, level + 1)) - counter += 1 - lines.append("") - return lines - - def render_table(self, table: Tag) -> list[str]: - rows: list[list[str]] = [] - for tr in table.find_all("tr"): - row: list[str] = [] - cells = tr.find_all(["th", "td"]) - for cell in cells: - row.append(normalize_whitespace(self.render_children_inline(cell.children))) - if row: - rows.append(row) - if not rows: - return [] - width = max(len(row) for row in rows) - padded = [row + [""] * (width - len(row)) for row in rows] - header = padded[0] - sep = ["---"] * width - lines = [ - "| " + " | ".join(header) + " |", - "| " + " | ".join(sep) + " |", - ] - for row in padded[1:]: - lines.append("| " + " | ".join(row) + " |") - lines.append("") - return lines - - def render_figure(self, figure: Tag) -> list[str]: - lines: list[str] = [] - imgs = figure.find_all("img") - for img in imgs: - src = (img.get("src") or "").strip() - alt = normalize_whitespace(img.get("alt") or "") - local_src = self.download_asset(src) - if local_src: - lines.append(f"![{alt}]({local_src})") - caption = figure.find("figcaption") - if caption: - caption_text = normalize_whitespace(self.render_children_inline(caption.children)) - if caption_text: - lines.append(f"_Figure: {caption_text}_") - if lines: - lines.append("") - return lines - - def render_block(self, node: Tag | NavigableString) -> list[str]: - if isinstance(node, NavigableString): - text = normalize_whitespace(str(node)) - return [text, ""] if text else [] - if not isinstance(node, Tag): - return [] - name = node.name.lower() - if name in {"style", "script"}: - return [] - if name in {"h1", "h2", "h3", "h4", "h5", "h6"}: - level = int(name[1]) - text = normalize_whitespace(self.render_children_inline(node.children)) - if not text: - return [] - anchor = node.get("id") - anchor_suffix = f" " if anchor else "" - return [f"{'#' * level} {text}{anchor_suffix}", ""] - if name == "p": - text = normalize_whitespace(self.render_children_inline(node.children)) - return [text, ""] if text else [] - if name in {"ul", "ol"}: - return self.render_list(node) - if name == "figure": - return self.render_figure(node) - if name == "table": - return self.render_table(node) - if name == "hr": - return ["---", ""] - if name == "br": - return [""] - if name == "d-contents": - return [] - if name in {"div", "section", "nav", "d-appendix", "d-article"}: - lines: list[str] = [] - for child in node.children: - lines.extend(self.render_block(child)) - return lines - text = normalize_whitespace(self.render_children_inline(node.children)) - return [text, ""] if text else [] - - def export(self, include_appendix: bool = True) -> tuple[str, int]: - html = self.fetch_html() - soup = BeautifulSoup(html, "html.parser") - title = normalize_whitespace(soup.title.string if soup.title else "") or "Untitled" - article = soup.find("d-article") - if article is None: - raise RuntimeError("Could not find in the page") - lines: list[str] = [ - f"# {title}", - "", - f"Source: {self.base_url}", - "", - "> Auto-generated by scripts/parse_transformer_circuits_post.py", - "", - ] - for child in article.children: - lines.extend(self.render_block(child)) - if include_appendix: - appendix = soup.find("d-appendix") - if appendix is not None: - appendix_lines: list[str] = [] - for child in appendix.children: - appendix_lines.extend(self.render_block(child)) - appendix_content = [line for line in appendix_lines if line.strip()] - if appendix_content: - lines.extend(["## Appendix", ""]) - lines.extend(appendix_lines) - # Strip trailing whitespace and collapse excessive blank lines. - cleaned: list[str] = [] - blank_run = 0 - for line in lines: - stripped = line.rstrip() - if not stripped: - blank_run += 1 - if blank_run <= 1: - cleaned.append("") - else: - blank_run = 0 - cleaned.append(stripped) - markdown = "\n".join(cleaned).strip() + "\n" - self.output_md.parent.mkdir(parents=True, exist_ok=True) - self.output_md.write_text(markdown, encoding="utf-8") - return markdown, len(self.asset_map) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--url", default=DEFAULT_URL, help="Post URL to parse") - parser.add_argument( - "--output-md", - default="papers/biology_source/biology.md", - help="Path for generated markdown", - ) - parser.add_argument( - "--assets-dir", - default="papers/biology_source/assets", - help="Directory for downloaded assets", - ) - parser.add_argument( - "--skip-appendix", - action="store_true", - help="Do not include d-appendix content", - ) - parser.add_argument( - "--skip-download-assets", - action="store_true", - help="Keep remote asset links instead of downloading files", - ) - parser.add_argument("--timeout", type=float, default=30.0, help="HTTP timeout in seconds") - return parser.parse_args() - - -def main() -> None: - args = parse_args() - output_md = Path(args.output_md).resolve() - assets_dir = Path(args.assets_dir).resolve() - exporter = Exporter( - base_url=args.url, - output_md=output_md, - assets_dir=assets_dir, - download_assets=not args.skip_download_assets, - timeout=args.timeout, - ) - _, asset_count = exporter.export(include_appendix=not args.skip_appendix) - print(f"Wrote markdown: {output_md}") - if args.skip_download_assets: - print("Assets were not downloaded (--skip-download-assets set).") - else: - print(f"Downloaded/linked assets: {asset_count}") - print(f"Assets dir: {assets_dir}") - - -if __name__ == "__main__": - main() diff --git a/scripts/test_abs_grad_trick.py b/scripts/test_abs_grad_trick.py deleted file mode 100644 index b3d15d296..000000000 --- a/scripts/test_abs_grad_trick.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Verify that ∂|y|/∂x = sign(y) · ∂y/∂x for a scalar y, even through nonlinearities. - -The chain rule: ∂|y|/∂x = (d|y|/dy) · (∂y/∂x) = sign(y) · ∂y/∂x - -This holds regardless of what nonlinear computation sits between x and y, -because ∂y/∂x already accounts for all intermediate nonlinearities. -The sign(y) factor is just the outermost link in the chain. -""" - -import torch -from torch import nn - - -def test_simple_linear(): - """Linear: y = Wx, trivial case.""" - x = torch.randn(5, requires_grad=True) - W = torch.randn(3, 5) - y_vec = W @ x - y = y_vec[1] # pick one scalar - - grad = torch.autograd.grad(y, x, retain_graph=True)[0] - grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] - grad_trick = y.sign() * grad - - assert torch.allclose(grad_abs, grad_trick, atol=1e-7), ( - f"FAIL: {(grad_abs - grad_trick).abs().max()}" - ) - print(f" linear: max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓") - - -def test_deep_nonlinear(): - """Deep net with ReLU, tanh, and GELU — representative of a transformer.""" - torch.manual_seed(42) - net = nn.Sequential( - nn.Linear(8, 16), - nn.ReLU(), - nn.Linear(16, 16), - nn.Tanh(), - nn.Linear(16, 16), - nn.GELU(), - nn.Linear(16, 4), - ) - x = torch.randn(8, requires_grad=True) - y_vec = net(x) - y = y_vec[2] # scalar output - - grad = torch.autograd.grad(y, x, retain_graph=True)[0] - grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] - grad_trick = y.sign() * grad - - assert torch.allclose(grad_abs, grad_trick, atol=1e-6), ( - f"FAIL: {(grad_abs - grad_trick).abs().max()}" - ) - print( - f" deep nonlinear (y={y.item():.4f}): max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓" - ) - - -def test_negative_target(): - """Ensure it works when y < 0 (sign flips the gradient).""" - torch.manual_seed(99) - net = nn.Sequential(nn.Linear(4, 8), nn.Tanh(), nn.Linear(8, 1)) - # Find an input that gives negative output - for _seed in range(200): - x = torch.randn(4, requires_grad=True) - y = net(x).squeeze() - if y.item() < -0.1: - break - assert y.item() < 0, "Couldn't find negative output" - - grad = torch.autograd.grad(y, x, retain_graph=True)[0] - grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] - grad_trick = y.sign() * grad - - assert torch.allclose(grad_abs, grad_trick, atol=1e-6), ( - f"FAIL: {(grad_abs - grad_trick).abs().max()}" - ) - print( - f" negative target (y={y.item():.4f}): max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓" - ) - - -def test_multiple_inputs(): - """Multiple input tensors (mirrors the app's in_post_detaches list).""" - torch.manual_seed(7) - x1 = torch.randn(3, 4, requires_grad=True) - x2 = torch.randn(3, 4, requires_grad=True) - - # Nonlinear function of both inputs - h = torch.relu(x1) + torch.tanh(x2) - y = (h @ torch.randn(4, 1)).sum() # scalar - - grads = torch.autograd.grad(y, [x1, x2], retain_graph=True) - grads_abs = torch.autograd.grad(y.abs(), [x1, x2], retain_graph=True) - - for i, (g, g_abs) in enumerate(zip(grads, grads_abs, strict=True)): - g_trick = y.sign() * g - assert torch.allclose(g_abs, g_trick, atol=1e-6), ( - f"FAIL input {i}: {(g_abs - g_trick).abs().max()}" - ) - - print(f" multiple inputs (y={y.item():.4f}): all match ✓") - - -def test_sum_of_abs_DOES_NOT_work(): - """Show that the trick FAILS for sum-of-abs (dataset attributions case). - - ∂(Σ|y_i|)/∂x ≠ sign(Σy_i) · ∂(Σy_i)/∂x - because each y_i has a different sign. - """ - torch.manual_seed(42) - x = torch.randn(4, requires_grad=True) - W = torch.randn(3, 4) - y_vec = W @ x # [3] - - target_signed = y_vec.sum() - target_abs = y_vec.abs().sum() - - grad_signed = torch.autograd.grad(target_signed, x, retain_graph=True)[0] - grad_abs = torch.autograd.grad(target_abs, x, retain_graph=True)[0] - - # The WRONG trick: use sign of the sum - grad_wrong = target_signed.sign() * grad_signed - - # The correct per-element version - grad_correct = sum( - y_vec[i].sign() * torch.autograd.grad(y_vec[i], x, retain_graph=True)[0] - for i in range(len(y_vec)) - ) - - wrong_diff = (grad_abs - grad_wrong).abs().max() - correct_diff = (grad_abs - grad_correct).abs().max() - print( - f" sum-of-abs: wrong trick diff = {wrong_diff:.4f}, correct per-element diff = {correct_diff:.2e}" - ) - assert wrong_diff > 0.01, "Expected the wrong trick to fail for sum-of-abs" - assert correct_diff < 1e-6, "Per-element version should match" - print(" → confirms: trick works for scalar y, NOT for sum-of-abs ✓") - - -if __name__ == "__main__": - print("Testing ∂|y|/∂x = sign(y) · ∂y/∂x for scalar y:\n") - test_simple_linear() - test_deep_nonlinear() - test_negative_target() - test_multiple_inputs() - print() - print("Testing that the trick does NOT work for sum-of-abs:\n") - test_sum_of_abs_DOES_NOT_work() - print("\nAll tests passed.") diff --git a/spd/app/TODO.md b/spd/app/TODO.md deleted file mode 100644 index 21a8ba1fd..000000000 --- a/spd/app/TODO.md +++ /dev/null @@ -1,172 +0,0 @@ -# App Backend Review & Action Items - -Review date: 2026-03-04. Scope: `spd/app/backend/` — all Python files. - -Context: the app is a **researcher-first local tool** (frontend + backend launched together, opened in browser). Errors should be loud, silent failures absent, the prompt DB is deletable short-term state, no backwards compatibility needed. - -## Overview - -The backend is ~6,500 lines across 18 Python files. The core architecture (FastAPI + SQLite + singleton state + SSE streaming) is sound. The main concerns are: a few real bugs, accumulated dead code, some silent failures that violate the "loud errors" principle, and a few design seams where complexity hides. - -### File size inventory - -| File | Lines | Risk | -|---|---|---| -| `routers/mcp.py` | 1637 | High — mixed concerns, largest file | -| `routers/graphs.py` | 1036 | Medium — streaming complexity | -| `compute.py` | 920 | Low — core algorithm, well-structured | -| `database.py` | 827 | Medium — manual serialization | -| `optim_cis.py` | 504 | Low | -| `routers/dataset_search.py` | 473 | Medium — hardcoded dataset names | -| `routers/correlations.py` | 386 | Low | -| `routers/graph_interp.py` | 373 | Low | -| `routers/investigations.py` | 317 | Low | -| `routers/pretrain_info.py` | 246 | Low | -| `server.py` | 212 | Low — clean | -| `routers/activation_contexts.py` | 207 | Low | -| `routers/runs.py` | 191 | Low | -| `routers/dataset_attributions.py` | 170 | Low | -| `routers/intervention.py` | 169 | Low — clean | -| `state.py` | 132 | Low — clean | -| `app_tokenizer.py` | 119 | Low | -| `routers/prompts.py` | 115 | Low | - ---- - -## Bugs - -### 1. `dataset_search.py:262` — KeyError on tokenized results - -`get_tokenized_results` accesses `result["story"]` but `search_dataset` stores results with key `"text"` (line 137: `results.append({"text": text, ...})`). This will crash with `KeyError: 'story'` whenever tokenized results are requested. - -**Fix:** Change line 262 from `result["story"]` to `result["text"]`. Also line 287: the metadata exclusion list references `"story"` — should be `"text"`. - -### 2. `dataset_search.py` — random endpoints hardcode SimpleStories - -`get_random_samples` (line 336) and `get_random_samples_with_loss` (line 415) both hardcode `load_dataset("lennart-finke/SimpleStories", ...)` and access `item_dict["story"]`. Since primary models are now Pile-trained, these endpoints are broken for current research. They also don't use `DepLoadedRun` to get the dataset name from the run config like `search_dataset` does. - -**Fix:** Make them take `DepLoadedRun`, read `task_config.dataset_name` and `task_config.column_name`, and use those instead of hardcoded values. Or, if the random endpoints aren't used with Pile models, consider deleting them. - ---- - -## Dead code to delete - -### 3. `ForkedInterventionRunRecord` + `forked_intervention_runs` table - -`database.py:117-125` defines `ForkedInterventionRunRecord`. Lines 256-265 create the `forked_intervention_runs` table. Lines 744-827 implement 3 CRUD methods (`save_forked_intervention_run`, `get_forked_intervention_runs`, `delete_forked_intervention_run`). No router references any of these — the fork endpoints were removed. Delete all of it. - -Files: `database.py` - -### 4. `optim_cis.py:500-504` — `get_out_dir()` never called - -Dead utility function that creates a local `out/` directory. Nothing references it. - -Files: `optim_cis.py` - -### 5. Unused schemas in `graphs.py:188-209` - -`ComponentStats`, `PromptSearchQuery`, and `PromptSearchResponse` are defined but no endpoint uses them. They appear to be leftovers from a removed prompt search feature. The `PromptPreview` in `graphs.py:114` also duplicates the one in `prompts.py:25`. - -Files: `routers/graphs.py` - -### 6. `spd/app/TODO.md` was empty - -(This file — now repurposed for this review.) - ---- - -## Design issues - -### 7. `OptimizationParams` mixes config inputs with computed outputs - -`database.py:69-82` — Fields like `imp_min_coeff`, `steps`, `pnorm`, `beta` are optimization *inputs*. Fields like `ci_masked_label_prob`, `stoch_masked_label_prob`, `adv_pgd_label_prob` are computed *outputs*. These metrics are mutated in-place after construction in `graphs.py:759-761`. - -This makes the object's contract unclear — is it immutable config or mutable state? - -**Suggestion:** Either nest the metrics in a sub-model (`metrics: OptimMetrics | None`), or at minimum stop mutating after construction (compute the metrics before constructing `OptimizationParams`). - -### 8. `StoredGraph.id = -1` sentinel value - -`database.py:90` uses `-1` as "unsaved graph". If a graph is accidentally used before being saved, that `-1` leaks into API responses or DB queries. `id: int | None = None` is more honest and lets the type system catch misuse. - -### 9. GPU lock accessed two different ways - -- `graphs.py:603,844` — `stream_computation(work, manager._gpu_lock)` reaches into the private lock directly -- `intervention.py:86` — `with manager.gpu_lock():` uses the context manager - -The stream pattern is inherently different (hold lock across SSE generator lifetime), but accessing `_gpu_lock` directly breaks encapsulation. - -**Suggestion:** Add a `stream_with_gpu_lock(work)` method on `StateManager` that encapsulates the lock acquisition + SSE streaming pattern. Then `graphs.py` calls `manager.stream_with_gpu_lock(work)` instead of reaching into privates. - -### 10. `load_run` returns untyped dicts - -`runs.py:96,139` returns `{"status": "loaded", "run_id": ...}` and `{"status": "already_loaded", ...}`. No response model, so the frontend has no type-safe contract for this endpoint. - -**Fix:** Define a `LoadRunResponse(BaseModel)` with `status`, `run_id`, `wandb_path`. - -### 11. Edge truncation is invisible to the user - -`graphs.py:903` logs a warning when edges exceed `GLOBAL_EDGE_LIMIT = 50_000` and are truncated, but this info only goes to server logs. The researcher never sees it. - -**Fix:** Add `edges_truncated: bool` (or `total_edge_count: int`) to `GraphData` so the frontend can show a notice. - -### 12. Module-level `DEVICE = get_device()` in multiple files - -`graphs.py:266`, `intervention.py:48`, `dataset_search.py`, `prompts.py:18` all call `get_device()` at import time. Fine in practice but makes testing and non-GPU imports impossible. - -**Suggestion:** Move to a function call or lazily-evaluated property when/if this becomes a testing bottleneck. Low priority. - -### 13. `_GRAPH_INTERP_MOCK_MODE` cross-router import - -`runs.py:13` imports `MOCK_MODE` from `routers/graph_interp.py` and uses it in the status endpoint (line 174). The TODO comment says to remove it. This cross-router dependency for a mock flag should be cleaned up — the mock mode should either be a config flag on `StateManager` or deleted entirely. - ---- - -## Silent failure patterns (violate "loud errors" principle) - -### 14. `compute.py:79-86` — output node capping is silent - -`compute_layer_alive_info` caps output nodes to `MAX_OUTPUT_NODES_PER_POS = 15` per position without any logging or indication. If a researcher has >15 high-probability output tokens at a position, they silently lose some. - -At minimum, log when capping occurs. - -### 15. `correlations.py:291,302` — token stats returns `None` silently - -`get_component_token_stats` returns `None` when token stats haven't been harvested. This means the endpoint returns a `200 null` response, which the frontend has to special-case. An explicit 404 with a message is more honest. - -### 16. `correlations.py:112,260` — interpretations/intruder scores return `{}` silently - -`get_all_interpretations` and `get_intruder_scores` return empty dicts when data isn't available. This is defensible for bulk endpoints (the frontend can check emptiness), but it means the researcher has no way to distinguish "no interpretations exist" from "interpretations not yet generated." Consider logging or adding a `has_interpretations` flag to `LoadedRun`. - -Note: `LoadedRun.autointerp_available` already partially addresses this. But the endpoints themselves don't use it — they independently check `loaded.interp is None`. - ---- - -## Lower priority / nice-to-haves - -### 17. `extract_node_ci_vals` Python double loop - -`compute.py:640-648` iterates every `(seq_pos, component_idx)` pair in Python. For large models (39K components × 512 seq), this is a lot of Python overhead. Could be vectorized to only extract non-zero entries. - -### 18. `database.py` manual graph get-or-create race - -Lines 528-539: catches `IntegrityError` on manual graph save, then re-queries. There's a small race window between the failed insert and the re-query. Acceptable for a single-user local app but worth noting. - -### 19. `mcp.py` is 1637 lines - -The MCP router is the largest file, mixing tool definitions, implementation logic, and JSON-RPC handling. It has module-level global state (`_investigation_config`). This file would benefit from being split, but it's also likely to be rewritten when MCP tooling matures, so the ROI of refactoring now is debatable. - ---- - -## Suggested priority order for implementation - -1. Fix `result["story"]` KeyError (bug #1) — 2 min -2. Delete dead code (items #3-5) — 10 min -3. Fix random dataset endpoints or delete if unused (#2) — 15 min -4. Add `edges_truncated` to GraphData (#11) — 10 min -5. Type the `load_run` response (#10) — 5 min -6. Clean up `_GRAPH_INTERP_MOCK_MODE` (#13) — 5 min -7. Deduplicate `MAX_OUTPUT_NODES_PER_POS` (#5 partial) — 2 min -8. `StoredGraph.id` sentinel → `None` (#8) — 10 min -9. Split `OptimizationParams` (#7) — 20 min -10. GPU lock encapsulation (#9) — 15 min diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py index 525dbce9c..f948a571a 100644 --- a/spd/app/backend/routers/graph_interp.py +++ b/spd/app/backend/routers/graph_interp.py @@ -242,6 +242,8 @@ def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: source, target = comp_canon, rel_canon case "input": source, target = rel_canon, comp_canon + case _: + assert False, f"unexpected pass_name: {e.pass_name}" if source not in node_keys or target not in node_keys: continue diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py index 0784a0eb4..296452b46 100644 --- a/spd/app/backend/routers/investigations.py +++ b/spd/app/backend/routers/investigations.py @@ -71,7 +71,7 @@ def _parse_metadata(inv_path: Path) -> dict[str, Any] | None: try: data: dict[str, Any] = json.loads(metadata_path.read_text()) return data - except Exception: + except json.JSONDecodeError: return None @@ -84,21 +84,18 @@ def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: last_msg = None count = 0 - try: - with open(events_path) as f: - for line in f: - line = line.strip() - if not line: - continue - count += 1 - try: - event = json.loads(line) - last_time = event.get("timestamp") - last_msg = event.get("message") - except json.JSONDecodeError: - continue - except Exception: - pass + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue return last_time, last_msg, count @@ -111,7 +108,7 @@ def _parse_task_summary(inv_path: Path) -> tuple[str | None, str | None, str | N try: data: dict[str, Any] = json.loads(summary_path.read_text()) return data.get("title"), data.get("summary"), data.get("status") - except Exception: + except json.JSONDecodeError: return None, None, None @@ -134,7 +131,7 @@ def _get_created_at(inv_path: Path, metadata: dict[str, Any] | None) -> str: event = json.loads(first_line) if "timestamp" in event: return event["timestamp"] - except Exception: + except json.JSONDecodeError: pass if metadata and "created_at" in metadata: @@ -285,7 +282,6 @@ def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) prompt=request.prompt, context_length=loaded.context_length, max_turns=50, - partition="h200-reserved", time="8:00:00", job_suffix=None, ) From 65ca30fa3e1235c0d09199911b458bd6e26d410a Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 12:57:09 +0000 Subject: [PATCH 03/20] Use ValueError instead of assert False for unexpected pass_name Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/graph_interp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py index f948a571a..7df23a10f 100644 --- a/spd/app/backend/routers/graph_interp.py +++ b/spd/app/backend/routers/graph_interp.py @@ -243,7 +243,7 @@ def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: case "input": source, target = rel_canon, comp_canon case _: - assert False, f"unexpected pass_name: {e.pass_name}" + raise ValueError(f"unexpected pass_name: {e.pass_name}") if source not in node_keys or target not in node_keys: continue From 3313f6bb955ba2633b50d88ae089123df4fe4d8a Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:06:44 +0000 Subject: [PATCH 04/20] =?UTF-8?q?Remove=20unreachable=20default=20case=20?= =?UTF-8?q?=E2=80=94=20pyright=20proves=20pass=5Fname=20is=20exhaustive?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/graph_interp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py index 7df23a10f..272bd4d80 100644 --- a/spd/app/backend/routers/graph_interp.py +++ b/spd/app/backend/routers/graph_interp.py @@ -8,10 +8,10 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from spd.graph_interp.schemas import LabelResult from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.graph_interp.schemas import LabelResult from spd.topology import TransformerTopology # TODO(oli): Remove MOCK_MODE after real data is available @@ -242,8 +242,6 @@ def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: source, target = comp_canon, rel_canon case "input": source, target = rel_canon, comp_canon - case _: - raise ValueError(f"unexpected pass_name: {e.pass_name}") if source not in node_keys or target not in node_keys: continue From 90b145667a934a865bf621adbc55789263c00422 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:10:55 +0000 Subject: [PATCH 05/20] Make PR self-contained: pass partition to launch_investigation Matches the signature in #428 (investigate module). Uses DEFAULT_PARTITION_NAME instead of hardcoded string. TODO to remove when investigate module drops the required partition param. make check now passes with 0 errors. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/investigations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py index 296452b46..3fa8297d0 100644 --- a/spd/app/backend/routers/investigations.py +++ b/spd/app/backend/routers/investigations.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from spd.app.backend.dependencies import DepLoadedRun -from spd.settings import SPD_OUT_DIR +from spd.settings import DEFAULT_PARTITION_NAME, SPD_OUT_DIR from spd.utils.wandb_utils import parse_wandb_run_path router = APIRouter(prefix="/api/investigations", tags=["investigations"]) @@ -282,6 +282,7 @@ def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) prompt=request.prompt, context_length=loaded.context_length, max_turns=50, + partition=DEFAULT_PARTITION_NAME, # TODO: remove when investigate module drops required partition time="8:00:00", job_suffix=None, ) From d5c7181ed2f470c817663a4022c95f6d71a5dab2 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:12:36 +0000 Subject: [PATCH 06/20] Delete dead code: commented-out MCP tool, unused DEVICE, dead schema - Remove ~87 lines of commented-out _tool_get_component_attributions in mcp.py - Remove unused DEVICE constant + get_device import + stale TODO in prompts.py - Remove unused ActivationContextsGenerationConfig from schemas.py Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/mcp.py | 86 ------------------------------ spd/app/backend/routers/prompts.py | 10 ---- spd/app/backend/schemas.py | 11 ---- 3 files changed, 107 deletions(-) diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py index 448162199..488e48b4f 100644 --- a/spd/app/backend/routers/mcp.py +++ b/spd/app/backend/routers/mcp.py @@ -1309,92 +1309,6 @@ def _tool_get_component_activation_examples(params: dict[str, Any]) -> dict[str, } -# def _tool_get_component_attributions(params: dict[str, Any]) -> dict[str, Any]: -# """Get dataset-level component dependencies.""" -# _, loaded = _get_state() - -# layer = params["layer"] -# component_idx = params["component_idx"] -# k = params.get("k", 10) - -# assert loaded.attributions is not None, "dataset attributions not loaded" -# storage = loaded.attributions.get_attributions() - -# concrete_layer = loaded.topology.canon_to_target(layer) if layer != "output" else "output" -# component_key = f"{concrete_layer}:{component_idx}" - -# _log_event( -# "tool_call", -# f"get_component_attributions: {component_key}", -# {"layer": layer, "component_idx": component_idx, "k": k}, -# ) - -# is_source = storage.has_source(component_key) -# is_target = storage.has_target(component_key) - -# assert is_source or is_target, f"Component {component_key} not found in attributions" - -# w_unembed = loaded.topology.get_unembed_weight() if is_source else None - -# def _entries_to_dicts( -# entries: list[Any], -# ) -> list[dict[str, Any]]: -# return [ -# { -# "component_key": f"{_canonicalize_layer(e.layer, loaded)}:{e.component_idx}", -# "layer": _canonicalize_layer(e.layer, loaded), -# "component_idx": e.component_idx, -# "value": e.value, -# } -# for e in entries -# ] - -# positive_sources = ( -# _entries_to_dicts(storage.get_top_sources(component_key, k, "positive")) -# if is_target -# else [] -# ) -# negative_sources = ( -# _entries_to_dicts(storage.get_top_sources(component_key, k, "negative")) -# if is_target -# else [] -# ) -# positive_targets = ( -# _entries_to_dicts( -# storage.get_top_targets( -# component_key, -# k, -# "positive", -# w_unembed=w_unembed, -# include_outputs=w_unembed is not None, -# ) -# ) -# if is_source -# else [] -# ) -# negative_targets = ( -# _entries_to_dicts( -# storage.get_top_targets( -# component_key, -# k, -# "negative", -# w_unembed=w_unembed, -# include_outputs=w_unembed is not None, -# ) -# ) -# if is_source -# else [] -# ) - -# return { -# "component_key": component_key, -# "positive_sources": positive_sources, -# "negative_sources": negative_sources, -# "positive_targets": positive_targets, -# "negative_targets": negative_targets, -# } - - def _tool_get_model_info(_params: dict[str, Any]) -> dict[str, Any]: """Get architecture details about the pretrained model.""" _, loaded = _get_state() diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 7f06bcf68..2cf0da197 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -6,16 +6,6 @@ from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.utils import log_errors -from spd.utils.distributed_utils import get_device - -# TODO: Re-enable these endpoints when dependencies are available: -# - extract_active_from_ci from database -# - PromptSearchQuery, PromptSearchResponse from schemas -# - DatasetConfig, LMTaskConfig from configs -# - create_data_loader, extract_batch_data from data -# - logger from utils - -DEVICE = get_device() # ============================================================================= # Schemas diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index 61fa3d9d2..bf5f3e9b2 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -27,17 +27,6 @@ class OutputProbability(BaseModel): # ============================================================================= -class ActivationContextsGenerationConfig(BaseModel): - """Configuration for generating activation contexts.""" - - importance_threshold: float = 0.01 - n_batches: int = 100 - batch_size: int = 32 - n_tokens_either_side: int = 5 - topk_examples: int = 20 - separation_tokens: int = 0 - - class SubcomponentMetadata(BaseModel): """Lightweight metadata for a subcomponent (without examples/token_prs)""" From 3b6d0197bdd73b022d8c4c425a63e847f42b726c Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:15:05 +0000 Subject: [PATCH 07/20] =?UTF-8?q?Revert=20unnecessary=20migration=20?= =?UTF-8?q?=E2=80=94=20real=20DB=20already=20has=20the=20columns?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Checked SPD_OUT_DIR/app/prompt_attr.db: ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob all exist. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 61be17ad5..24a823a3c 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -287,7 +287,7 @@ def init_schema(self) -> None: ON forked_intervention_runs(intervention_run_id); """) - # Migration: add edges_data_abs column if missing (backwards compat with existing DBs) + # Migration: add edges_data_abs column if missing columns = {row[1] for row in conn.execute("PRAGMA table_info(graphs)").fetchall()} if "edges_data_abs" not in columns: conn.execute("ALTER TABLE graphs ADD COLUMN edges_data_abs TEXT") From c0b985d86ce3065b9306909bcf38aef832763e23 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:16:28 +0000 Subject: [PATCH 08/20] =?UTF-8?q?Remove=20stale=20edges=5Fdata=5Fabs=20mig?= =?UTF-8?q?ration=20=E2=80=94=20column=20is=20in=20CREATE=20TABLE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CREATE TABLE statement already includes edges_data_abs (and all metric columns). The real DB at SPD_OUT_DIR has all columns present. No legacy DBs without these columns exist. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/database.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 24a823a3c..c2178076b 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -287,11 +287,6 @@ def init_schema(self) -> None: ON forked_intervention_runs(intervention_run_id); """) - # Migration: add edges_data_abs column if missing - columns = {row[1] for row in conn.execute("PRAGMA table_info(graphs)").fetchall()} - if "edges_data_abs" not in columns: - conn.execute("ALTER TABLE graphs ADD COLUMN edges_data_abs TEXT") - conn.commit() # ------------------------------------------------------------------------- From 1c15f5276819a2c9868f56e637abfc65b6885455 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:17:29 +0000 Subject: [PATCH 09/20] Update app CLAUDE.md: DB is persistent, requires manual migrations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The prompt DB is no longer disposable — it's shared team state on NFS. Schema changes need manual ALTER TABLE with backups. CREATE TABLE statements are the source of truth. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/CLAUDE.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 42f745de6..9a7d8e4a5 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -11,8 +11,9 @@ Web-based visualization and analysis tool for exploring neural network component This is a **rapidly iterated research tool**. Key implications: -- **Please do not code for backwards compatibility**: Schema changes don't need migrations -- **Database is shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, accessible by multiple backends. Do not delete without checking with the team. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access +- **Database is persistent shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, shared across the team. Do not delete. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access. + - **Schema changes require manual migration**: Update the `CREATE TABLE IF NOT EXISTS` statements to match the desired schema, then manually `ALTER TABLE` the real DB (back it up first). No automatic migration framework — just SQL. + - Keep the CREATE TABLE statements as the source of truth for the schema. - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert - **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. From 98706898f265c6b5e3ca1e307539d342746746ca Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:20:03 +0000 Subject: [PATCH 10/20] Fix MeanKL loss mismatch and add GPU lock to search_tokens MeanKL: single version used F.kl_div(reduction="batchmean") which for [1, seq, vocab] gives sum over all positions. Batched version used .sum(-1).mean(-1) giving mean over positions. These differ by a factor of seq_len. Fixed single version to match batched (mean over positions). search_tokens: was running model forward pass without GPU lock, risking concurrent CUDA ops with graph computation. Added manager.gpu_lock(). Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/optim_cis.py | 3 ++- spd/app/backend/routers/graphs.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 4c7e42fbb..5ec2f15b5 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -91,7 +91,8 @@ def compute_recon_loss( case MeanKLLossConfig(): target_probs = F.softmax(target_out, dim=-1) pred_log_probs = F.log_softmax(logits, dim=-1) - return F.kl_div(pred_log_probs, target_probs, reduction="batchmean") + # sum over vocab, mean over positions (consistent with batched version) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) @dataclass diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index e2b439da1..db2dd68cb 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -456,7 +456,7 @@ def search_tokens( device = next(loaded.model.parameters()).device tokens_tensor = torch.tensor([prompt.token_ids], device=device) - with torch.no_grad(): + with manager.gpu_lock(), torch.no_grad(): logits = loaded.model(tokens_tensor) probs = torch.softmax(logits[0, position], dim=-1) From 24b086ff828f6e0a9b7bd78c0fe55864501c9d98 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:21:19 +0000 Subject: [PATCH 11/20] Gracefully skip base intervention run when no interventable nodes Previously asserted and crashed after the graph was already saved to the DB, leaving an orphaned graph with no base intervention run. Now logs a warning and returns early. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/graphs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index db2dd68cb..6ec5cd08f 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -73,7 +73,9 @@ def _save_base_intervention_run( for k, ci in node_ci_vals.items() if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 ] - assert len(interventable_keys) > 0, "No interventable nodes with CI > 0" + if not interventable_keys: + logger.warning(f"Graph {graph_id}: no interventable nodes with CI > 0, skipping base intervention run") + return active_nodes: list[tuple[str, int, int]] = [] for key in interventable_keys: From 1e2598d9442dc7a20cae5acbf9c0eea350bd24a4 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:23:43 +0000 Subject: [PATCH 12/20] Remove MOCK_MODE and all mock code from graph_interp router ~112 lines of dead mock data, mock functions, and MOCK_MODE branches. Also remove the cross-router MOCK_MODE import in runs.py. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/routers/graph_interp.py | 129 ------------------------ spd/app/backend/routers/runs.py | 7 +- 2 files changed, 2 insertions(+), 134 deletions(-) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py index 272bd4d80..003075b4d 100644 --- a/spd/app/backend/routers/graph_interp.py +++ b/spd/app/backend/routers/graph_interp.py @@ -4,8 +4,6 @@ prompt-edge graph produced by the graph_interp pipeline. """ -import random - from fastapi import APIRouter, HTTPException from pydantic import BaseModel @@ -14,9 +12,6 @@ from spd.graph_interp.schemas import LabelResult from spd.topology import TransformerTopology -# TODO(oli): Remove MOCK_MODE after real data is available -MOCK_MODE = False - MAX_GRAPH_NODES = 500 @@ -103,9 +98,6 @@ class ModelGraphResponse(BaseModel): @router.get("/labels") @log_errors def get_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: - if MOCK_MODE: - return _mock_all_labels(loaded) - repo = loaded.graph_interp if repo is None: return {} @@ -151,9 +143,6 @@ def _to_detail(label: LabelResult | None) -> LabelDetail | None: @router.get("/labels/{layer}/{c_idx}") @log_errors def get_label_detail(layer: str, c_idx: int, loaded: DepLoadedRun) -> GraphInterpDetail: - if MOCK_MODE: - return _mock_label_detail(layer, c_idx) - repo = loaded.graph_interp if repo is None: raise HTTPException(status_code=404, detail="Graph interp data not available") @@ -207,9 +196,6 @@ def get_component_detail( @router.get("/graph") @log_errors def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: - if MOCK_MODE: - return _mock_model_graph(loaded) - repo = loaded.graph_interp if repo is None: raise HTTPException(status_code=404, detail="Graph interp data not available") @@ -256,118 +242,3 @@ def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: ) return ModelGraphResponse(nodes=nodes, edges=edges) - - -# -- Mock data (TODO: remove after real data available) ------------------------ - -_MOCK_LABELS = [ - "sentence-final punctuation", - "proper noun completion", - "emotional adjective selection", - "temporal adverb prediction", - "morphological suffix (-ing/-ed)", - "determiner before noun", - "dialogue quotation marks", - "plural noun suffix", - "clause boundary detection", - "verb tense agreement", - "spatial preposition", - "possessive pronoun", - "narrative action verb", - "abstract emotion noun", - "comparative adjective form", - "subject-verb agreement", - "article selection (a/the)", - "comma splice detection", - "pronoun resolution", - "negation scope", -] - -_MOCK_INPUT_LABELS = [ - "sentence-initial capitals", - "mid-sentence verb position", - "adjective-noun boundary", - "clause-final position", - "article-noun sequence", - "subject pronoun at boundary", - "preposition-object pair", - "verb stem before suffix", - "quotation boundary", - "comma-separated items", -] - - -def _mock_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: - rng = random.Random(42) - topology = loaded.topology - confidences = ["high", "high", "high", "medium", "medium", "low"] - - result: dict[str, GraphInterpHeadline] = {} - for target_path, components in loaded.model.components.items(): - canon = topology.target_to_canon(target_path) - n_components = components.C - n_mock = min(n_components, rng.randint(5, 20)) - indices = sorted(rng.sample(range(n_components), n_mock)) - for idx in indices: - key = f"{canon}:{idx}" - result[key] = GraphInterpHeadline( - label=rng.choice(_MOCK_LABELS), - confidence=rng.choice(confidences), - output_label=rng.choice(_MOCK_LABELS), - input_label=rng.choice(_MOCK_INPUT_LABELS), - ) - return result - - -def _mock_label_detail(layer: str, c_idx: int) -> GraphInterpDetail: - rng = random.Random(hash((layer, c_idx))) - conf = rng.choice(["high", "medium", "low"]) - return GraphInterpDetail( - output=LabelDetail( - label=rng.choice(_MOCK_LABELS), - confidence=conf, - reasoning=f"Output: Component {layer}:{c_idx} writes {rng.choice(_MOCK_LABELS).lower()} tokens to the residual stream.", - prompt="(mock prompt)", - ), - input=LabelDetail( - label=rng.choice(_MOCK_INPUT_LABELS), - confidence=conf, - reasoning=f"Input: Component {layer}:{c_idx} fires on {rng.choice(_MOCK_INPUT_LABELS).lower()} patterns.", - prompt="(mock prompt)", - ), - unified=LabelDetail( - label=rng.choice(_MOCK_LABELS), - confidence=conf, - reasoning=f"Unified: Combines output ({rng.choice(_MOCK_LABELS).lower()}) and input ({rng.choice(_MOCK_INPUT_LABELS).lower()}) functions.", - prompt="(mock prompt)", - ), - ) - - -def _mock_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: - labels = _mock_all_labels(loaded) - - nodes = [ - GraphNode(component_key=key, label=h.label, confidence=h.confidence) - for key, h in labels.items() - ] - - rng = random.Random(42) - keys = list(labels.keys()) - edges: list[GraphEdge] = [] - - for key in keys: - layer = key.rsplit(":", 1)[0] - later_keys = [k for k in keys if k.rsplit(":", 1)[0] != layer] - n_edges = rng.randint(1, 4) - for target in rng.sample(later_keys, min(n_edges, len(later_keys))): - edges.append( - GraphEdge( - source=key, - target=target, - attribution=rng.uniform(-1.0, 1.0), - pass_name=rng.choice(["output", "input"]), - ) - ) - - return ModelGraphResponse(nodes=nodes, edges=edges) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 914998a4e..3b9c2300e 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -7,16 +7,15 @@ import yaml from fastapi import APIRouter from pydantic import BaseModel -from spd.graph_interp.repo import GraphInterpRepo from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.dependencies import DepStateManager -from spd.app.backend.routers.graph_interp import MOCK_MODE as _GRAPH_INTERP_MOCK_MODE from spd.app.backend.state import RunState from spd.app.backend.utils import log_errors from spd.autointerp.repo import InterpRepo from spd.configs import LMTaskConfig from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -170,9 +169,7 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: backend_user=getpass.getuser(), dataset_attributions_available=manager.run_state.attributions is not None, dataset_search_enabled=dataset_search_enabled, - # TODO(oli): Remove MOCK_MODE import after real data available - graph_interp_available=manager.run_state.graph_interp is not None - or _GRAPH_INTERP_MOCK_MODE, + graph_interp_available=manager.run_state.graph_interp is not None, autointerp_available=manager.run_state.interp is not None, ) From c51e62fba048c340741ccb2cb94b3e57830783aa Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:31:10 +0000 Subject: [PATCH 13/20] Delete dead fork code, unused DB methods, deduplicate MAX_OUTPUT_NODES_PER_POS database.py: - Remove ForkedInterventionRunRecord class - Remove forked_intervention_runs table + index from schema - Remove fork cleanup from delete_prompt - Remove save/get/delete_forked_intervention_run methods - Remove unused: delete_graphs_for_prompt, delete_graphs_for_run, delete_intervention_runs_for_graph, get_intervention_run graphs.py: - Import MAX_OUTPUT_NODES_PER_POS from compute.py instead of redefining Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/database.py | 144 ------------------------------ spd/app/backend/routers/graphs.py | 4 +- 2 files changed, 1 insertion(+), 147 deletions(-) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index c2178076b..3d6f519ac 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -123,15 +123,6 @@ class InterventionRunRecord(BaseModel): created_at: str -class ForkedInterventionRunRecord(BaseModel): - """A forked intervention run with modified tokens (currently unused).""" - - id: int - intervention_run_id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result_json: str - created_at: str - class PromptAttrDB: """SQLite database for storing and querying prompt attribution data. @@ -274,17 +265,6 @@ def init_schema(self) -> None: CREATE INDEX IF NOT EXISTS idx_intervention_runs_graph ON intervention_runs(graph_id); - - CREATE TABLE IF NOT EXISTS forked_intervention_runs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - intervention_run_id INTEGER NOT NULL REFERENCES intervention_runs(id) ON DELETE CASCADE, - token_replacements TEXT NOT NULL, -- JSON array of [seq_pos, new_token_id] tuples - result TEXT NOT NULL, -- JSON InterventionResponse - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_forked_intervention_runs_parent - ON forked_intervention_runs(intervention_run_id); """) conn.commit() @@ -691,13 +671,6 @@ def delete_prompt(self, prompt_id: int) -> None: with self._write_lock(): conn = self._get_conn() graph_ids_query = "SELECT id FROM graphs WHERE prompt_id = ?" - intervention_ids_query = ( - f"SELECT id FROM intervention_runs WHERE graph_id IN ({graph_ids_query})" - ) - conn.execute( - f"DELETE FROM forked_intervention_runs WHERE intervention_run_id IN ({intervention_ids_query})", - (prompt_id,), - ) conn.execute( f"DELETE FROM intervention_runs WHERE graph_id IN ({graph_ids_query})", (prompt_id,), @@ -706,26 +679,6 @@ def delete_prompt(self, prompt_id: int) -> None: conn.execute("DELETE FROM prompts WHERE id = ?", (prompt_id,)) conn.commit() - def delete_graphs_for_prompt(self, prompt_id: int) -> int: - """Delete all graphs for a prompt. Returns the number of deleted rows.""" - with self._write_lock(): - conn = self._get_conn() - cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) - conn.commit() - return cursor.rowcount - - def delete_graphs_for_run(self, run_id: int) -> int: - """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" - with self._write_lock(): - conn = self._get_conn() - cursor = conn.execute( - """DELETE FROM graphs - WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", - (run_id,), - ) - conn.commit() - return cursor.rowcount - # ------------------------------------------------------------------------- # Intervention run operations # ------------------------------------------------------------------------- @@ -794,101 +747,4 @@ def delete_intervention_run(self, run_id: int) -> None: conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) conn.commit() - def delete_intervention_runs_for_graph(self, graph_id: int) -> int: - """Delete all intervention runs for a graph. Returns count deleted.""" - with self._write_lock(): - conn = self._get_conn() - cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) - conn.commit() - return cursor.rowcount - - # ------------------------------------------------------------------------- - # Forked intervention run operations - # ------------------------------------------------------------------------- - - def save_forked_intervention_run( - self, - intervention_run_id: int, - token_replacements: list[tuple[int, int]], - result_json: str, - ) -> int: - """Save a forked intervention run. - - Args: - intervention_run_id: The parent intervention run ID. - token_replacements: List of (seq_pos, new_token_id) tuples. - result_json: JSON-encoded InterventionResponse. - Returns: - The forked intervention run ID. - """ - with self._write_lock(): - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) - VALUES (?, ?, ?)""", - (intervention_run_id, json.dumps(token_replacements), result_json), - ) - conn.commit() - fork_id = cursor.lastrowid - assert fork_id is not None - return fork_id - - def get_forked_intervention_runs( - self, intervention_run_id: int - ) -> list[ForkedInterventionRunRecord]: - """Get all forked runs for an intervention run. - - Args: - intervention_run_id: The parent intervention run ID. - - Returns: - List of forked intervention run records, ordered by creation time. - """ - conn = self._get_conn() - rows = conn.execute( - """SELECT id, intervention_run_id, token_replacements, result, created_at - FROM forked_intervention_runs - WHERE intervention_run_id = ? - ORDER BY created_at""", - (intervention_run_id,), - ).fetchall() - - return [ - ForkedInterventionRunRecord( - id=row["id"], - intervention_run_id=row["intervention_run_id"], - token_replacements=json.loads(row["token_replacements"]), - result_json=row["result"], - created_at=row["created_at"], - ) - for row in rows - ] - - def get_intervention_run(self, run_id: int) -> InterventionRunRecord | None: - """Get a single intervention run by ID.""" - conn = self._get_conn() - row = conn.execute( - """SELECT id, graph_id, selected_nodes, result, created_at - FROM intervention_runs - WHERE id = ?""", - (run_id,), - ).fetchone() - - if row is None: - return None - - return InterventionRunRecord( - id=row["id"], - graph_id=row["graph_id"], - selected_nodes=json.loads(row["selected_nodes"]), - result_json=row["result"], - created_at=row["created_at"], - ) - - def delete_forked_intervention_run(self, fork_id: int) -> None: - """Delete a forked intervention run.""" - with self._write_lock(): - conn = self._get_conn() - conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) - conn.commit() diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 6ec5cd08f..141ab7785 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -20,6 +20,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( DEFAULT_EVAL_PGD_CONFIG, + MAX_OUTPUT_NODES_PER_POS, Edge, compute_intervention, compute_prompt_attributions, @@ -281,9 +282,6 @@ class BatchGraphResult(BaseModel): ProgressCallback = Callable[[int, int, str], None] -MAX_OUTPUT_NODES_PER_POS = 15 - - def _build_out_probs( ci_masked_out_logits: torch.Tensor, target_out_logits: torch.Tensor, From 56f975e67e6c7fd3707fe3626c398a366fb53603 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:39:59 +0000 Subject: [PATCH 14/20] =?UTF-8?q?Make=20harvest=5Fsubrun=5Fid=20required?= =?UTF-8?q?=20everywhere=20=E2=80=94=20no=20more=20"use=20most=20recent"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All CLI entrypoints and pipeline functions now require explicit harvest_subrun_id. Eliminates silent fallback to open_most_recent() which could pick up stale data from a different config. - autointerp run_interpret.py: main() and get_command() require it - autointerp run_slurm.py: submit_autointerp() requires it - autointerp run_slurm_cli.py: CLI requires --harvest_subrun_id - autointerp scoring/run_label_scoring.py: main() and get_command() require it - dataset_attributions config.py: harvest_subrun_id is required on config - dataset_attributions harvest.py: _build_alive_masks requires it - dataset_attributions run_slurm.py: removed harvest_subrun_id param (now in config) - postprocess __init__.py: sets harvest_subrun_id on attr config from harvest result Co-Authored-By: Claude Opus 4.6 (1M context) --- .../scoring/scripts/run_label_scoring.py | 24 +++++++------------ spd/autointerp/scripts/run_interpret.py | 14 ++++------- spd/autointerp/scripts/run_slurm.py | 2 +- spd/autointerp/scripts/run_slurm_cli.py | 5 ++-- spd/dataset_attributions/config.py | 2 +- spd/dataset_attributions/harvest.py | 8 ++----- spd/dataset_attributions/scripts/run_slurm.py | 24 ++----------------- spd/postprocess/__init__.py | 7 ++++-- 8 files changed, 27 insertions(+), 59 deletions(-) diff --git a/spd/autointerp/scoring/scripts/run_label_scoring.py b/spd/autointerp/scoring/scripts/run_label_scoring.py index be2efa388..fd95f9763 100644 --- a/spd/autointerp/scoring/scripts/run_label_scoring.py +++ b/spd/autointerp/scoring/scripts/run_label_scoring.py @@ -25,7 +25,7 @@ def main( decomposition_id: str, scorer_type: LabelScorerType, config_json: dict[str, Any], - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, ) -> None: assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" load_dotenv() @@ -44,15 +44,11 @@ def main( # Separate writable DB for saving scores (the repo's DB is readonly/immutable) score_db = InterpDB(interp_repo._subrun_dir / "interp.db") - if harvest_subrun_id is not None: - harvest = HarvestRepo( - decomposition_id=decomposition_id, - subrun_id=harvest_subrun_id, - readonly=True, - ) - else: - harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=True) - assert harvest is not None, f"No harvest data for {decomposition_id}" + harvest = HarvestRepo( + decomposition_id=decomposition_id, + subrun_id=harvest_subrun_id, + readonly=True, + ) components = harvest.get_all_components() @@ -99,18 +95,16 @@ def get_command( decomposition_id: str, scorer_type: LabelScorerType, config: AutointerpEvalConfig, - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, ) -> str: config_json = config.model_dump_json(exclude_none=True) - cmd = ( + return ( f"python -m spd.autointerp.scoring.scripts.run_label_scoring " f"--decomposition_id {decomposition_id} " f"--scorer_type {scorer_type} " f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " ) - if harvest_subrun_id is not None: - cmd += f" --harvest_subrun_id {harvest_subrun_id} " - return cmd if __name__ == "__main__": diff --git a/spd/autointerp/scripts/run_interpret.py b/spd/autointerp/scripts/run_interpret.py index da056dc35..263329172 100644 --- a/spd/autointerp/scripts/run_interpret.py +++ b/spd/autointerp/scripts/run_interpret.py @@ -22,7 +22,7 @@ def main( decomposition_id: str, config_json: dict[str, Any], - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, autointerp_subrun_id: str | None = None, ) -> None: assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" @@ -32,12 +32,7 @@ def main( openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") assert openrouter_api_key, "OPENROUTER_API_KEY not set" - if harvest_subrun_id is not None: - harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=False) - else: - harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=False) - if harvest is None: - raise ValueError(f"No harvest data found for {decomposition_id}") + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=False) if autointerp_subrun_id is not None: subrun_dir = get_autointerp_dir(decomposition_id) / autointerp_subrun_id @@ -76,7 +71,7 @@ def main( def get_command( decomposition_id: str, config: AutointerpConfig, - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, autointerp_subrun_id: str | None = None, ) -> str: config_json = config.model_dump_json(exclude_none=True) @@ -84,9 +79,8 @@ def get_command( "python -m spd.autointerp.scripts.run_interpret " f"--decomposition_id {decomposition_id} " f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " ) - if harvest_subrun_id is not None: - cmd += f"--harvest_subrun_id {harvest_subrun_id} " if autointerp_subrun_id is not None: cmd += f"--autointerp_subrun_id {autointerp_subrun_id} " return cmd diff --git a/spd/autointerp/scripts/run_slurm.py b/spd/autointerp/scripts/run_slurm.py index bcd83b94c..4a1cd0bdc 100644 --- a/spd/autointerp/scripts/run_slurm.py +++ b/spd/autointerp/scripts/run_slurm.py @@ -30,9 +30,9 @@ class AutointerpSubmitResult: def submit_autointerp( decomposition_id: str, config: AutointerpSlurmConfig, + harvest_subrun_id: str, dependency_job_id: str | None = None, snapshot_branch: str | None = None, - harvest_subrun_id: str | None = None, ) -> AutointerpSubmitResult: """Submit the autointerp pipeline to SLURM. diff --git a/spd/autointerp/scripts/run_slurm_cli.py b/spd/autointerp/scripts/run_slurm_cli.py index fffc75c60..56db16499 100644 --- a/spd/autointerp/scripts/run_slurm_cli.py +++ b/spd/autointerp/scripts/run_slurm_cli.py @@ -10,18 +10,19 @@ import fire -def main(decomposition_id: str, config: str) -> None: +def main(decomposition_id: str, config: str, harvest_subrun_id: str) -> None: """Submit autointerp pipeline (interpret + evals) to SLURM. Args: decomposition_id: ID of the target decomposition run. config: Path to AutointerpSlurmConfig YAML/JSON. + harvest_subrun_id: Harvest subrun to use (e.g. "h-20260306_120000"). """ from spd.autointerp.config import AutointerpSlurmConfig from spd.autointerp.scripts.run_slurm import submit_autointerp slurm_config = AutointerpSlurmConfig.from_file(config) - submit_autointerp(decomposition_id, slurm_config) + submit_autointerp(decomposition_id, slurm_config, harvest_subrun_id=harvest_subrun_id) def cli() -> None: diff --git a/spd/dataset_attributions/config.py b/spd/dataset_attributions/config.py index 6f02df0f9..8a515ab7e 100644 --- a/spd/dataset_attributions/config.py +++ b/spd/dataset_attributions/config.py @@ -14,7 +14,7 @@ class DatasetAttributionConfig(BaseConfig): spd_run_wandb_path: str - harvest_subrun_id: str | None = None + harvest_subrun_id: str n_batches: int | Literal["whole_dataset"] = 10_000 batch_size: int = 32 ci_threshold: float = 0.0 diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 6a0ba3380..da2f53505 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -36,7 +36,7 @@ def _build_alive_masks( model: ComponentModel, run_id: str, - harvest_subrun_id: str | None, + harvest_subrun_id: str, ) -> dict[str, Bool[Tensor, " n_components"]]: """Build masks of alive components (firing_density > 0) per target layer. @@ -48,11 +48,7 @@ def _build_alive_masks( for layer in model.target_module_paths } - if harvest_subrun_id is not None: - harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) - else: - harvest = HarvestRepo.open_most_recent(run_id, readonly=True) - assert harvest is not None, f"No harvest data for {run_id}" + harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) summary = harvest.get_summary() assert summary is not None, "Harvest summary not available" diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 9e04cac46..961b652c8 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -45,25 +45,8 @@ def submit_attributions( job_suffix: str | None = None, snapshot_branch: str | None = None, dependency_job_id: str | None = None, - harvest_subrun_id: str | None = None, ) -> AttributionsSubmitResult: - """Submit multi-GPU attribution harvesting job to SLURM. - - Submits a job array where each task processes a subset of batches, then - submits a merge job that depends on all workers completing. Creates a git - snapshot to ensure consistent code across all workers. - - Args: - wandb_path: WandB run path for the target decomposition run. - config: Attribution SLURM configuration. - job_suffix: Optional suffix for SLURM job names (e.g., "1h" -> "spd-attr-1h"). - snapshot_branch: Git snapshot branch to use. If None, creates a new snapshot. - dependency_job_id: SLURM job to wait for before starting (e.g. harvest merge). - harvest_subrun_id: Harvest subrun for alive masks. If None, uses most recent. - - Returns: - AttributionsSubmitResult with array, merge results and subrun ID. - """ + """Submit multi-GPU attribution harvesting job to SLURM.""" n_gpus = config.n_gpus partition = config.partition time = config.time @@ -80,10 +63,7 @@ def submit_attributions( suffix = f"-{job_suffix}" if job_suffix else "" array_job_name = f"spd-attr{suffix}" - inner_config = config.config - if harvest_subrun_id is not None and inner_config.harvest_subrun_id is None: - inner_config = inner_config.model_copy(update={"harvest_subrun_id": harvest_subrun_id}) - config_json = inner_config.model_dump_json(exclude_none=True) + config_json = config.config.model_dump_json(exclude_none=True) # SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc. worker_commands = [] diff --git a/spd/postprocess/__init__.py b/spd/postprocess/__init__.py index 35922436b..ccd209fcd 100644 --- a/spd/postprocess/__init__.py +++ b/spd/postprocess/__init__.py @@ -99,12 +99,15 @@ def postprocess(config: PostprocessConfig, dependency_job_id: str | None = None) attr_result = None if config.attributions is not None: assert isinstance(decomp_cfg, SPDHarvestConfig) + attr_inner = config.attributions.config.model_copy( + update={"harvest_subrun_id": harvest_result.subrun_id} + ) + attr_slurm = config.attributions.model_copy(update={"config": attr_inner}) attr_result = submit_attributions( wandb_path=decomp_cfg.wandb_path, - config=config.attributions, + config=attr_slurm, snapshot_branch=snapshot_branch, dependency_job_id=harvest_result.merge_result.job_id, - harvest_subrun_id=harvest_result.subrun_id, ) # === 5. Graph interp (depends on harvest merge + attribution merge) === From c959635f049d06e7beb4dc60daeb5584d46c4624 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 14:42:15 +0000 Subject: [PATCH 15/20] remove dead scripts --- scripts/export_circuit_json.py | 200 --------------------------------- scripts/render_circuit_html.py | 69 ------------ 2 files changed, 269 deletions(-) delete mode 100644 scripts/export_circuit_json.py delete mode 100644 scripts/render_circuit_html.py diff --git a/scripts/export_circuit_json.py b/scripts/export_circuit_json.py deleted file mode 100644 index 144727a33..000000000 --- a/scripts/export_circuit_json.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Export an OptimizedPromptAttributionResult to JSON for the circuit graph renderer. - -Usage from a notebook / script: - - from scripts.export_circuit_json import export_circuit_json - - export_circuit_json( - circuit=circuit, - token_ids=tokens.tolist(), - token_strings=tok.get_spans(tokens.tolist()), - output_path=Path("king_circuit.json"), - interp=interp, # optional InterpRepo - graph_interp=gi_repo, # optional GraphInterpRepo - ) -""" - -import json -from pathlib import Path - -from spd.graph_interp.repo import GraphInterpRepo - -from spd.app.backend.compute import Edge, OptimizedPromptAttributionResult -from spd.autointerp.repo import InterpRepo - - -def _parse_node_key(key: str) -> tuple[str, int, int]: - """Parse 'h.3.attn.o_proj:11:361' -> ('h.3.attn.o_proj', 11, 361).""" - parts = key.split(":") - layer = ":".join(parts[:-2]) - seq = int(parts[-2]) - cidx = int(parts[-1]) - return layer, seq, cidx - - -def _concrete_to_canonical(layer: str) -> str: - """Map concrete model path to canonical address used by graph layout. - - h.3.attn.o_proj -> 3.attn.o - h.3.attn.v_proj -> 3.attn.v - h.3.attn.q_proj -> 3.attn.q - h.3.attn.k_proj -> 3.attn.k - h.3.mlp.c_fc -> 3.mlp.up - h.3.mlp.down_proj -> 3.mlp.down - h.3.mlp.gate_proj -> 3.glu.gate - h.3.mlp.up_proj -> 3.glu.up - wte -> embed - lm_head -> output - """ - if layer == "wte": - return "embed" - if layer == "lm_head": - return "output" - - PROJ_MAP = { - "q_proj": "q", - "k_proj": "k", - "v_proj": "v", - "o_proj": "o", - "c_fc": "up", - "down_proj": "down", - "gate_proj": "gate", - "up_proj": "up", - } - - # h.{block}.{sublayer}.{proj_name} - parts = layer.split(".") - assert len(parts) == 4, f"Expected h.N.sublayer.proj, got {layer!r}" - block_idx = parts[1] - sublayer = parts[2] # "attn" or "mlp" - proj_name = parts[3] - - canonical_proj = PROJ_MAP.get(proj_name) - assert canonical_proj is not None, f"Unknown projection: {proj_name!r} in {layer!r}" - - # Determine canonical sublayer - if sublayer == "attn": - canonical_sublayer = "attn" - elif sublayer == "mlp": - canonical_sublayer = "glu" if proj_name in ("gate_proj", "up_proj") else "mlp" - else: - canonical_sublayer = sublayer - - return f"{block_idx}.{canonical_sublayer}.{canonical_proj}" - - -def _edge_to_dict(e: Edge) -> dict[str, object]: - return { - "source": str(e.source), - "target": str(e.target), - "attribution": e.strength, - "is_cross_seq": e.is_cross_seq, - } - - -def _get_label( - layer: str, - cidx: int, - interp: InterpRepo | None, - graph_interp: GraphInterpRepo | None, -) -> str | None: - component_key = f"{layer}:{cidx}" - if graph_interp is not None: - unified = graph_interp.get_unified_label(component_key) - if unified is not None: - return unified.label - output = graph_interp.get_output_label(component_key) - if output is not None: - return output.label - if interp is not None: - ir = interp.get_interpretation(component_key) - if ir is not None: - return ir.label - return None - - -def export_circuit_json( - circuit: OptimizedPromptAttributionResult, - token_ids: list[int], - token_strings: list[str], - output_path: Path, - min_ci: float = 0.3, - min_edge_attr: float = 0.1, - interp: InterpRepo | None = None, - graph_interp: GraphInterpRepo | None = None, -) -> None: - """Export circuit to JSON for the standalone HTML renderer. - - Args: - circuit: The optimized circuit result from EditableModel.optimize_circuit. - token_ids: Raw token IDs for each position. - token_strings: Display strings for each position (from tok.get_spans). - output_path: Where to write the JSON. - min_ci: Minimum CI to include a node. - min_edge_attr: Minimum |attribution| to include an edge. - interp: Optional InterpRepo for autointerp labels. - graph_interp: Optional GraphInterpRepo for graph-interp labels. - """ - # Build tokens list - tokens = [ - {"pos": i, "id": tid, "string": tstr} - for i, (tid, tstr) in enumerate(zip(token_ids, token_strings, strict=True)) - ] - - # Build nodes from ci_vals, filtering by min_ci - nodes = [] - node_keys_kept = set() - for key, ci in circuit.node_ci_vals.items(): - if ci < min_ci: - continue - layer, seq, cidx = _parse_node_key(key) - canonical = _concrete_to_canonical(layer) - act = circuit.node_subcomp_acts.get(key, 0.0) - - label = _get_label(layer, cidx, interp, graph_interp) - - nodes.append( - { - "key": key, - "graph_key": f"{layer}:{cidx}", - "layer": layer, - "canonical": canonical, - "seq": seq, - "cidx": cidx, - "ci": round(ci, 4), - "activation": round(act, 4), - "token": token_strings[seq] if seq < len(token_strings) else "?", - "label": label, - } - ) - node_keys_kept.add(key) - - # Build edges, filtering by min_edge_attr and requiring both endpoints in kept nodes - edges = [] - for e in circuit.edges: - if abs(e.strength) < min_edge_attr: - continue - src_key = str(e.source) - tgt_key = str(e.target) - if src_key not in node_keys_kept or tgt_key not in node_keys_kept: - continue - edges.append(_edge_to_dict(e)) - - metrics: dict[str, float] = {"l0_total": circuit.metrics.l0_total} - if circuit.metrics.ci_masked_label_prob is not None: - metrics["ci_masked_label_prob"] = circuit.metrics.ci_masked_label_prob - if circuit.metrics.stoch_masked_label_prob is not None: - metrics["stoch_masked_label_prob"] = circuit.metrics.stoch_masked_label_prob - - data = { - "tokens": tokens, - "nodes": nodes, - "edges": edges, - "metrics": metrics, - } - - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w") as f: - json.dump(data, f, indent=2) - - print(f"Exported {len(nodes)} nodes, {len(edges)} edges to {output_path}") diff --git a/scripts/render_circuit_html.py b/scripts/render_circuit_html.py deleted file mode 100644 index 677c9cb57..000000000 --- a/scripts/render_circuit_html.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Render a circuit JSON file as a self-contained HTML page. - -This is a thin wrapper that copies the circuit.html template and embeds the JSON -data inline, so the result is a single self-contained HTML file (no separate data file needed). - -Usage: - python scripts/render_circuit_html.py data/king_circuit.json -o circuit_standalone.html - -Or from Python: - from scripts.render_circuit_html import render_circuit_html - render_circuit_html(Path("data/king_circuit.json"), Path("circuit_standalone.html")) -""" - -import argparse -import json -from pathlib import Path - - -def render_circuit_html(json_path: Path, output_path: Path, title: str = "Circuit Graph") -> None: - """Render a circuit JSON as a self-contained HTML file. - - Takes the interactive circuit.html template and replaces the fetch() call - with inline data, producing a single portable HTML file. - """ - with open(json_path) as f: - data = json.load(f) - - # Read the template - template_path = Path(__file__).parent.parent / "scripts" / "_circuit_template.html" - if not template_path.exists(): - # Fallback: read from www - from spd.settings import SPD_OUT_DIR - - template_path = SPD_OUT_DIR / "www" / "pile-editing" / "circuit.html" - - assert template_path.exists(), f"Template not found: {template_path}" - template = template_path.read_text() - - # Replace the fetch with inline data - inline_js = f"data = {json.dumps(data)}; init();" - template = template.replace( - "fetch(DATA_URL)\n" - " .then(r => { if (!r.ok) throw new Error(`HTTP ${r.status}`); return r.json(); })\n" - " .then(d => { data = d; init(); })\n" - " .catch(e => { document.getElementById('stats').textContent = `Error: ${e.message}`; });", - inline_js, - ) - - # Update title - template = template.replace( - 'Circuit Graph — King → "he"', f"{title}" - ) - - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w") as f: - f.write(template) - - print(f"Wrote self-contained HTML to {output_path}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Render circuit JSON as self-contained HTML") - parser.add_argument("json_path", type=Path, help="Path to circuit JSON file") - parser.add_argument( - "-o", "--output", type=Path, default=Path("circuit.html"), help="Output HTML path" - ) - parser.add_argument("--title", default="Circuit Graph", help="Page title") - args = parser.parse_args() - render_circuit_html(args.json_path, args.output, args.title) From 48b91acc491abd2846f3e9d88a14097dcccfbdd5 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 15:54:20 +0000 Subject: [PATCH 16/20] Add Md DSL to spd/utils/ for shared use across prompt builders Same block-based DSL from graph_interp, canonical location for autointerp strategies to import from too. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/utils/markdown.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 spd/utils/markdown.py diff --git a/spd/utils/markdown.py b/spd/utils/markdown.py new file mode 100644 index 000000000..f4f270250 --- /dev/null +++ b/spd/utils/markdown.py @@ -0,0 +1,43 @@ +"""Minimal Markdown document builder for prompt construction. + +Atomic unit is a block (paragraph, heading, list, etc.). +build() joins blocks with double newlines. +""" + + +class Md: + """Accumulates Markdown blocks with a fluent API. + + Each method appends a block and returns self for chaining. + Call .build() to get the final string (blocks joined by blank lines). + """ + + def __init__(self) -> None: + self._blocks: list[str] = [] + + def h2(self, text: str) -> "Md": + self._blocks.append(f"## {text}") + return self + + def h3(self, text: str) -> "Md": + self._blocks.append(f"### {text}") + return self + + def p(self, text: str) -> "Md": + self._blocks.append(text) + return self + + def bullets(self, items: list[str]) -> "Md": + self._blocks.append("\n".join(f"- {item}" for item in items)) + return self + + def numbered(self, items: list[str]) -> "Md": + self._blocks.append("\n".join(f"{i}. {item}" for i, item in enumerate(items, 1))) + return self + + def extend(self, other: "Md") -> "Md": + self._blocks.extend(other._blocks) + return self + + def build(self) -> str: + return "\n\n".join(self._blocks) From c024988348356bbef41f6ab154a431753774c0dc Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 16:14:43 +0000 Subject: [PATCH 17/20] Rewrite all LLM prompt builders with Md DSL, fix pre-existing warnings Convert compact_skeptical, dual_view, and graph_interp prompt formatters from f-string concatenation to the Md block-based DSL. Extract shared token_pmi_pairs helper into prompt_helpers. Add labeled_list to Md for the common bold-header + bullet-items pattern. Also fix two pre-existing basedpyright warnings: DONE_MARKER import path in graph_interp/repo.py and unused param in test_storage.py. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/database.py | 3 - spd/app/backend/routers/graphs.py | 4 +- spd/app/backend/state.py | 2 +- spd/autointerp/prompt_helpers.py | 97 +++++----- .../strategies/compact_skeptical.py | 154 +++++++-------- spd/autointerp/strategies/dual_view.py | 115 +++++++----- spd/graph_interp/interpret.py | 44 +++-- spd/graph_interp/prompts.py | 177 ++++++++++-------- spd/graph_interp/repo.py | 3 +- spd/harvest/analysis.py | 1 - spd/utils/markdown.py | 13 +- tests/dataset_attributions/test_storage.py | 2 +- 12 files changed, 324 insertions(+), 291 deletions(-) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 3d6f519ac..f64593237 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -123,7 +123,6 @@ class InterventionRunRecord(BaseModel): created_at: str - class PromptAttrDB: """SQLite database for storing and querying prompt attribution data. @@ -746,5 +745,3 @@ def delete_intervention_run(self, run_id: int) -> None: conn = self._get_conn() conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) conn.commit() - - diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 141ab7785..fd43b3b1f 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -75,7 +75,9 @@ def _save_base_intervention_run( if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 ] if not interventable_keys: - logger.warning(f"Graph {graph_id}: no interventable nodes with CI > 0, skipping base intervention run") + logger.warning( + f"Graph {graph_id}: no interventable nodes with CI > 0, skipping base intervention run" + ) return active_nodes: list[tuple[str, int, int]] = [] diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index dd70d22ca..2cdabda73 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -12,13 +12,13 @@ from typing import Any from fastapi import HTTPException -from spd.graph_interp.repo import GraphInterpRepo from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB, Run from spd.autointerp.repo import InterpRepo from spd.configs import Config from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.models.component_model import ComponentModel from spd.topology import TransformerTopology diff --git a/spd/autointerp/prompt_helpers.py b/spd/autointerp/prompt_helpers.py index ac1f412fd..823bf38ef 100644 --- a/spd/autointerp/prompt_helpers.py +++ b/spd/autointerp/prompt_helpers.py @@ -9,6 +9,17 @@ from spd.app.backend.utils import delimit_tokens from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md + + +def token_pmi_pairs( + app_tok: AppTokenizer, + token_pmi_top: list[tuple[int, float]] | None, +) -> list[tuple[str, float]] | None: + if not token_pmi_top: + return None + return [(app_tok.get_tok_display(tid), pmi) for tid, pmi in token_pmi_top] + DATASET_DESCRIPTIONS: dict[str, str] = { "SimpleStories/SimpleStories": ( @@ -43,11 +54,7 @@ def ordinal(n: int) -> str: def human_layer_desc(canonical: str, n_blocks: int) -> str: - """Convert canonical layer string to human-readable description. - - '0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks' - '1.attn.q' -> 'attention query projection in the 2nd of 4 blocks' - """ + """'0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks'""" m = re.match(r"(\d+)\.(.*)", canonical) if not m: return canonical @@ -58,7 +65,6 @@ def human_layer_desc(canonical: str, n_blocks: int) -> str: def layer_position_note(canonical: str, n_blocks: int) -> str: - """Brief note about what layer position means for interpretation.""" m = re.match(r"(\d+)\.", canonical) if not m: return "" @@ -86,75 +92,72 @@ def density_note(firing_density: float) -> str: def build_output_section( output_stats: TokenPRLift, output_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if output_pmi: - section += ( + md.labeled_list( "**Output PMI (pointwise mutual information, in nats: how much more likely " "a token is to be produced when this component fires, vs its base rate. " - "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**\n" + "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in output_pmi[:10]], ) - for tok, pmi in output_pmi[:10]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - if output_stats.top_precision: - section += "\n**Output precision — of all probability mass for token X, what fraction is at positions where this component fires?**\n" - for tok, prec in output_stats.top_precision[:10]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section + md.labeled_list( + "**Output precision — of all probability mass for token X, what fraction " + "is at positions where this component fires?**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in output_stats.top_precision[:10]], + ) + return md def build_input_section( input_stats: TokenPRLift, input_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if input_pmi: - section += "**Input PMI (same metric as above, for input tokens):**\n" - for tok, pmi in input_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - + md.labeled_list( + "**Input PMI (same metric as above, for input tokens):**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in input_pmi[:6]], + ) if input_stats.top_precision: - section += "\n**Input precision — probability the component fires given the current token is X:**\n" - for tok, prec in input_stats.top_precision[:8]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section + md.labeled_list( + "**Input precision — probability the component fires given the current token is X:**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in input_stats.top_precision[:8]], + ) + return md def build_fires_on_examples( component: ComponentData, app_tok: AppTokenizer, max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): +) -> Md: + lines: list[str] = [] + for i, ex in enumerate(component.activation_examples[:max_examples]): if any(ex.firings): spans = app_tok.get_spans(ex.token_ids) tokens = list(zip(spans, ex.firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section + lines.append(f"{i + 1}. {delimit_tokens(tokens)}") + md = Md() + if lines: + md.p("\n".join(lines)) + return md def build_says_examples( component: ComponentData, app_tok: AppTokenizer, max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): +) -> Md: + lines: list[str] = [] + for i, ex in enumerate(component.activation_examples[:max_examples]): if any(ex.firings): spans = app_tok.get_spans(ex.token_ids) shifted_firings = [False] + ex.firings[:-1] tokens = list(zip(spans, shifted_firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section + lines.append(f"{i + 1}. {delimit_tokens(tokens)}") + md = Md() + if lines: + md.p("\n".join(lines)) + return md diff --git a/spd/autointerp/strategies/compact_skeptical.py b/spd/autointerp/strategies/compact_skeptical.py index 6998bfda8..857a11144 100644 --- a/spd/autointerp/strategies/compact_skeptical.py +++ b/spd/autointerp/strategies/compact_skeptical.py @@ -6,10 +6,15 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.autointerp.config import CompactSkepticalConfig -from spd.autointerp.prompt_helpers import DATASET_DESCRIPTIONS, build_fires_on_examples +from spd.autointerp.prompt_helpers import ( + DATASET_DESCRIPTIONS, + build_fires_on_examples, + token_pmi_pairs, +) from spd.autointerp.schemas import ModelMetadata from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md SPD_CONTEXT = ( "Each component has a causal importance (CI) value per token position. " @@ -29,29 +34,18 @@ def format_prompt( output_pmi: list[tuple[str, float]] | None = None if config.include_pmi: - input_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None - ) - output_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None - ) + input_pmi = token_pmi_pairs(app_tok, component.input_token_pmi.top) + output_pmi = token_pmi_pairs(app_tok, component.output_token_pmi.top) input_section = _build_input_section(input_token_stats, input_pmi) output_section = _build_output_section(output_token_stats, output_pmi) - examples_section = build_fires_on_examples( - component, - app_tok, - config.max_examples, - ) + examples_section = build_fires_on_examples(component, app_tok, config.max_examples) - if component.firing_density > 0.0: - rate_str = f"~1 in {int(1 / component.firing_density)} tokens" - else: - rate_str = "extremely rare" # TODO(oli) make this string better. does this even happen? + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) layer_desc = model_metadata.layer_descriptions.get(component.layer, component.layer) @@ -60,83 +54,89 @@ def format_prompt( dataset_desc = DATASET_DESCRIPTIONS[model_metadata.dataset_name] dataset_line = f", dataset: {dataset_desc}" - spd_context_block = f"\n{SPD_CONTEXT}\n" if config.include_spd_context else "" - forbidden = ", ".join(config.forbidden_words) if config.forbidden_words else "(none)" - return f"""\ -Label this neural network component. -{spd_context_block} -## Context -- Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line} -- Component location: {layer_desc} -- Component firing rate: {component.firing_density * 100:.2f}% ({rate_str}) - -## Token correlations + md = Md() + md.p("Label this neural network component.") -{input_section} -{output_section} + if config.include_spd_context: + md.p(SPD_CONTEXT) -## Activation examples (active tokens in <>) - -{examples_section} - -## Task + md.h(2, "Context").bullets( + [ + f"Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line}", + f"Component location: {layer_desc}", + f"Component firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) -Give a 2-{config.label_max_words} word label for what this component detects. + md.h(2, "Token correlations") + md.extend(input_section).extend(output_section) -Be SKEPTICAL. If you can't identify specific tokens or a tight grammatical pattern, say "unclear". + md.h(2, "Activation examples (active tokens in <>)") + md.extend(examples_section) -Rules: -1. Good labels name SPECIFIC tokens: "'the'", "##ing suffix", "she/her pronouns" -2. Say "unclear" if: tokens are too varied, pattern is abstract, or evidence is weak -3. FORBIDDEN words (too vague): {forbidden} -4. Lowercase only -5. Confidence: "high" = clear, specific pattern with strong evidence; "medium" = plausible but noisy; "low" = speculative + md.h(2, "Task") + md.p(f"Give a 2-{config.label_max_words} word label for what this component detects.") + md.p( + "Be SKEPTICAL. If you can't identify specific tokens or a tight grammatical " + 'pattern, say "unclear".' + ) + md.p("Rules:") + md.numbered( + [ + 'Good labels name SPECIFIC tokens: "\'the\'", "##ing suffix", "she/her pronouns"', + 'Say "unclear" if: tokens are too varied, pattern is abstract, or evidence is weak', + f"FORBIDDEN words (too vague): {forbidden}", + "Lowercase only", + 'Confidence: "high" = clear, specific pattern with strong evidence; ' + '"medium" = plausible but noisy; "low" = speculative', + ] + ) + md.p( + 'GOOD: "##ed suffix", "\'and\' conjunction", "she/her/hers", "period then capital", "unclear"\n' + 'BAD: "various words and punctuation", "verbs and adjectives", "tokens near commas"' + ) -GOOD: "##ed suffix", "'and' conjunction", "she/her/hers", "period then capital", "unclear" -BAD: "various words and punctuation", "verbs and adjectives", "tokens near commas" -""" + return md.build() def _build_input_section( input_stats: TokenPRLift, input_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if input_stats.top_recall: - section += "**Input tokens with highest recall (most common current tokens when the component is firing)**\n" - for tok, recall in input_stats.top_recall[:8]: - section += f"- {repr(tok)}: {recall * 100:.0f}%\n" - + md.labeled_list( + "**Input tokens with highest recall (most common current tokens when the component is firing)**", + [f"{repr(tok)}: {recall * 100:.0f}%" for tok, recall in input_stats.top_recall[:8]], + ) if input_stats.top_precision: - section += "\n**Input tokens with highest precision (probability the component fires given the current token is X)**\n" - for tok, prec in input_stats.top_precision[:8]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - + md.labeled_list( + "**Input tokens with highest precision (probability the component fires given the current token is X)**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in input_stats.top_precision[:8]], + ) if input_pmi: - section += "\n**Input tokens with highest PMI (pointwise mutual information. Tokens with higher-than-base-rate likelihood of co-occurrence with the component firing)**\n" - for tok, pmi in input_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - return section + md.labeled_list( + "**Input tokens with highest PMI (pointwise mutual information. Tokens with higher-than-base-rate likelihood of co-occurrence with the component firing)**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in input_pmi[:6]], + ) + return md def _build_output_section( output_stats: TokenPRLift, output_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if output_stats.top_precision: - section += "**Output precision — of all predicted probability for token X, what fraction is at positions where this component fires?**\n" - for tok, prec in output_stats.top_precision[:10]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - + md.labeled_list( + "**Output precision — of all predicted probability for token X, what fraction is at positions where this component fires?**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in output_stats.top_precision[:10]], + ) if output_pmi: - section += "\n**Output PMI — tokens the model predicts at higher-than-base-rate when this component fires:**\n" - for tok, pmi in output_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - return section + md.labeled_list( + "**Output PMI — tokens the model predicts at higher-than-base-rate when this component fires:**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in output_pmi[:6]], + ) + return md diff --git a/spd/autointerp/strategies/dual_view.py b/spd/autointerp/strategies/dual_view.py index 1206c73f5..8256f740d 100644 --- a/spd/autointerp/strategies/dual_view.py +++ b/spd/autointerp/strategies/dual_view.py @@ -18,10 +18,12 @@ density_note, human_layer_desc, layer_position_note, + token_pmi_pairs, ) from spd.autointerp.schemas import ModelMetadata from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md def format_prompt( @@ -36,26 +38,19 @@ def format_prompt( output_pmi: list[tuple[str, float]] | None = None if config.include_pmi: - input_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None - ) - output_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None - ) + input_pmi = token_pmi_pairs(app_tok, component.input_token_pmi.top) + output_pmi = token_pmi_pairs(app_tok, component.output_token_pmi.top) output_section = build_output_section(output_token_stats, output_pmi) input_section = build_input_section(input_token_stats, input_pmi) fires_on_examples = build_fires_on_examples(component, app_tok, config.max_examples) says_examples = build_says_examples(component, app_tok, config.max_examples) - if component.firing_density > 0.0: - rate_str = f"~1 in {int(1 / component.firing_density)} tokens" - else: - rate_str = "extremely rare" + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) @@ -77,48 +72,64 @@ def format_prompt( else "" ) - return f"""\ -Describe what this neural network component does. - -Each component is a learned linear transformation inside a weight matrix. It has an input function (what causes it to fire) and an output function (what tokens it causes the model to produce). These are often different — a component might fire on periods but produce sentence-opening words, or fire on prepositions but produce abstract nouns. - -Consider all of the evidence below critically. Token statistics can be noisy, especially for high-density components. The activation examples are sampled and may not be representative. Look for patterns that are consistent across multiple sources of evidence. - -## Context -- Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line} -- Component location: {layer_desc} -- Component firing rate: {component.firing_density * 100:.2f}% ({rate_str}) - -{context_notes} - -## Output tokens (what the model produces when this component fires) - -{output_section} -## Input tokens (what causes this component to fire) - -{input_section} -## Activation examples — where the component fires - -<> mark tokens where this component is active. + md = Md() + md.p( + "Describe what this neural network component does.\n\n" + "Each component is a learned linear transformation inside a weight matrix. " + "It has an input function (what causes it to fire) and an output function " + "(what tokens it causes the model to produce). These are often different — " + "a component might fire on periods but produce sentence-opening words, or " + "fire on prepositions but produce abstract nouns.\n\n" + "Consider all of the evidence below critically. Token statistics can be noisy, " + "especially for high-density components. The activation examples are sampled " + "and may not be representative. Look for patterns that are consistent across " + "multiple sources of evidence." + ) -{fires_on_examples} -## Activation examples — what the model produces + md.h(2, "Context").bullets( + [ + f"Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line}", + f"Component location: {layer_desc}", + f"Component firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) + if context_notes: + md.p(context_notes) -Same examples with <> shifted right by one — showing the token that follows each firing position. + md.h(2, "Output tokens (what the model produces when this component fires)") + md.extend(output_section) -{says_examples} + md.h(2, "Input tokens (what causes this component to fire)") + md.extend(input_section) -## Task + md.h(2, "Activation examples — where the component fires") + md.p("<> mark tokens where this component is active.") + md.extend(fires_on_examples) -Give a {config.label_max_words}-word-or-fewer label describing this component's function. The label should read like a short description of the job this component does in the network. Use both the input and output evidence. + md.h(2, "Activation examples — what the model produces") + md.p( + "Same examples with <> shifted right by one — " + "showing the token that follows each firing position." + ) + md.extend(says_examples) -Examples of good labels across different component types: -- "word stem completion (stems → suffixes)" -- "closes dialogue with quotation marks" -- "object pronouns after verbs" -- "story-ending moral resolution vocabulary" -- "aquatic scene vocabulary (frog, river, pond)" -- "'of course' and abstract nouns after prepositions" + md.h(2, "Task") + md.p( + f"Give a {config.label_max_words}-word-or-fewer label describing this component's " + "function. The label should read like a short description of the job this component " + "does in the network. Use both the input and output evidence." + ) + md.p( + "Examples of good labels across different component types:\n" + '- "word stem completion (stems → suffixes)"\n' + '- "closes dialogue with quotation marks"\n' + '- "object pronouns after verbs"\n' + '- "story-ending moral resolution vocabulary"\n' + '- "aquatic scene vocabulary (frog, river, pond)"\n' + "- \"'of course' and abstract nouns after prepositions\"" + ) + md.p( + f'Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only.' + ) -Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only. -""" + return md.build() diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py index 00b6e243b..fa052873c 100644 --- a/spd/graph_interp/interpret.py +++ b/spd/graph_interp/interpret.py @@ -154,14 +154,19 @@ def jobs() -> Iterable[LLMJob]: assert o_stats is not None, f"No output token stats for {key}" related = get_related(key, labels_so_far) - db.save_prompt_edges([ - PromptEdge( - component_key=key, related_key=r.component_key, - pass_name="output", attribution=r.attribution, - related_label=r.label, related_confidence=r.confidence, - ) - for r in related - ]) + db.save_prompt_edges( + [ + PromptEdge( + component_key=key, + related_key=r.component_key, + pass_name="output", + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] + ) prompt = format_output_prompt( component=component, model_metadata=model_metadata, @@ -189,14 +194,19 @@ def jobs() -> Iterable[LLMJob]: assert i_stats is not None, f"No input token stats for {key}" related = get_related(key, labels_so_far) - db.save_prompt_edges([ - PromptEdge( - component_key=key, related_key=r.component_key, - pass_name="input", attribution=r.attribution, - related_label=r.label, related_confidence=r.confidence, - ) - for r in related - ]) + db.save_prompt_edges( + [ + PromptEdge( + component_key=key, + related_key=r.component_key, + pass_name="input", + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] + ) prompt = format_input_prompt( component=component, model_metadata=model_metadata, @@ -368,5 +378,3 @@ def _check_error_rate(n_errors: int, n_done: int) -> None: raise RuntimeError( f"Error rate {n_errors / total:.0%} ({n_errors}/{total}) exceeds 5% threshold" ) - - diff --git a/spd/graph_interp/prompts.py b/spd/graph_interp/prompts.py index f2160ed0c..e5f44f2dd 100644 --- a/spd/graph_interp/prompts.py +++ b/spd/graph_interp/prompts.py @@ -15,12 +15,14 @@ density_note, human_layer_desc, layer_position_note, + token_pmi_pairs, ) from spd.autointerp.schemas import ModelMetadata from spd.graph_interp.graph_context import RelatedComponent from spd.graph_interp.schemas import LabelResult from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md LABEL_SCHEMA: dict[str, object] = { "type": "object", @@ -33,11 +35,17 @@ "additionalProperties": False, } +JSON_INSTRUCTION = ( + 'Respond with JSON: {"label": "...", "confidence": "low|medium|high", "reasoning": "..."}' +) + +UNCLEAR_NOTE = 'Say "unclear" if the evidence is too weak.' + def _component_header( component: ComponentData, model_metadata: ModelMetadata, -) -> str: +) -> Md: canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) position_note = layer_position_note(canonical, model_metadata.n_blocks) @@ -49,13 +57,17 @@ def _component_header( else "extremely rare" ) + md = Md() + md.h(2, "Context").bullets( + [ + f"Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model", + f"Firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) context_notes = " ".join(filter(None, [position_note, dens_note])) - - return f"""\ -## Context -- Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model -- Firing rate: {component.firing_density * 100:.2f}% ({rate_str}) -{context_notes}""" + if context_notes: + md.p(context_notes) + return md def format_output_prompt( @@ -67,36 +79,35 @@ def format_output_prompt( label_max_words: int, max_examples: int, ) -> str: - header = _component_header(component, model_metadata) + output_pmi = token_pmi_pairs(app_tok, component.output_token_pmi.top) - output_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None + md = Md() + md.p( + "You are analyzing a component in a neural network to understand its " + "OUTPUT FUNCTION — what it does when it fires." ) - output_section = build_output_section(output_token_stats, output_pmi) - says = build_says_examples(component, app_tok, max_examples) - related_table = _format_related_table(related, model_metadata, app_tok) + md.extend(_component_header(component, model_metadata)) - return f"""\ -You are analyzing a component in a neural network to understand its OUTPUT FUNCTION — what it does when it fires. + md.h(2, "Output tokens (what the model produces when this component fires)") + md.extend(build_output_section(output_token_stats, output_pmi)) -{header} + md.h(2, "Activation examples — what the model produces") + md.extend(build_says_examples(component, app_tok, max_examples)) -## Output tokens (what the model produces when this component fires) -{output_section} -## Activation examples — what the model produces -{says} -## Downstream components (what this component influences) -These components in later layers are most influenced by this component (by gradient attribution): -{related_table} -## Task -Give a {label_max_words}-word-or-fewer label describing this component's OUTPUT FUNCTION — what it does when it fires. + md.h(2, "Downstream components (what this component influences)") + md.p( + "These components in later layers are most influenced by this component (by gradient attribution):" + ) + md.extend(_format_related(related, model_metadata, app_tok)) -Say "unclear" if the evidence is too weak. + md.h(2, "Task") + md.p( + f"Give a {label_max_words}-word-or-fewer label describing this component's " + "OUTPUT FUNCTION — what it does when it fires." + ) + md.p(UNCLEAR_NOTE).p(JSON_INSTRUCTION) -Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} -""" + return md.build() def format_input_prompt( @@ -108,36 +119,33 @@ def format_input_prompt( label_max_words: int, max_examples: int, ) -> str: - header = _component_header(component, model_metadata) + input_pmi = token_pmi_pairs(app_tok, component.input_token_pmi.top) - input_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None + md = Md() + md.p( + "You are analyzing a component in a neural network to understand its " + "INPUT FUNCTION — what triggers it to fire." ) - input_section = build_input_section(input_token_stats, input_pmi) - fires_on = build_fires_on_examples(component, app_tok, max_examples) - related_table = _format_related_table(related, model_metadata, app_tok) + md.extend(_component_header(component, model_metadata)) - return f"""\ -You are analyzing a component in a neural network to understand its INPUT FUNCTION — what triggers it to fire. + md.h(2, "Input tokens (what causes this component to fire)") + md.extend(build_input_section(input_token_stats, input_pmi)) -{header} + md.h(2, "Activation examples — where the component fires") + md.extend(build_fires_on_examples(component, app_tok, max_examples)) -## Input tokens (what causes this component to fire) -{input_section} -## Activation examples — where the component fires -{fires_on} -## Upstream components (what feeds into this component) -These components in earlier layers most strongly attribute to this component: -{related_table} -## Task -Give a {label_max_words}-word-or-fewer label describing this component's INPUT FUNCTION — what conditions trigger it to fire. + md.h(2, "Upstream components (what feeds into this component)") + md.p("These components in earlier layers most strongly attribute to this component:") + md.extend(_format_related(related, model_metadata, app_tok)) -Say "unclear" if the evidence is too weak. + md.h(2, "Task") + md.p( + f"Give a {label_max_words}-word-or-fewer label describing this component's " + "INPUT FUNCTION — what conditions trigger it to fire." + ) + md.p(UNCLEAR_NOTE).p(JSON_INSTRUCTION) -Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} -""" + return md.build() def format_unification_prompt( @@ -149,45 +157,47 @@ def format_unification_prompt( label_max_words: int, max_examples: int, ) -> str: - header = _component_header(component, model_metadata) - fires_on = build_fires_on_examples(component, app_tok, max_examples) - says = build_says_examples(component, app_tok, max_examples) - - return f"""\ -A neural network component has been analyzed from two perspectives. - -{header} - -## Activation examples — where the component fires -{fires_on} -## Activation examples — what the model produces -{says} -## Two-perspective analysis - -OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence}) - Reasoning: {output_label.reasoning} + md = Md() + md.p("A neural network component has been analyzed from two perspectives.") + md.extend(_component_header(component, model_metadata)) + + md.h(2, "Activation examples — where the component fires") + md.extend(build_fires_on_examples(component, app_tok, max_examples)) + + md.h(2, "Activation examples — what the model produces") + md.extend(build_says_examples(component, app_tok, max_examples)) + + md.h(2, "Two-perspective analysis") + md.p( + f'OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence})\n' + f" Reasoning: {output_label.reasoning}\n\n" + f'INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence})\n' + f" Reasoning: {input_label.reasoning}" + ) -INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence}) - Reasoning: {input_label.reasoning} + md.h(2, "Task") + md.p( + f"Synthesize these into a single unified label (max {label_max_words} words) " + "that captures the component's complete role. If input and output suggest the " + "same concept, unify them. If they describe genuinely different aspects " + "(e.g. fires on X, produces Y), combine both." + ) + md.p(JSON_INSTRUCTION) -## Task -Synthesize these into a single unified label (max {label_max_words} words) that captures the component's complete role. If input and output suggest the same concept, unify them. If they describe genuinely different aspects (e.g. fires on X, produces Y), combine both. + return md.build() -Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} -""" - -def _format_related_table( +def _format_related( components: list[RelatedComponent], model_metadata: ModelMetadata, app_tok: AppTokenizer, -) -> str: - # Filter: only show labeled components and token entries (embed/output) +) -> Md: visible = [n for n in components if n.label is not None or _is_token_entry(n.component_key)] + md = Md() if not visible: - return "(no related components with labels found)\n" + md.p("(no related components with labels found)") + return md - # Normalize attributions: strongest = 1.0 max_attr = max(abs(n.attribution) for n in visible) norm = max_attr if max_attr > 0 else 1.0 @@ -206,7 +216,8 @@ def _format_related_table( line += f'\n label: "{n.label}" (confidence: {n.confidence})' lines.append(line) - return "\n".join(lines) + "\n" + md.p("\n".join(lines)) + return md def _is_token_entry(key: str) -> bool: diff --git a/spd/graph_interp/repo.py b/spd/graph_interp/repo.py index 6667c4e1e..a76590e3d 100644 --- a/spd/graph_interp/repo.py +++ b/spd/graph_interp/repo.py @@ -11,7 +11,8 @@ import yaml -from spd.graph_interp.db import DONE_MARKER, GraphInterpDB +from spd.autointerp.db import DONE_MARKER +from spd.graph_interp.db import GraphInterpDB from spd.graph_interp.schemas import LabelResult, PromptEdge, get_graph_interp_dir diff --git a/spd/harvest/analysis.py b/spd/harvest/analysis.py index d0d92aac3..739e78fe9 100644 --- a/spd/harvest/analysis.py +++ b/spd/harvest/analysis.py @@ -106,7 +106,6 @@ def get_correlated_components( return output - def has_component(storage: CorrelationStorage, component_key: str) -> bool: """Check if a component exists in the storage.""" return component_key in storage.key_to_idx diff --git a/spd/utils/markdown.py b/spd/utils/markdown.py index f4f270250..0098c4d8e 100644 --- a/spd/utils/markdown.py +++ b/spd/utils/markdown.py @@ -15,12 +15,8 @@ class Md: def __init__(self) -> None: self._blocks: list[str] = [] - def h2(self, text: str) -> "Md": - self._blocks.append(f"## {text}") - return self - - def h3(self, text: str) -> "Md": - self._blocks.append(f"### {text}") + def h(self, level: int, text: str) -> "Md": + self._blocks.append(f"{'#' * level} {text}") return self def p(self, text: str) -> "Md": @@ -31,6 +27,11 @@ def bullets(self, items: list[str]) -> "Md": self._blocks.append("\n".join(f"- {item}" for item in items)) return self + def labeled_list(self, label: str, items: list[str]) -> "Md": + lines = [label] + [f"- {item}" for item in items] + self._blocks.append("\n".join(lines)) + return self + def numbered(self, items: list[str]) -> "Md": self._blocks.append("\n".join(f"{i}. {item}" for i, item in enumerate(items, 1))) return self diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py index 969973bd7..a972c1107 100644 --- a/tests/dataset_attributions/test_storage.py +++ b/tests/dataset_attributions/test_storage.py @@ -195,7 +195,7 @@ def test_single_file(self, tmp_path: Path): def _deterministic_storage( - regular_val: float = 10.0, + _regular_val: float = 10.0, embed_val: float = 6.0, ci_sum_val: float = 50.0, act_sq_sum_val: float = 400.0, From b3f99ee3c160d1dc8016cca63ba509fdc9804e3b Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 17:28:33 +0000 Subject: [PATCH 18/20] Simplify non-app Python modules: deduplicate graph_interp, autointerp, investigate - graph_interp/db.py: Extract parameterized _save_label/_get_label/_get_all_labels from 3x3 duplicated CRUD methods - graph_interp/interpret.py: Unify process_output_layer/process_input_layer via _make_process_layer factory - autointerp/prompt_helpers.py: Deduplicate build_fires_on_examples/build_says_examples into _build_examples - graph_interp/prompts.py: Simplify _format_related string building with f-string - investigate/agent_prompt.py: Replace repetitive config blocks with data-driven loop - investigate/scripts/run_agent.py: Remove obvious docstrings, simplify fetch_model_info Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/autointerp/prompt_helpers.py | 33 ++--- spd/graph_interp/db.py | 81 +++++------- spd/graph_interp/interpret.py | 183 ++++++++++++--------------- spd/graph_interp/prompts.py | 9 +- spd/investigate/agent_prompt.py | 26 ++-- spd/investigate/scripts/run_agent.py | 13 +- 6 files changed, 146 insertions(+), 199 deletions(-) diff --git a/spd/autointerp/prompt_helpers.py b/spd/autointerp/prompt_helpers.py index 823bf38ef..7b2da592e 100644 --- a/spd/autointerp/prompt_helpers.py +++ b/spd/autointerp/prompt_helpers.py @@ -128,36 +128,37 @@ def build_input_section( return md -def build_fires_on_examples( +def _build_examples( component: ComponentData, app_tok: AppTokenizer, max_examples: int, + shift_firings: bool, ) -> Md: lines: list[str] = [] for i, ex in enumerate(component.activation_examples[:max_examples]): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - tokens = list(zip(spans, ex.firings, strict=True)) - lines.append(f"{i + 1}. {delimit_tokens(tokens)}") + if not any(ex.firings): + continue + spans = app_tok.get_spans(ex.token_ids) + firings = [False] + ex.firings[:-1] if shift_firings else ex.firings + tokens = list(zip(spans, firings, strict=True)) + lines.append(f"{i + 1}. {delimit_tokens(tokens)}") md = Md() if lines: md.p("\n".join(lines)) return md +def build_fires_on_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> Md: + return _build_examples(component, app_tok, max_examples, shift_firings=False) + + def build_says_examples( component: ComponentData, app_tok: AppTokenizer, max_examples: int, ) -> Md: - lines: list[str] = [] - for i, ex in enumerate(component.activation_examples[:max_examples]): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - shifted_firings = [False] + ex.firings[:-1] - tokens = list(zip(spans, shifted_firings, strict=True)) - lines.append(f"{i + 1}. {delimit_tokens(tokens)}") - md = Md() - if lines: - md.p("\n".join(lines)) - return md + return _build_examples(component, app_tok, max_examples, shift_firings=True) diff --git a/spd/graph_interp/db.py b/spd/graph_interp/db.py index 1d10cda29..1a6a83e37 100644 --- a/spd/graph_interp/db.py +++ b/spd/graph_interp/db.py @@ -48,6 +48,9 @@ """ +_LABEL_TABLES = ("output_labels", "input_labels", "unified_labels") + + class GraphInterpDB: """NFS-hosted. Uses open_nfs_sqlite (no WAL). Single writer, then read-only.""" @@ -60,11 +63,12 @@ def __init__(self, db_path: Path, readonly: bool = False) -> None: def mark_done(self) -> None: (self._db_path.parent / DONE_MARKER).touch() - # -- Output labels --------------------------------------------------------- + # -- Label CRUD (shared across output/input/unified) ----------------------- - def save_output_label(self, result: LabelResult) -> None: + def _save_label(self, table: str, result: LabelResult) -> None: + assert table in _LABEL_TABLES self._conn.execute( - "INSERT OR REPLACE INTO output_labels VALUES (?, ?, ?, ?, ?, ?)", + f"INSERT OR REPLACE INTO {table} VALUES (?, ?, ?, ?, ?, ?)", ( result.component_key, result.label, @@ -76,73 +80,52 @@ def save_output_label(self, result: LabelResult) -> None: ) self._conn.commit() - def get_output_label(self, component_key: str) -> LabelResult | None: + def _get_label(self, table: str, component_key: str) -> LabelResult | None: + assert table in _LABEL_TABLES row = self._conn.execute( - "SELECT * FROM output_labels WHERE component_key = ?", (component_key,) + f"SELECT * FROM {table} WHERE component_key = ?", (component_key,) ).fetchone() if row is None: return None return _row_to_label_result(row) - def get_all_output_labels(self) -> dict[str, LabelResult]: - rows = self._conn.execute("SELECT * FROM output_labels").fetchall() + def _get_all_labels(self, table: str) -> dict[str, LabelResult]: + assert table in _LABEL_TABLES + rows = self._conn.execute(f"SELECT * FROM {table}").fetchall() return {row["component_key"]: _row_to_label_result(row) for row in rows} + # -- Output labels --------------------------------------------------------- + + def save_output_label(self, result: LabelResult) -> None: + self._save_label("output_labels", result) + + def get_output_label(self, component_key: str) -> LabelResult | None: + return self._get_label("output_labels", component_key) + + def get_all_output_labels(self) -> dict[str, LabelResult]: + return self._get_all_labels("output_labels") + # -- Input labels ---------------------------------------------------------- def save_input_label(self, result: LabelResult) -> None: - self._conn.execute( - "INSERT OR REPLACE INTO input_labels VALUES (?, ?, ?, ?, ?, ?)", - ( - result.component_key, - result.label, - result.confidence, - result.reasoning, - result.raw_response, - result.prompt, - ), - ) - self._conn.commit() + self._save_label("input_labels", result) def get_input_label(self, component_key: str) -> LabelResult | None: - row = self._conn.execute( - "SELECT * FROM input_labels WHERE component_key = ?", (component_key,) - ).fetchone() - if row is None: - return None - return _row_to_label_result(row) + return self._get_label("input_labels", component_key) def get_all_input_labels(self) -> dict[str, LabelResult]: - rows = self._conn.execute("SELECT * FROM input_labels").fetchall() - return {row["component_key"]: _row_to_label_result(row) for row in rows} + return self._get_all_labels("input_labels") # -- Unified labels -------------------------------------------------------- def save_unified_label(self, result: LabelResult) -> None: - self._conn.execute( - "INSERT OR REPLACE INTO unified_labels VALUES (?, ?, ?, ?, ?, ?)", - ( - result.component_key, - result.label, - result.confidence, - result.reasoning, - result.raw_response, - result.prompt, - ), - ) - self._conn.commit() + self._save_label("unified_labels", result) def get_unified_label(self, component_key: str) -> LabelResult | None: - row = self._conn.execute( - "SELECT * FROM unified_labels WHERE component_key = ?", (component_key,) - ).fetchone() - if row is None: - return None - return _row_to_label_result(row) + return self._get_label("unified_labels", component_key) def get_all_unified_labels(self) -> dict[str, LabelResult]: - rows = self._conn.execute("SELECT * FROM unified_labels").fetchall() - return {row["component_key"]: _row_to_label_result(row) for row in rows} + return self._get_all_labels("unified_labels") def get_completed_unified_keys(self) -> set[str]: rows = self._conn.execute("SELECT component_key FROM unified_labels").fetchall() @@ -178,12 +161,10 @@ def get_all_prompt_edges(self) -> list[PromptEdge]: rows = self._conn.execute("SELECT * FROM prompt_edges").fetchall() return [_row_to_prompt_edge(row) for row in rows] - # -- Config ---------------------------------------------------------------- - # -- Stats ----------------------------------------------------------------- def get_label_count(self, table: str) -> int: - assert table in ("output_labels", "input_labels", "unified_labels") + assert table in _LABEL_TABLES row = self._conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() assert row is not None return row[0] diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py index fa052873c..f0c92725a 100644 --- a/spd/graph_interp/interpret.py +++ b/spd/graph_interp/interpret.py @@ -12,7 +12,6 @@ import asyncio from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable -from functools import partial from pathlib import Path from typing import Literal @@ -36,13 +35,15 @@ format_unification_prompt, ) from spd.graph_interp.schemas import LabelResult, PromptEdge -from spd.harvest.analysis import get_input_token_stats, get_output_token_stats +from spd.harvest.analysis import TokenPRLift, get_input_token_stats, get_output_token_stats from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData from spd.harvest.storage import CorrelationStorage, TokenStatsStorage from spd.log import logger GetRelated = Callable[[str, dict[str, LabelResult]], list[RelatedComponent]] Step = Callable[[list[str], dict[str, LabelResult]], Awaitable[dict[str, LabelResult]]] +MakePrompt = Callable[["ComponentData", "TokenPRLift", list[RelatedComponent]], str] def run_graph_interp( @@ -106,23 +107,13 @@ def _to_canon(concrete_key: str) -> str: layer, idx = concrete_key.rsplit(":", 1) return f"{concrete_to_canon[layer]}:{idx}" - def _make_get_targets(metric: AttrMetric) -> "graph_context.GetAttributed": + def _make_get_attributed( + method: Callable[..., list[DatasetAttributionEntry]], metric: AttrMetric + ) -> "graph_context.GetAttributed": def get( key: str, k: int, sign: Literal["positive", "negative"] ) -> list[DatasetAttributionEntry]: - return _translate_entries( - attribution_storage.get_top_targets(_to_canon(key), k=k, sign=sign, metric=metric) - ) - - return get - - def _make_get_sources(metric: AttrMetric) -> "graph_context.GetAttributed": - def get( - key: str, k: int, sign: Literal["positive", "negative"] - ) -> list[DatasetAttributionEntry]: - return _translate_entries( - attribution_storage.get_top_sources(_to_canon(key), k=k, sign=sign, metric=metric) - ) + return _translate_entries(method(_to_canon(key), k=k, sign=sign, metric=metric)) return get @@ -138,87 +129,49 @@ def get(key: str, labels_so_far: dict[str, LabelResult]) -> list[RelatedComponen return get - # -- Layer processors ------------------------------------------------------ - - async def process_output_layer( - get_related: GetRelated, - save_label: Callable[[LabelResult], None], - pending: list[str], - labels_so_far: dict[str, LabelResult], - ) -> dict[str, LabelResult]: - def jobs() -> Iterable[LLMJob]: - for key in pending: - component = harvest.get_component(key) - assert component is not None, f"Component {key} not found in harvest DB" - o_stats = get_output_token_stats(token_stats, key, app_tok, top_k=50) - assert o_stats is not None, f"No output token stats for {key}" - - related = get_related(key, labels_so_far) - db.save_prompt_edges( - [ - PromptEdge( - component_key=key, - related_key=r.component_key, - pass_name="output", - attribution=r.attribution, - related_label=r.label, - related_confidence=r.confidence, - ) - for r in related - ] - ) - prompt = format_output_prompt( - component=component, - model_metadata=model_metadata, - app_tok=app_tok, - output_token_stats=o_stats, - related=related, - label_max_words=config.label_max_words, - max_examples=config.max_examples, - ) - yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) - - return await _collect_labels(llm_map, jobs(), len(pending), save_label) + # -- Layer processor (shared for output and input passes) -------------------- - async def process_input_layer( + def _make_process_layer( get_related: GetRelated, save_label: Callable[[LabelResult], None], - pending: list[str], - labels_so_far: dict[str, LabelResult], - ) -> dict[str, LabelResult]: - def jobs() -> Iterable[LLMJob]: - for key in pending: - component = harvest.get_component(key) - assert component is not None, f"Component {key} not found in harvest DB" - i_stats = get_input_token_stats(token_stats, key, app_tok, top_k=20) - assert i_stats is not None, f"No input token stats for {key}" - - related = get_related(key, labels_so_far) - db.save_prompt_edges( - [ - PromptEdge( - component_key=key, - related_key=r.component_key, - pass_name="input", - attribution=r.attribution, - related_label=r.label, - related_confidence=r.confidence, - ) - for r in related - ] - ) - prompt = format_input_prompt( - component=component, - model_metadata=model_metadata, - app_tok=app_tok, - input_token_stats=i_stats, - related=related, - label_max_words=config.label_max_words, - max_examples=config.max_examples, - ) - yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) - - return await _collect_labels(llm_map, jobs(), len(pending), save_label) + pass_name: Literal["output", "input"], + get_token_stats: Callable[[str], TokenPRLift | None], + make_prompt: MakePrompt, + ) -> Step: + async def process( + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + stats = get_token_stats(key) + assert stats is not None, f"No {pass_name} token stats for {key}" + + related = get_related(key, labels_so_far) + db.save_prompt_edges( + [ + PromptEdge( + component_key=key, + related_key=r.component_key, + pass_name=pass_name, + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] + ) + yield LLMJob( + prompt=make_prompt(component, stats, related), + schema=LABEL_SCHEMA, + key=key, + ) + + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + return process # -- Scan (fold over layers) ----------------------------------------------- @@ -292,18 +245,48 @@ def jobs() -> Iterable[LLMJob]: db = GraphInterpDB(db_path) metric = config.attr_metric - get_targets = _make_get_targets(metric) - get_sources = _make_get_sources(metric) + get_targets = _make_get_attributed(attribution_storage.get_top_targets, metric) + get_sources = _make_get_attributed(attribution_storage.get_top_sources, metric) + + def _output_prompt( + component: ComponentData, stats: TokenPRLift, related: list[RelatedComponent] + ) -> str: + return format_output_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + output_token_stats=stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + + def _input_prompt( + component: ComponentData, stats: TokenPRLift, related: list[RelatedComponent] + ) -> str: + return format_input_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + input_token_stats=stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) - label_output = partial( - process_output_layer, + label_output = _make_process_layer( _get_related(get_targets), db.save_output_label, + "output", + lambda key: get_output_token_stats(token_stats, key, app_tok, top_k=50), + _output_prompt, ) - label_input = partial( - process_input_layer, + label_input = _make_process_layer( _get_related(get_sources), db.save_input_label, + "input", + lambda key: get_input_token_stats(token_stats, key, app_tok, top_k=20), + _input_prompt, ) async def _run() -> None: diff --git a/spd/graph_interp/prompts.py b/spd/graph_interp/prompts.py index e5f44f2dd..f00a0d671 100644 --- a/spd/graph_interp/prompts.py +++ b/spd/graph_interp/prompts.py @@ -205,13 +205,8 @@ def _format_related( for n in visible: display = _component_display(n.component_key, model_metadata, app_tok) rel_attr = n.attribution / norm - - parts = [f" {display} (relative attribution: {rel_attr:+.2f}"] - if n.pmi is not None: - parts.append(f", co-firing PMI: {n.pmi:.2f}") - parts.append(")") - - line = "".join(parts) + pmi_str = f", co-firing PMI: {n.pmi:.2f}" if n.pmi is not None else "" + line = f" {display} (relative attribution: {rel_attr:+.2f}{pmi_str})" if n.label is not None: line += f'\n label: "{n.label}" (confidence: {n.confidence})' lines.append(line) diff --git a/spd/investigate/agent_prompt.py b/spd/investigate/agent_prompt.py index d53a47ac3..1e68073f1 100644 --- a/spd/investigate/agent_prompt.py +++ b/spd/investigate/agent_prompt.py @@ -160,27 +160,23 @@ def _format_model_info(model_info: dict[str, Any]) -> str: - """Format model architecture info for inclusion in the agent prompt.""" parts = [f"- **Architecture**: {model_info.get('summary', 'Unknown')}"] - target_config = model_info.get("target_model_config") - if target_config: - if "n_layer" in target_config: - parts.append(f"- **Layers**: {target_config['n_layer']}") - if "n_embd" in target_config: - parts.append(f"- **Hidden dim**: {target_config['n_embd']}") - if "vocab_size" in target_config: - parts.append(f"- **Vocab size**: {target_config['vocab_size']}") - if "n_ctx" in target_config: - parts.append(f"- **Context length**: {target_config['n_ctx']}") + tc = model_info.get("target_model_config", {}) + for key, label in [ + ("n_layer", "Layers"), + ("n_embd", "Hidden dim"), + ("vocab_size", "Vocab size"), + ("n_ctx", "Context length"), + ]: + if key in tc: + parts.append(f"- **{label}**: {tc[key]}") topology = model_info.get("topology") if topology and topology.get("block_structure"): block = topology["block_structure"][0] - attn = ", ".join(block.get("attn_projections", [])) - ffn = ", ".join(block.get("ffn_projections", [])) - parts.append(f"- **Attention projections**: {attn}") - parts.append(f"- **FFN projections**: {ffn}") + parts.append(f"- **Attention projections**: {', '.join(block.get('attn_projections', []))}") + parts.append(f"- **FFN projections**: {', '.join(block.get('ffn_projections', []))}") return "\n".join(parts) diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py index 54806ed36..1cf9910fb 100644 --- a/spd/investigate/scripts/run_agent.py +++ b/spd/investigate/scripts/run_agent.py @@ -30,7 +30,6 @@ def write_mcp_config(inv_dir: Path, port: int) -> Path: - """Write MCP configuration file for Claude Code.""" mcp_config = { "mcpServers": { "spd": { @@ -45,7 +44,6 @@ def write_mcp_config(inv_dir: Path, port: int) -> Path: def write_claude_settings(inv_dir: Path) -> None: - """Write Claude Code settings to pre-grant MCP tool permissions.""" claude_dir = inv_dir / ".claude" claude_dir.mkdir(exist_ok=True) settings = {"permissions": {"allow": ["mcp__spd__*"]}} @@ -53,7 +51,6 @@ def write_claude_settings(inv_dir: Path) -> None: def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: - """Find an available port starting from start_port.""" for offset in range(max_attempts): port = start_port + offset with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -68,7 +65,6 @@ def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: def wait_for_backend(port: int, timeout: float = 120.0) -> bool: - """Wait for the backend to become healthy.""" url = f"http://localhost:{port}/api/health" start = time.time() while time.time() - start < timeout: @@ -83,7 +79,6 @@ def wait_for_backend(port: int, timeout: float = 120.0) -> bool: def load_run(port: int, wandb_path: str, context_length: int) -> None: - """Load the SPD run into the backend. Raises on failure.""" url = f"http://localhost:{port}/api/runs/load" params = {"wandb_path": wandb_path, "context_length": context_length} resp = requests.post(url, params=params, timeout=300) @@ -93,15 +88,12 @@ def load_run(port: int, wandb_path: str, context_length: int) -> None: def fetch_model_info(port: int) -> dict[str, Any]: - """Fetch model architecture info from the backend.""" resp = requests.get(f"http://localhost:{port}/api/pretrain_info/loaded", timeout=30) assert resp.status_code == 200, f"Failed to fetch model info: {resp.status_code} {resp.text}" - result: dict[str, Any] = resp.json() - return result + return resp.json() def log_event(events_path: Path, event: InvestigationEvent) -> None: - """Append an event to the events log.""" with open(events_path, "a") as f: f.write(event.model_dump_json() + "\n") @@ -174,8 +166,7 @@ def run_agent( stderr=subprocess.STDOUT, ) - def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: - _ = frame + def cleanup(signum: int | None = None, _frame: FrameType | None = None) -> None: logger.info(f"[{inv_id}] Cleaning up...") if backend_proc.poll() is None: backend_proc.terminate() From f51c0fcb28a9b9fc0dffd01a13ad95ca9a64c6c9 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 6 Mar 2026 17:33:58 +0000 Subject: [PATCH 19/20] Simplify app backend + frontend: deduplicate, remove dead code, fix stale docs Backend: - graphs.py: Extract _build_loss_config, _build_loss_result, _maybe_pgd_config, _maybe_adv_pgd helpers - server.py: Move deferred stdlib imports to module-level - __init__.py: Fix __all__ ordering - CLAUDE.md: Remove duplicate router entries - sqlite.py: Fix stale docstring referencing old DB location Frontend components: - Deduplicate getTopEdgeAttributions into shared topEdgeAttributions() in promptAttributionsTypes.ts - Extract generic parseSSEStream() in graphs.ts, eliminating ~50 lines of duplicated SSE parsing - Extract AVAILABILITY_COLUMNS in RunSelector, reducing ~60 lines of duplicated template - Eliminate redundant computeMaxAbsComponentAct in ActivationContextsViewer + ClusterComponentCard - Fix unreachable null check in ClusterComponentCard - Fix mid-file import in ComponentNodeCard - Remove dead fork handler stubs in PromptAttributionsTab - Remove unused isRunEditable export, 5 unused CSS selectors, 12+ unnecessary comments Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/CLAUDE.md | 6 +- spd/app/backend/routers/__init__.py | 2 +- spd/app/backend/routers/graphs.py | 206 ++++++------------ spd/app/backend/server.py | 5 +- .../ActivationContextsViewer.svelte | 2 +- .../components/ClusterComponentCard.svelte | 4 +- .../src/components/ClustersViewer.svelte | 5 +- .../src/components/InvestigationsTab.svelte | 3 - .../components/PromptAttributionsTab.svelte | 4 - .../src/components/RunSelector.svelte | 100 +++------ .../investigations/ArtifactGraph.svelte | 16 +- .../prompt-attr/ComponentNodeCard.svelte | 24 +- .../prompt-attr/InterventionsView.svelte | 30 --- .../components/prompt-attr/NodeTooltip.svelte | 18 +- .../prompt-attr/OutputNodeCard.svelte | 17 +- spd/app/frontend/src/lib/api/graphs.ts | 82 ++----- spd/app/frontend/src/lib/interventionTypes.ts | 5 - .../src/lib/promptAttributionsTypes.ts | 17 ++ spd/utils/sqlite.py | 8 +- tests/app/test_server_api.py | 2 +- 20 files changed, 170 insertions(+), 386 deletions(-) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 9a7d8e4a5..7e86aa0fd 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -52,11 +52,9 @@ backend/ ├── intervention.py # Selective component activation ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering - ├── dataset_search.py # SimpleStories dataset search - ├── agents.py # Various useful endpoints that AI agents should look at when helping - ├── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ├── dataset_search.py # Dataset search (reads dataset from run config) - └── agents.py # Various useful endpoints that AI agents should look at when helping + ├── agents.py # Various useful endpoints that AI agents should look at when helping + └── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ``` Note: Activation contexts, correlations, and token stats are now loaded from pre-harvested data (see `spd/harvest/`). The app no longer computes these on-the-fly. diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index 7b1729fbb..7a3dfadcf 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -24,8 +24,8 @@ "correlations_router", "data_sources_router", "dataset_attributions_router", - "graph_interp_router", "dataset_search_router", + "graph_interp_router", "graphs_router", "intervention_router", "investigations_router", diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index fd43b3b1f..0727b3612 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -45,6 +45,7 @@ MaskType, MeanKLLossConfig, OptimCIConfig, + PositionalLossConfig, ) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors @@ -174,6 +175,58 @@ class LogitLossResult(BaseModel): label_str: str +LossType = Literal["ce", "kl", "logit"] +LossResult = CELossResult | KLLossResult | LogitLossResult + + +def _build_loss_config( + loss_type: LossType, + loss_coeff: float, + loss_position: int, + label_token: int | None, +) -> PositionalLossConfig: + match loss_type: + case "ce": + assert label_token is not None, "label_token is required for CE loss" + return CELossConfig(coeff=loss_coeff, position=loss_position, label_token=label_token) + case "kl": + return KLLossConfig(coeff=loss_coeff, position=loss_position) + case "logit": + assert label_token is not None, "label_token is required for logit loss" + return LogitLossConfig( + coeff=loss_coeff, position=loss_position, label_token=label_token + ) + + +def _build_loss_result( + loss_config: PositionalLossConfig, + tok_display: Callable[[int], str], +) -> LossResult: + match loss_config: + case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): + return CELossResult( + coeff=coeff, position=pos, label_token=label_tok, label_str=tok_display(label_tok) + ) + case KLLossConfig(coeff=coeff, position=pos): + return KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + return LogitLossResult( + coeff=coeff, position=pos, label_token=label_tok, label_str=tok_display(label_tok) + ) + + +def _maybe_pgd_config(n_steps: int | None, step_size: float | None) -> PgdConfig | None: + if n_steps is not None and step_size is not None: + return PgdConfig(n_steps=n_steps, step_size=step_size) + return None + + +def _maybe_adv_pgd(n_steps: int | None, step_size: float | None) -> AdvPGDConfig | None: + if n_steps is not None and step_size is not None: + return AdvPGDConfig(n_steps=n_steps, step_size=step_size, init="random") + return None + + class OptimizationMetricsResult(BaseModel): """Final loss metrics from CI optimization.""" @@ -670,9 +723,6 @@ def get_group_key(edge: Edge) -> str: return out_edges -LossType = Literal["ce", "kl", "logit"] - - @router.post("/optimized/stream") @log_errors def compute_graph_optimized_stream( @@ -699,27 +749,8 @@ def compute_graph_optimized_stream( label_token is required when loss_type is "ce". adv_pgd_n_steps and adv_pgd_step_size enable adversarial PGD when both are provided. """ - # Build loss config based on type - loss_config: LossConfig - match loss_type: - case "ce": - if label_token is None: - raise HTTPException(status_code=400, detail="label_token is required for CE loss") - loss_config = CELossConfig( - coeff=loss_coeff, position=loss_position, label_token=label_token - ) - case "kl": - loss_config = KLLossConfig(coeff=loss_coeff, position=loss_position) - case "logit": - if label_token is None: - raise HTTPException( - status_code=400, detail="label_token is required for logit loss" - ) - loss_config = LogitLossConfig( - coeff=loss_coeff, position=loss_position, label_token=label_token - ) - - lr = 1e-2 + loss_config = _build_loss_config(loss_type, loss_coeff, loss_position, label_token) + pgd = _maybe_pgd_config(adv_pgd_n_steps, adv_pgd_step_size) db = manager.db prompt = db.get_prompt(prompt_id) @@ -733,11 +764,9 @@ def compute_graph_optimized_stream( detail=f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens", ) - label_str = loaded.tokenizer.get_tok_display(label_token) if label_token is not None else None spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) - # Slice tokens to only include positions <= loss_position num_tokens = loss_position + 1 spans_sliced = spans[:num_tokens] @@ -748,14 +777,12 @@ def compute_graph_optimized_stream( beta=beta, mask_type=mask_type, loss=loss_config, - pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) - if adv_pgd_n_steps is not None and adv_pgd_step_size is not None - else None, + pgd=pgd, ) optim_config = OptimCIConfig( seed=0, - lr=lr, + lr=1e-2, steps=steps, weight_decay=0.0, lr_schedule="cosine", @@ -767,9 +794,7 @@ def compute_graph_optimized_stream( sampling=loaded.config.sampling, ce_kl_rounding_threshold=0.5, mask_type=mask_type, - adv_pgd=AdvPGDConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size, init="random") - if adv_pgd_n_steps is not None and adv_pgd_step_size is not None - else None, + adv_pgd=_maybe_adv_pgd(adv_pgd_n_steps, adv_pgd_step_size), ) def work( @@ -833,28 +858,6 @@ def work( raw_edges_abs=result.edges_abs, ) - # Build loss result based on config type - loss_result: CELossResult | KLLossResult | LogitLossResult - match loss_config: - case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): - assert label_str is not None - loss_result = CELossResult( - coeff=coeff, - position=pos, - label_token=label_tok, - label_str=label_str, - ) - case KLLossConfig(coeff=coeff, position=pos): - loss_result = KLLossResult(coeff=coeff, position=pos) - case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): - assert label_str is not None - loss_result = LogitLossResult( - coeff=coeff, - position=pos, - label_token=label_tok, - label_str=label_str, - ) - return GraphDataWithOptimization( id=graph_id, graphType="optimized", @@ -874,16 +877,14 @@ def work( pnorm=pnorm, beta=beta, mask_type=mask_type, - loss=loss_result, + loss=_build_loss_result(loss_config, loaded.tokenizer.get_tok_display), metrics=OptimizationMetricsResult( ci_masked_label_prob=result.metrics.ci_masked_label_prob, stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, l0_total=result.metrics.l0_total, ), - pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) - if adv_pgd_n_steps is not None and adv_pgd_step_size is not None - else None, + pgd=pgd, ), ) @@ -924,22 +925,11 @@ def compute_graph_optimized_batch_stream( assert len(body.imp_min_coeffs) > 0, "At least one coefficient required" assert len(body.imp_min_coeffs) <= 20, "Too many coefficients (max 20)" - loss_config: LossConfig - match body.loss_type: - case "ce": - assert body.label_token is not None, "label_token is required for CE loss" - loss_config = CELossConfig( - coeff=body.loss_coeff, position=body.loss_position, label_token=body.label_token - ) - case "kl": - loss_config = KLLossConfig(coeff=body.loss_coeff, position=body.loss_position) - case "logit": - assert body.label_token is not None, "label_token is required for logit loss" - loss_config = LogitLossConfig( - coeff=body.loss_coeff, position=body.loss_position, label_token=body.label_token - ) - - lr = 1e-2 + loss_config = _build_loss_config( + body.loss_type, body.loss_coeff, body.loss_position, body.label_token + ) + pgd = _maybe_pgd_config(body.adv_pgd_n_steps, body.adv_pgd_step_size) + adv_pgd = _maybe_adv_pgd(body.adv_pgd_n_steps, body.adv_pgd_step_size) db = manager.db prompt = db.get_prompt(body.prompt_id) @@ -950,25 +940,16 @@ def compute_graph_optimized_batch_stream( f"loss_position {body.loss_position} out of bounds for prompt with {len(token_ids)} tokens" ) - label_str = ( - loaded.tokenizer.get_tok_display(body.label_token) if body.label_token is not None else None - ) spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) num_tokens = body.loss_position + 1 spans_sliced = spans[:num_tokens] - adv_pgd = ( - AdvPGDConfig(n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size, init="random") - if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None - else None - ) - configs = [ OptimCIConfig( seed=0, - lr=lr, + lr=1e-2, steps=body.steps, weight_decay=0.0, lr_schedule="cosine", @@ -1014,9 +995,7 @@ def work( beta=body.beta, mask_type=body.mask_type, loss=loss_config, - pgd=PgdConfig(n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size) - if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None - else None, + pgd=pgd, ) opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob @@ -1061,21 +1040,6 @@ def work( raw_edges_abs=result.edges_abs, ) - loss_result: CELossResult | KLLossResult | LogitLossResult - match loss_config: - case CELossConfig(coeff=lc, position=pos, label_token=label_tok): - assert label_str is not None - loss_result = CELossResult( - coeff=lc, position=pos, label_token=label_tok, label_str=label_str - ) - case KLLossConfig(coeff=lc, position=pos): - loss_result = KLLossResult(coeff=lc, position=pos) - case LogitLossConfig(coeff=lc, position=pos, label_token=label_tok): - assert label_str is not None - loss_result = LogitLossResult( - coeff=lc, position=pos, label_token=label_tok, label_str=label_str - ) - graphs.append( GraphDataWithOptimization( id=graph_id, @@ -1096,18 +1060,14 @@ def work( pnorm=body.pnorm, beta=body.beta, mask_type=body.mask_type, - loss=loss_result, + loss=_build_loss_result(loss_config, loaded.tokenizer.get_tok_display), metrics=OptimizationMetricsResult( ci_masked_label_prob=result.metrics.ci_masked_label_prob, stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, l0_total=result.metrics.l0_total, ), - pgd=PgdConfig( - n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size - ) - if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None - else None, + pgd=pgd, ), ) ) @@ -1246,28 +1206,6 @@ def stored_graph_to_response( assert graph.optimization_params is not None opt = graph.optimization_params - # Build loss result based on stored config type - loss_result: CELossResult | KLLossResult | LogitLossResult - match opt.loss: - case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): - label_str = tokenizer.get_tok_display(label_tok) - loss_result = CELossResult( - coeff=coeff, - position=pos, - label_token=label_tok, - label_str=label_str, - ) - case KLLossConfig(coeff=coeff, position=pos): - loss_result = KLLossResult(coeff=coeff, position=pos) - case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): - label_str = tokenizer.get_tok_display(label_tok) - loss_result = LogitLossResult( - coeff=coeff, - position=pos, - label_token=label_tok, - label_str=label_str, - ) - return GraphDataWithOptimization( id=graph.id, graphType=graph.graph_type, @@ -1287,16 +1225,14 @@ def stored_graph_to_response( pnorm=opt.pnorm, beta=opt.beta, mask_type=opt.mask_type, - loss=loss_result, + loss=_build_loss_result(opt.loss, tokenizer.get_tok_display), metrics=OptimizationMetricsResult( l0_total=float(fg.l0_total), ci_masked_label_prob=opt.ci_masked_label_prob, stoch_masked_label_prob=opt.stoch_masked_label_prob, adv_pgd_label_prob=opt.adv_pgd_label_prob, ), - pgd=PgdConfig(n_steps=opt.pgd.n_steps, step_size=opt.pgd.step_size) - if opt.pgd is not None - else None, + pgd=opt.pgd, ), ) diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 89ac602b3..afbff5db6 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -8,10 +8,12 @@ python -m spd.app.backend.server --port 8000 """ +import os import time import traceback from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from pathlib import Path import fire import torch @@ -53,9 +55,6 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" - import os - from pathlib import Path - from spd.app.backend.routers.mcp import InvestigationConfig, set_investigation_config manager = StateManager.get() diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index 232e4cd39..3df0eeee7 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -292,7 +292,7 @@ tokens: d.example_tokens, ci: d.example_ci, componentActs: d.example_component_acts, - maxAbsComponentAct: computeMaxAbsComponentAct(d.example_component_acts), + maxAbsComponentAct, }), ), ); diff --git a/spd/app/frontend/src/components/ClusterComponentCard.svelte b/spd/app/frontend/src/components/ClusterComponentCard.svelte index d14015ace..e5c6769b5 100644 --- a/spd/app/frontend/src/components/ClusterComponentCard.svelte +++ b/spd/app/frontend/src/components/ClusterComponentCard.svelte @@ -89,7 +89,7 @@ tokens: d.example_tokens, ci: d.example_ci, componentActs: d.example_component_acts, - maxAbsComponentAct: computeMaxAbsComponentAct(d.example_component_acts), + maxAbsComponentAct, }), ), ); @@ -150,7 +150,7 @@
- {#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} + {#if componentData.tokenStats.status === "loading" || componentData.tokenStats.status === "uninitialized"} Loading token stats... {:else if componentData.tokenStats.status === "error"} Error: {String(componentData.tokenStats.error)} diff --git a/spd/app/frontend/src/components/ClustersViewer.svelte b/spd/app/frontend/src/components/ClustersViewer.svelte index 324b04e77..6ff194c82 100644 --- a/spd/app/frontend/src/components/ClustersViewer.svelte +++ b/spd/app/frontend/src/components/ClustersViewer.svelte @@ -30,10 +30,7 @@ if (clusterId === null) { singletons.push(member); } else { - if (!groups[clusterId]) { - groups[clusterId] = []; - } - groups[clusterId].push(member); + (groups[clusterId] ??= []).push(member); } } diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte index b7752cb5f..a7dea1423 100644 --- a/spd/app/frontend/src/components/InvestigationsTab.svelte +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -9,14 +9,12 @@ import type { Loadable } from "../lib"; import ResearchLogViewer from "./investigations/ResearchLogViewer.svelte"; - // State let investigations = $state>({ status: "uninitialized" }); let selected = $state | null>(null); let activeTab = $state<"research" | "events">("research"); let loadedArtifacts = $state>({}); let artifactsLoading = $state(false); - // Launch state let launchPrompt = $state(""); let launchState = $state>({ status: "uninitialized" }); @@ -34,7 +32,6 @@ } } - // Load investigations on mount $effect(() => { loadInvestigations(); }); diff --git a/spd/app/frontend/src/components/PromptAttributionsTab.svelte b/spd/app/frontend/src/components/PromptAttributionsTab.svelte index 3f7299a93..32c91f4eb 100644 --- a/spd/app/frontend/src/components/PromptAttributionsTab.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsTab.svelte @@ -557,10 +557,6 @@ } } - // Fork handlers commented out — functionality disabled for now - // async function handleForkRun(runId: number, tokenReplacements: [number, number][]) { ... } - // async function handleDeleteFork(forkId: number) { ... } - async function handleGenerateGraphFromSelection() { if (!activeCard || !activeGraph) return; const state = interventionStates[activeGraph.id]; diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index ef75c771f..19db5f3bc 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -1,7 +1,22 @@ diff --git a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte index 328e5449d..9dcab7f0e 100644 --- a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte @@ -1,5 +1,5 @@