diff --git a/.cursor/worktrees.json b/.cursor/worktrees.json deleted file mode 100644 index 47947e5dc..000000000 --- a/.cursor/worktrees.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "setup-worktree": [ - "make install-dev", - "make install-app" - ] -} diff --git a/.gitignore b/.gitignore index 51097c54d..3581e751a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ spd/scripts/sweep_params.yaml docs/coverage/** notebooks/** +scratch/ + +# Script outputs (generated files, often large) +scripts/outputs/ **/out/ neuronpedia_outputs/ @@ -173,4 +177,7 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* +*.schema.json + +.claude/worktrees \ No newline at end of file diff --git a/.mcp.json b/.mcp.json index fefb52c9a..700113020 100644 --- a/.mcp.json +++ b/.mcp.json @@ -1,8 +1,3 @@ { - "mcpServers": { - "svelte-llm": { - "type": "http", - "url": "https://svelte-llm.stanislav.garden/mcp/mcp" - } - } + "mcpServers": {} } \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index a31fb66dd..13bdcf4d0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,14 +3,18 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Environment Setup + **IMPORTANT**: Always activate the virtual environment before running Python or git operations: + ```bash source .venv/bin/activate ``` -Repo requires `.env` file with WandB credentials (see `.env.example`) +If working in a worktree, make sure there's a local `.venv` first by running `uv sync` in the worktree directory. Do NOT `cd` to the main repo — all commands (including git) should run in the worktree. +Repo requires `.env` file with WandB credentials (see `.env.example`) ## Project Overview + SPD (Stochastic Parameter Decomposition) is a research framework for analyzing neural network components and their interactions through sparse parameter decomposition techniques. - Target model parameters are decomposed as a sum of `parameter components` @@ -36,6 +40,8 @@ The codebase supports three experimental domains: TMS (Toy Model of Superpositio - `ss_llama_simple_mlp`, `ss_llama_simple_mlp-1L`, `ss_llama_simple_mlp-2L` - Llama MLP-only variants - `ss_gpt2`, `ss_gpt2_simple`, `ss_gpt2_simple_noln` - Simple Stories GPT-2 variants - `ss_gpt2_simple-1L`, `ss_gpt2_simple-2L` - GPT-2 simple layer variants + - `pile_llama_simple_mlp-2L`, `pile_llama_simple_mlp-4L`, `pile_llama_simple_mlp-12L` - Pile Llama MLP-only variants + - `pile_gpt2_simple-2L_global_reverse` - Pile GPT-2 with global reverse - `gpt2` - Standard GPT-2 - `ts` - TinyStories @@ -46,7 +52,7 @@ This repository implements methods from two key research papers on parameter dec **Stochastic Parameter Decomposition (SPD)** - [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) -- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. +- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. - Introduces the core SPD framework - Details the stochastic masking approach and optimization techniques used throughout the codebase - Useful reading for understanding the implementation details, though may be outdated. @@ -95,6 +101,7 @@ This repository implements methods from two key research papers on parameter dec ## Architecture Overview **Core SPD Framework:** + - `spd/run_spd.py` - Main SPD optimization logic called by all experiments - `spd/configs.py` - Pydantic config classes for all experiment types - `spd/registry.py` - Centralized experiment registry with all experiment configurations @@ -104,12 +111,18 @@ This repository implements methods from two key research papers on parameter dec - `spd/metrics.py` - Metrics for logging to WandB (e.g. CI-L0, KL divergence, etc.) - `spd/figures.py` - Figures for logging to WandB (e.g. CI histograms, Identity plots, etc.) +**Terminology: Sources vs Masks:** + +- **Sources** (`adv_sources`, `PPGDSources`, `self.sources`): The raw values that PGD optimizes adversarially. These are interpolated with CI to produce component masks: `mask = ci + (1 - ci) * source`. Used in both regular PGD (`spd/metrics/pgd_utils.py`) and persistent PGD (`spd/persistent_pgd.py`). +- **Masks** (`component_masks`, `RoutingMasks`, `make_mask_infos`, `n_mask_samples`): The materialized per-component masks used during forward passes. These are produced from sources (in PGD) or from stochastic sampling, and are a general SPD concept across the whole codebase. + **Experiment Structure:** Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: + - `models.py` - Experiment-specific model classes and pretrained loading - `*_decomposition.py` - Main SPD execution script -- `train_*.py` - Training script for target models +- `train_*.py` - Training script for target models - `*_config.yaml` - Configuration files - `plotting.py` - Visualization utilities @@ -123,7 +136,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: **Configuration System:** - YAML configs define all experiment parameters -- Pydantic models provide type safety and validation +- Pydantic models provide type safety and validation - WandB integration for experiment tracking and model storage - Supports both local paths and `wandb:project/runs/run_id` format for model loading - Centralized experiment registry (`spd/registry.py`) manages all experiment configurations @@ -133,8 +146,9 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: - `spd/harvest/` - Offline GPU pipeline for collecting component statistics (correlations, token stats, activation examples) - `spd/autointerp/` - LLM-based automated interpretation of components - `spd/dataset_attributions/` - Multi-GPU pipeline for computing component-to-component attribution strengths aggregated over training data -- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions}//` -- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, and `spd/dataset_attributions/CLAUDE.md` for details +- `spd/graph_interp/` - Context-aware component labeling using graph structure (attributions + correlations) +- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions,graph_interp}//` +- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, `spd/dataset_attributions/CLAUDE.md`, and `spd/graph_interp/CLAUDE.md` for details **Output Directory (`SPD_OUT_DIR`):** @@ -156,11 +170,14 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── investigate/ # Agent investigation (see investigate/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) │ ├── dataset_attributions/ # Dataset attributions (see dataset_attributions/CLAUDE.md) │ ├── harvest/ # Statistics collection (see harvest/CLAUDE.md) +│ ├── postprocess/ # Unified postprocessing pipeline (harvest + attributions + autointerp) +│ ├── graph_interp/ # Context-aware interpretation (see graph_interp/CLAUDE.md) │ ├── pretrain/ # Target model pretraining (see pretrain/CLAUDE.md) │ ├── experiments/ # Experiment implementations │ │ ├── tms/ # Toy Model of Superposition @@ -193,16 +210,20 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-run` | `spd/scripts/run.py` | SLURM-based experiment runner | | `spd-local` | `spd/scripts/run_local.py` | Local experiment runner | | `spd-harvest` | `spd/harvest/scripts/run_slurm_cli.py` | Submit harvest SLURM job | -| `spd-autointerp` | `spd/autointerp/scripts/cli.py` | Submit autointerp SLURM job | +| `spd-autointerp` | `spd/autointerp/scripts/run_slurm_cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | +| `spd-postprocess` | `spd/postprocess/cli.py` | Unified postprocessing pipeline (harvest + attributions + interpret + evals) | +| `spd-graph-interp` | `spd/graph_interp/scripts/run_slurm_cli.py` | Submit graph interpretation SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | | `spd-pretrain` | `spd/pretrain/scripts/run_slurm_cli.py` | Pretrain target models | +| `spd-investigate` | `spd/investigate/scripts/run_slurm_cli.py` | Launch investigation agent | ### Files to Skip When Searching Use `spd/` as the search root (not repo root) to avoid noise. **Always skip:** + - `.venv/` - Virtual environment - `__pycache__/`, `.pytest_cache/`, `.ruff_cache/` - Build artifacts - `node_modules/` - Frontend dependencies @@ -212,27 +233,37 @@ Use `spd/` as the search root (not repo root) to avoid noise. - `wandb/` - WandB local files **Usually skip unless relevant:** + - `tests/` - Test files (unless debugging test failures) - `papers/` - Research paper drafts ### Common Call Chains **Running Experiments:** + - `spd-run` → `spd/scripts/run.py` → `spd/utils/slurm.py` → SLURM → `spd/run_spd.py` - `spd-local` → `spd/scripts/run_local.py` → `spd/run_spd.py` directly **Harvest Pipeline:** + - `spd-harvest` → `spd/harvest/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/harvest/scripts/run.py` → `spd/harvest/harvest.py` **Autointerp Pipeline:** -- `spd-autointerp` → `spd/autointerp/scripts/cli.py` → `spd/utils/slurm.py` → `spd/autointerp/interpret.py` + +- `spd-autointerp` → `spd/autointerp/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → `spd/autointerp/interpret.py` **Dataset Attributions Pipeline:** + - `spd-attributions` → `spd/dataset_attributions/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/dataset_attributions/harvest.py` **Clustering Pipeline:** + - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Investigation Pipeline:** + +- `spd-investigate` → `spd/investigate/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM → `spd/investigate/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -279,6 +310,56 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Investigation (`spd-investigate`) + +Launch a Claude Code agent to investigate a specific question about an SPD model: + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What components are involved in verb agreement?" --time 4:00:00 +``` + +Each investigation: + +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates the specific research question using SPD tools via MCP +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/investigations//` + +For parallel investigations, run the command multiple times with different prompts. + +See `spd/investigate/CLAUDE.md` for details. + +### Unified Postprocessing (`spd-postprocess`) + +Run all postprocessing steps for a completed SPD run with a single command: + +```bash +spd-postprocess # Run everything with default config +spd-postprocess --config custom_config.yaml # Use custom config +``` + +Defaults are defined in `PostprocessConfig` (`spd/postprocess/config.py`). Pass a custom YAML/JSON config to override. Set any section to `null` to skip it: + +- `attributions: null` — skip dataset attributions +- `autointerp: null` — skip autointerp entirely (interpret + evals) +- `autointerp.evals: null` — skip evals but still run interpret +- `intruder: null` — skip intruder eval + +SLURM dependency graph: + +``` +harvest (GPU array → merge) +├── intruder eval (CPU, depends on harvest merge, label-free) +└── autointerp (depends on harvest merge) + ├── interpret (CPU, LLM calls) + │ ├── detection (CPU, depends on interpret) + │ └── fuzzing (CPU, depends on interpret) +attributions (GPU array → merge, parallel with harvest) +``` + ### Running on SLURM Cluster (`spd-run`) For the core team, `spd-run` provides full-featured SLURM orchestration: @@ -290,6 +371,7 @@ spd-run # Run all experiments ``` All `spd-run` executions: + - Submit jobs to SLURM - Create a git snapshot for reproducibility - Create W&B workspace views @@ -310,6 +392,7 @@ spd-run --experiments --sweep --n_agents [--cpu] ``` Examples: + ```bash spd-run --experiments tms_5-2 --sweep --n_agents 4 # Run TMS 5-2 sweep with 4 GPU agents spd-run --experiments resid_mlp2 --sweep --n_agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents @@ -331,6 +414,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee - Default sweep parameters are loaded from `spd/scripts/sweep_params.yaml` - You can specify a custom sweep parameters file by passing its path to `--sweep` - Sweep parameters support both experiment-specific and global configurations: + ```yaml # Global parameters applied to all experiments global: @@ -343,7 +427,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee # Experiment-specific parameters (override global) tms_5-2: seed: - values: [100, 200] # Overrides global seed + values: [100, 200] # Overrides global seed task_config: feature_probability: values: [0.05, 0.1] @@ -369,6 +453,7 @@ model = ComponentModel.from_run_info(run_info) # Local paths work too model = ComponentModel.from_pretrained("/path/to/checkpoint.pt") ``` + **Path Formats:** - WandB: `wandb:entity/project/run_id` or `wandb:entity/project/runs/run_id` @@ -382,14 +467,14 @@ Downloaded runs are cached in `SPD_OUT_DIR/runs/-/`. - This includes not setting off multiple sweeps/evals that total >8 GPUs - Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` - ## Coding Guidelines & Software Engineering Principles **This is research code, not production. Prioritize simplicity and fail-fast over defensive programming.** Core principles: + - **Fail fast** - assert assumptions, crash on violations, don't silently recover -- **No backwards compat** - delete unused code, don't deprecate or add migration shims +- **No legacy support** - delete unused code, don't add fallbacks for old formats or migration shims - **Narrow types** - avoid `| None` unless null is semantically meaningful; use discriminated unions over bags of optional fields - **No try/except for control flow** - check preconditions explicitly, then trust them - **YAGNI** - don't add abstractions, config options, or flexibility for hypothetical futures @@ -418,22 +503,26 @@ config = get_config(path) value = config.key ``` - ### Tests + - The point of tests in this codebase is to ensure that the code is working as expected, not to prevent production outages - there's no deployment here. Therefore, don't worry about lots of larger integration/end-to-end tests. These often require too much overhead for what it's worth in our case, and this codebase is interactively run so often that issues will likely be caught by the user at very little cost. ### Assertions and error handling + - If you have an invariant in your head, assert it. Are you afraid to assert? Sounds like your program might already be broken. Assert, assert, assert. Never soft fail. - Do not write: `if everythingIsOk: continueHappyPath()`. Instead do `assert everythingIsOk` - You should have a VERY good reason to handle an error gracefully. If your program isn't working like it should then it shouldn't be running, you should be fixing it. - Do not write `try-catch` blocks unless it definitely makes sense +- **Write for the golden path.** Never let edge cases bloat the code. Before handling them, just raise an exception. If an edge case becomes annoying enough, we'll handle it then — but write first and foremost for the common case. ### Control Flow + - Keep I/O as high up as possible. Make as many functions as possible pure. - Prefer `match` over `if/elif/else` chains when dispatching on conditions - more declarative and makes cases explicit - If you either have (a and b) or neither, don't make them both independently optional. Instead, put them in an optional tuple ### Types, Arguments, and Defaults + - Write your invariants into types as much as possible. - Use jaxtyping for tensor shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. @@ -444,16 +533,16 @@ value = config.key - good: {: } - bad: {"tokens": …, "loss": …} - Default args are rarely a good idea. Avoid them unless necessary. You should have a very good reason for having a default value for an argument, especially if it's caller also defaults to the same thing -- This repo uses basedpyright (not mypy) +- This repo uses basedpyright (not mypy) - Keep defaults high in the call stack. +- Don't use `from __future__ import annotations` — use string quotes for forward references instead. ### Tensor Operations + - Try to use einops by default for clarity. - Assert shapes liberally - Document complex tensor manipulations - - ### Comments - Comments hide sloppy code. If you feel the need to write a comment, consider that you should instead @@ -463,10 +552,11 @@ value = config.key - separate an inlined computation into a meaningfully named variable - Don’t write dialogic / narrativised comments or code. Instead, write comments that describe the code as is, not the diff you're making. Examples of narrativising comments: - - `# the function now uses y instead of x` - - `# changed to be faster` - - `# we now traverse in reverse` + - `# the function now uses y instead of x` + - `# changed to be faster` + - `# we now traverse in reverse` - Here's an example of a bad diff, where the new comment makes reference to a change in code, not just the state of the code: + ``` 95 - # Reservoir states 96 - reservoir_states: list[ReservoirState] @@ -474,14 +564,15 @@ value = config.key 96 + reservoir: TensorReservoirState ``` - ### Other Important Software Development Practices -- Backwards compatibility that adds complexity should be avoided. -- Delete unused code. + +- Don't add legacy fallbacks or migration code - just change it and let old data be manually migrated if needed. +- Delete unused code. - If an argument is always x, strongly consider removing as an argument and just inlining - **Update CLAUDE.md files** when changing code structure, adding/removing files, or modifying key interfaces. Update the CLAUDE.md in the same directory (or nearest parent) as the changed files. ### GitHub + - To view github issues and PRs, use the github cli (e.g. `gh issue view 28` or `gh pr view 30`). - When making PRs, use the github template defined in `.github/pull_request_template.md`. - Before committing, ALWAYS ensure you are on the correct branch and do not use `git add .` to add all unstaged files. Instead, add only the individual files you changed, don't commit all files. diff --git a/Makefile b/Makefile index c771616a6..31a05a304 100644 --- a/Makefile +++ b/Makefile @@ -6,8 +6,10 @@ install: copy-templates .PHONY: install-dev install-dev: copy-templates uv sync - pre-commit install + uv run pre-commit install +.PHONY: install-all +install-all: install-dev install-app # special install for CI (GitHub Actions) that reduces disk usage and install time # 1. create a fresh venv with `--clear` -- this is mostly only for local testing of the CI install diff --git a/pyproject.toml b/pyproject.toml index cb29155fe..cef1fc201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,9 +30,12 @@ dependencies = [ "scipy>=1.14.1", "fastapi", "uvicorn", + "orjson", + "aiolimiter>=1.2", "openrouter>=0.1.1", "httpx>=0.28.0", - "zstandard" # For streaming datasets + "zstandard", # For streaming datasets + "kaleido==0.2.1", ] [dependency-groups] @@ -54,6 +57,9 @@ spd-clustering = "spd.clustering.scripts.run_pipeline:cli" spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" +spd-investigate = "spd.investigate.scripts.run_slurm_cli:cli" +spd-postprocess = "spd.postprocess.cli:cli" +spd-graph-interp = "spd.graph_interp.scripts.run_slurm_cli:cli" [build-system] requires = ["setuptools", "wheel"] diff --git a/scripts/plot_component_activations.py b/scripts/plot_component_activations.py deleted file mode 100644 index 3dab2e8fa..000000000 --- a/scripts/plot_component_activations.py +++ /dev/null @@ -1,242 +0,0 @@ -"""Plot component activations vs component ID for high-CI datapoints. - -Creates scatter plots (one per layer) where: -- X-axis: Component rank (ordered by median normalized activation) -- Y-axis: Component activation (normalized per-component to [0, 1]) -- Filter: Only plots datapoints where CI > threshold -""" - -import argparse -import json -from collections import defaultdict -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch - -from spd.harvest.schemas import ( - ActivationExample, - ComponentData, - ComponentTokenPMI, - get_activation_contexts_dir, -) -from spd.settings import SPD_OUT_DIR - - -def load_activation_contexts(run_id: str) -> dict[str, ComponentData]: - """Load all activation contexts.""" - ctx_dir = get_activation_contexts_dir(run_id) - path = ctx_dir / "components.jsonl" - assert path.exists(), f"No harvest data found for run {run_id}" - - components: dict[str, ComponentData] = {} - with open(path) as f: - for line in f: - data = json.loads(line) - data["activation_examples"] = [ - ActivationExample( - token_ids=ex["token_ids"], - ci_values=ex["ci_values"], - component_acts=ex.get("component_acts", [0.0] * len(ex["token_ids"])), - ) - for ex in data["activation_examples"] - ] - data["input_token_pmi"] = ComponentTokenPMI(**data["input_token_pmi"]) - data["output_token_pmi"] = ComponentTokenPMI(**data["output_token_pmi"]) - comp = ComponentData(**data) - components[comp.component_key] = comp - return components - - -def load_firing_counts(run_id: str) -> dict[str, int]: - """Load pre-calculated firing counts from harvest data.""" - token_stats_path = SPD_OUT_DIR / "harvest" / run_id / "correlations" / "token_stats.pt" - assert token_stats_path.exists(), f"No token stats found for run {run_id}" - - data = torch.load(token_stats_path) - component_keys = data["component_keys"] - firing_counts = data["firing_counts"] - - return {key: int(count) for key, count in zip(component_keys, firing_counts, strict=True)} - - -def extract_activations( - contexts: dict[str, ComponentData], - ci_threshold: float, -) -> tuple[dict[str, dict[str, list[float]]], dict[str, dict[str, list[float]]]]: - """Extract component activations, separating all vs above-threshold. - - Returns: - Tuple of: - - all_activations: layer -> component_key -> all activation values (for normalization) - - filtered_activations: layer -> component_key -> activations where CI > threshold - """ - all_activations: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) - filtered_activations: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) - - for component_key, component_data in contexts.items(): - layer = component_data.layer - for example in component_data.activation_examples: - for ci_val, act_val in zip(example.ci_values, example.component_acts, strict=True): - all_activations[layer][component_key].append(act_val) - if ci_val > ci_threshold: - filtered_activations[layer][component_key].append(act_val) - - return dict(all_activations), dict(filtered_activations) - - -def normalize_per_component( - all_activations: dict[str, list[float]], - filtered_activations: dict[str, list[float]], -) -> dict[str, np.ndarray]: - """Normalize filtered activations to [0, 1] using min-max from all activations.""" - normalized = {} - for key, filtered_acts in filtered_activations.items(): - if not filtered_acts: - continue - all_acts = np.array(all_activations[key]) - filtered_arr = np.array(filtered_acts) - min_val = all_acts.min() - max_val = all_acts.max() - if max_val > min_val: - normalized[key] = (filtered_arr - min_val) / (max_val - min_val) - else: - normalized[key] = np.full_like(filtered_arr, 0.5) - return normalized - - -def order_by_median(normalized: dict[str, np.ndarray]) -> list[str]: - """Order component keys by median of their normalized activations (descending).""" - medians = [(key, np.median(acts)) for key, acts in normalized.items()] - medians.sort(key=lambda x: x[1], reverse=True) - return [key for key, _ in medians] - - -def order_by_frequency( - normalized: dict[str, np.ndarray], firing_counts: dict[str, int] -) -> list[str]: - """Order component keys by pre-calculated firing counts (descending).""" - freqs = [(key, firing_counts.get(key, 0)) for key in normalized] - freqs.sort(key=lambda x: x[1], reverse=True) - return [key for key, _ in freqs] - - -def create_layer_scatter_plot( - normalized_by_key: dict[str, np.ndarray], - ordered_keys: list[str], - layer_name: str, - run_id: str, - output_path: Path, - x_label: str = "Component Rank (by median activation)", - y_label: str = "Normalized Component Activation", -) -> None: - """Create scatter plot for a single layer.""" - x_vals = [] - y_vals = [] - for rank, key in enumerate(ordered_keys): - acts = normalized_by_key[key] - x_vals.extend([rank] * len(acts)) - y_vals.extend(acts.tolist()) - - fig, ax = plt.subplots(figsize=(14, 8)) - ax.scatter(x_vals, y_vals, alpha=0.3, s=1, marker=".") - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - ax.set_title(f"Layer: {layer_name} ||| Run id: {run_id}") - - n_components = len(ordered_keys) - n_points = len(x_vals) - ax.text( - 0.02, - 0.98, - f"Components: {n_components}\nDatapoints: {n_points}", - transform=ax.transAxes, - verticalalignment="top", - fontsize=10, - bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.5}, - ) - - fig.tight_layout() - fig.savefig(output_path, dpi=300, bbox_inches="tight") - plt.close(fig) - - -def main(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("run_id", help="WandB run ID (e.g., 's-7884efcc')") - parser.add_argument( - "--ci-threshold", - type=float, - default=0.1, - help="Minimum CI value to include (default: 0.1)", - ) - args = parser.parse_args() - - base_output_dir = Path("scripts/outputs") / args.run_id / "component-act-scatter" - output_dir_median = base_output_dir / "order-by-median" - output_dir_freq = base_output_dir / "order-by-freq" - output_dir_median.mkdir(parents=True, exist_ok=True) - output_dir_freq.mkdir(parents=True, exist_ok=True) - - print(f"Loading activation contexts for run {args.run_id}...") - contexts = load_activation_contexts(args.run_id) - print(f"Loaded {len(contexts)} components") - - print("Loading firing counts...") - firing_counts = load_firing_counts(args.run_id) - - print("Extracting activations...") - all_by_layer, filtered_by_layer = extract_activations(contexts, args.ci_threshold) - - n_layers = len(filtered_by_layer) - n_total = sum(sum(len(v) for v in layer.values()) for layer in filtered_by_layer.values()) - print(f"Found {n_total} datapoints across {n_layers} layers with CI > {args.ci_threshold}") - - if n_total == 0: - print("No datapoints found above threshold. Try lowering --ci-threshold.") - return - - # Create plots ordered by median normalized activation - print(f"Creating per-layer plots (ordered by median) in {output_dir_median}/...") - for layer_name in sorted(all_by_layer.keys()): - all_acts = all_by_layer[layer_name] - filtered_acts = filtered_by_layer.get(layer_name, {}) - normalized = normalize_per_component(all_acts, filtered_acts) - if not normalized: - continue - ordered_keys = order_by_median(normalized) - safe_name = layer_name.replace(".", "_") - output_path = output_dir_median / f"{safe_name}.png" - create_layer_scatter_plot(normalized, ordered_keys, layer_name, args.run_id, output_path) - print(f" {output_path}") - - # Create plots ordered by CI activation frequency (with abs distance from midpoint) - print(f"Creating per-layer plots (ordered by frequency) in {output_dir_freq}/...") - for layer_name in sorted(all_by_layer.keys()): - all_acts = all_by_layer[layer_name] - filtered_acts = filtered_by_layer.get(layer_name, {}) - normalized = normalize_per_component(all_acts, filtered_acts) - if not normalized: - continue - # Transform to absolute distance from midpoint - abs_from_midpoint = {key: np.abs(acts - 0.5) for key, acts in normalized.items()} - ordered_keys = order_by_frequency(abs_from_midpoint, firing_counts) - safe_name = layer_name.replace(".", "_") - output_path = output_dir_freq / f"{safe_name}.png" - create_layer_scatter_plot( - abs_from_midpoint, - ordered_keys, - layer_name, - args.run_id, - output_path, - x_label="Component Rank (by firing frequency)", - y_label="|Normalized Component Activation - 0.5|", - ) - print(f" {output_path}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/spd/adapters/__init__.py b/spd/adapters/__init__.py new file mode 100644 index 000000000..aded4d188 --- /dev/null +++ b/spd/adapters/__init__.py @@ -0,0 +1,24 @@ +"""Harvest method adapters: method-specific logic for the generic harvest pipeline. + +Each decomposition method (SPD, CLT, MOLT) provides an adapter that knows how to: +- Load the model and build a dataloader +- Compute firings and activations from a batch (harvest_fn) +- Report layer structure and vocab size + +Construct via adapter_from_config(method_config). +""" + +from spd.adapters.base import DecompositionAdapter + + +def adapter_from_id(id: str) -> DecompositionAdapter: + from spd.adapters.spd import SPDAdapter + + if id.startswith("s-"): + return SPDAdapter(id) + elif id.startswith("clt-"): + raise NotImplementedError("CLT adapter not implemented yet") + elif id.startswith("molt-"): + raise NotImplementedError("MOLT adapter not implemented yet") + + raise ValueError(f"Unsupported decomposition ID: {id}") diff --git a/spd/adapters/base.py b/spd/adapters/base.py new file mode 100644 index 000000000..937ea453e --- /dev/null +++ b/spd/adapters/base.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod + +import torch +from torch.utils.data import DataLoader + +from spd.autointerp.schemas import ModelMetadata + + +class DecompositionAdapter(ABC): + @property + @abstractmethod + def decomposition_id(self) -> str: ... + + @property + @abstractmethod + def vocab_size(self) -> int: ... + + @property + @abstractmethod + def layer_activation_sizes(self) -> list[tuple[str, int]]: ... + + @property + @abstractmethod + def tokenizer_name(self) -> str: ... + + @property + @abstractmethod + def model_metadata(self) -> ModelMetadata: ... + + @abstractmethod + def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: ... diff --git a/spd/adapters/spd.py b/spd/adapters/spd.py new file mode 100644 index 000000000..69599214c --- /dev/null +++ b/spd/adapters/spd.py @@ -0,0 +1,72 @@ +from functools import cached_property +from typing import override + +import torch +from torch.utils.data import DataLoader + +from spd.adapters.base import DecompositionAdapter +from spd.autointerp.schemas import ModelMetadata +from spd.configs import LMTaskConfig +from spd.data import train_loader_and_tokenizer +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.topology import TransformerTopology +from spd.utils.general_utils import runtime_cast + + +class SPDAdapter(DecompositionAdapter): + def __init__(self, run_id: str): + self._run_id = run_id + + @cached_property + def spd_run_info(self): + return SPDRunInfo.from_path(f"goodfire/spd/runs/{self._run_id}") + + @cached_property + def component_model(self): + return ComponentModel.from_run_info(self.spd_run_info) + + @cached_property + def _topology(self) -> TransformerTopology: + return TransformerTopology(self.component_model.target_model) + + @property + @override + def decomposition_id(self) -> str: + return self._run_id + + @property + @override + def vocab_size(self) -> int: + return self._topology.embedding_module.num_embeddings + + @property + @override + def layer_activation_sizes(self) -> list[tuple[str, int]]: + cm = self.component_model + return list(cm.module_to_c.items()) + + @override + def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: + return train_loader_and_tokenizer(self.spd_run_info.config, batch_size)[0] + + @property + @override + def tokenizer_name(self) -> str: + cfg = self.spd_run_info.config + assert cfg.tokenizer_name is not None + return cfg.tokenizer_name + + @property + @override + def model_metadata(self) -> ModelMetadata: + cfg = self.spd_run_info.config + task_cfg = runtime_cast(LMTaskConfig, cfg.task_config) + return ModelMetadata( + n_blocks=self._topology.n_blocks, + model_class=cfg.pretrained_model_class, + dataset_name=task_cfg.dataset_name, + layer_descriptions={ + path: self._topology.target_to_canon(path) + for path in self.component_model.target_module_paths + }, + ) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 95d6bc1b3..7e86aa0fd 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -4,16 +4,19 @@ Web-based visualization and analysis tool for exploring neural network component - **Backend**: Python FastAPI (`backend/`) - **Frontend**: Svelte 5 + TypeScript (`frontend/`) -- **Database**: SQLite at `.data/app/prompt_attr.db` (relative to repo root) +- **Database**: SQLite at `SPD_OUT_DIR/app/prompt_attr.db` (shared across team via NFS) +- **TODOs**: See `TODO.md` for open work items ## Project Context This is a **rapidly iterated research tool**. Key implications: -- **Please do not code for backwards compatibility**: Schema changes don't need migrations, expect state can be deleted, etc. -- **Database is disposable**: Delete `.data/app/prompt_attr.db` if schema changes break things +- **Database is persistent shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, shared across the team. Do not delete. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access. + - **Schema changes require manual migration**: Update the `CREATE TABLE IF NOT EXISTS` statements to match the desired schema, then manually `ALTER TABLE` the real DB (back it up first). No automatic migration framework — just SQL. + - Keep the CREATE TABLE statements as the source of truth for the schema. - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert +- **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. ## Running the App @@ -32,23 +35,26 @@ This launches both backend (FastAPI/uvicorn) and frontend (Vite) dev servers. ``` backend/ ├── server.py # FastAPI app, CORS, routers -├── state.py # Singleton StateManager + HarvestCache (lazy-loaded harvest data) -├── compute.py # Core attribution computation +├── state.py # Singleton StateManager + HarvestRepo (lazy-loaded harvest data) +├── compute.py # Core attribution computation + intervention evaluation +├── app_tokenizer.py # AppTokenizer: wraps HF tokenizers for display/encoding +├── (topology lives at spd/topology.py — TransformerTopology) ├── schemas.py # Pydantic API models ├── dependencies.py # FastAPI dependency injection ├── utils.py # Logging/timing utilities ├── database.py # SQLite interface -├── optim_cis.py # Sparse CI optimization +├── optim_cis.py # Sparse CI optimization, loss configs, PGD └── routers/ - ├── runs.py # Load W&B runs + ├── runs.py # Load W&B runs + GET /api/model_info ├── graphs.py # Compute attribution graphs ├── prompts.py # Prompt management ├── activation_contexts.py # Serves pre-harvested activation contexts ├── intervention.py # Selective component activation ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering - ├── dataset_search.py # SimpleStories dataset search - └── agents.py # Various useful endpoints that AI agents should look at when helping + ├── dataset_search.py # Dataset search (reads dataset from run config) + ├── agents.py # Various useful endpoints that AI agents should look at when helping + └── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ``` Note: Activation contexts, correlations, and token stats are now loaded from pre-harvested data (see `spd/harvest/`). The app no longer computes these on-the-fly. @@ -70,6 +76,7 @@ frontend/src/ │ │ ├── dataset.ts # Dataset search │ │ └── clusters.ts # Component clustering │ ├── index.ts # Shared utilities (Loadable pattern) +│ ├── graphLayout.ts # Shared graph layout (parseLayer, row sorting) │ ├── promptAttributionsTypes.ts # TypeScript types │ ├── interventionTypes.ts │ ├── colors.ts # Color utilities @@ -84,9 +91,9 @@ frontend/src/ ├── ActivationContextsTab.svelte # Component firing patterns tab ├── ActivationContextsViewer.svelte ├── ActivationContextsPagedTable.svelte - ├── DatasetSearchTab.svelte # SimpleStories search UI + ├── DatasetSearchTab.svelte # Dataset search UI ├── DatasetSearchResults.svelte - ├── ClusterPathInput.svelte # Cluster path selector + ├── ClusterPathInput.svelte # Cluster path selector (dropdown populated from registry.ts) ├── ComponentProbeInput.svelte # Component probe UI ├── TokenHighlights.svelte # Token highlighting ├── prompt-attr/ @@ -150,8 +157,13 @@ Edge(source: Node, target: Node, strength: float, is_cross_seq: bool) # strength = gradient * activation # is_cross_seq = True for k/v → o_proj (attention pattern) -PromptAttributionResult(edges: list[Edge], output_probs: Tensor[seq, vocab], node_ci_vals: dict[str, float]) -# node_ci_vals maps "layer:seq:c_idx" → CI value +PromptAttributionResult(edges, ci_masked_out_logits, target_out_logits, node_ci_vals, node_subcomp_acts) + +TokenPrediction(token, token_id, prob, logit, target_prob, target_logit) + +InterventionResult(input_tokens, ci, stochastic, adversarial, ci_loss, stochastic_loss, adversarial_loss) +# ci/stochastic/adversarial are list[list[TokenPrediction]] (per-position top-k) +# losses are evaluated using the graph's implied loss context ``` ### Frontend Types (`promptAttributionsTypes.ts`) @@ -188,7 +200,7 @@ GraphData = { - `strength = grad * source_activation` - Create Edge for each alive source component -**Cross-sequence edges**: `is_kv_to_o_pair()` detects k/v → o_proj in same attention block. +**Cross-sequence edges**: `topology.is_cross_seq_pair()` detects k/v → o_proj in same attention block. These have gradients across sequence positions (causal attention pattern). ### Causal Importance (CI) @@ -207,13 +219,31 @@ Finds sparse CI mask that: - Minimizes L0 (active component count) - Uses importance minimality + CE loss (or KL loss) -### Intervention Forward +### Interventions (`compute.py → compute_intervention`) + +A single unified function evaluates a node selection under three masking regimes: + +- **CI**: mask = selection (binary on/off) +- **Stochastic**: mask = selection + (1-selection) × Uniform(0,1) +- **Adversarial**: PGD optimizes alive-but-unselected components to maximize loss; non-alive get Uniform(0,1) + +Returns `InterventionResult` with top-k `TokenPrediction`s per position for each regime, plus per-regime loss values. + +**Loss context**: Every graph has an implied loss that interventions evaluate against: -`compute_intervention_forward()`: +- **Standard/manual graphs** → `MeanKLLossConfig` (mean KL divergence from target across all positions) +- **Optimized graphs** → the graph's optimization loss (CE for a specific token at a position, or KL at a position) -1. Build component masks (all zeros) -2. Set mask=1.0 for selected nodes -3. Forward pass → top-k predictions per position +This loss is used for two things: (1) what PGD maximizes during adversarial evaluation, and (2) the `ci_loss`/`stochastic_loss`/`adversarial_loss` metrics reported in `InterventionResult`. + +**Alive masks**: `compute_intervention` recomputes the model's natural CI (one forward pass + `calc_causal_importances`) and binarizes at 0 to get alive masks. This ensures the alive set is always the full model's CI — not the graph's potentially sparse optimized CI. PGD can only manipulate alive-but-unselected components. + +**Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, +`adv_pgd_step_size`) are a _training_ regularizer — they make CI optimization robust. The PGD in +`compute_intervention` is an _eval_ metric — it measures worst-case performance for a given node +selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). + +**Base intervention run**: Created automatically during graph computation. Uses all interventable nodes with CI > 0. Persisted as an `intervention_run` so predictions are available synchronously. --- @@ -241,24 +271,29 @@ POST /api/graphs ### Intervention ``` -POST /api/intervention {text, nodes: ["h.0.attn.q_proj:3:5", ...]} - → compute_intervention_forward() - ← InterventionResponse with top-k predictions +POST /api/intervention/run {graph_id, selected_nodes, top_k, adv_pgd} + → compute_intervention(active_nodes, graph_alive_masks, loss_config) + ← InterventionRunSummary {id, selected_nodes, result: InterventionResult} + +InterventionResult = { + input_tokens, ci, stochastic, adversarial, // TokenPrediction[][] per regime + ci_loss, stochastic_loss, adversarial_loss // loss under each regime +} ``` ### Component Correlations & Interpretations ``` GET /api/correlations/components/{layer}/{component_idx} - → Load from HarvestCache (pre-harvested data) + → Load from HarvestRepo (pre-harvested data) ← ComponentCorrelationsResponse (precision, recall, jaccard, pmi) GET /api/correlations/token_stats/{layer}/{component_idx} - → Load from HarvestCache + → Load from HarvestRepo ← TokenStatsResponse (input/output token associations) GET /api/correlations/interpretation/{layer}/{component_idx} - → Load from HarvestCache (autointerp results) + → Load from HarvestRepo (autointerp results) ← InterpretationResponse (label, confidence, reasoning) ``` @@ -266,25 +301,25 @@ GET /api/correlations/interpretation/{layer}/{component_idx} ``` POST /api/dataset/search?query=... - → Search SimpleStories dataset - ← DatasetSearchMetadata + → Search the loaded run's training dataset (reads dataset_name from config) + ← DatasetSearchMetadata (includes dataset_name) GET /api/dataset/results?page=1&page_size=20 - ← Paginated search results + ← Paginated search results (text + generic metadata dict) ``` --- ## Database Schema -Located at `.data/app/prompt_attr.db`. Delete this file if schema changes cause issues. +Located at `SPD_OUT_DIR/app/prompt_attr.db` (shared via NFS). Uses DELETE journal mode with `fcntl.flock` write locking for safe concurrent access from multiple backends. -| Table | Key | Purpose | -| ------------------ | ---------------------------------- | ------------------------------------------------- | -| `runs` | `wandb_path` | W&B run references | -| `prompts` | `(run_id, context_length)` | Token sequences | -| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + output probs + node CI values | -| `intervention_runs`| `graph_id` | Saved intervention results | +| Table | Key | Purpose | +| ------------------- | ---------------------------------- | -------------------------------------------------------- | +| `runs` | `wandb_path` | W&B run references | +| `prompts` | `(run_id, context_length)` | Token sequences | +| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + CI/target logits + node CI values | +| `intervention_runs` | `graph_id` | Saved `InterventionResult` JSON (single `result` column) | Note: Activation contexts, correlations, token stats, and interpretations are loaded from pre-harvested data at `SPD_OUT_DIR/{harvest,autointerp}/` (see `spd/harvest/` and `spd/autointerp/`). @@ -299,13 +334,14 @@ StateManager.get() → AppState: - db: PromptAttrDB (always available) - run_state: RunState | None - model: ComponentModel - - tokenizer: PreTrainedTokenizerBase + - topology: TransformerTopology # Model topology (embedding, unembed, cross-seq roles) + - tokenizer: AppTokenizer # Token display, encoding, span construction - sources_by_target: dict[target_layer → source_layers] - - config, context_length, token_strings - - harvest: HarvestCache # Lazy-loaded pre-harvested data + - config, context_length + - harvest: HarvestRepo # Lazy-loaded pre-harvested data - dataset_search_state: DatasetSearchState | None # Cached search results -HarvestCache: # Lazy-loads from SPD_OUT_DIR/harvest// +HarvestRepo: # Lazy-loads from SPD_OUT_DIR/harvest// - correlations: CorrelationStorage | None - token_stats: TokenStatsStorage | None - activation_contexts: dict[str, ComponentData] | None @@ -332,6 +368,6 @@ HarvestCache: # Lazy-loads from SPD_OUT_DIR/harvest// ## Performance Notes -- **Edge limit**: `GLOBAL_EDGE_LIMIT = 5000` in graph visualization +- **Edge limit**: `GLOBAL_EDGE_LIMIT = 50000` in graph visualization - **SSE streaming**: Long computations stream progress updates - **Lazy loading**: Component details fetched on hover/pin diff --git a/spd/app/backend/app_tokenizer.py b/spd/app/backend/app_tokenizer.py new file mode 100644 index 000000000..acfa4d7eb --- /dev/null +++ b/spd/app/backend/app_tokenizer.py @@ -0,0 +1,119 @@ +"""Tokenizer wrapper that isolates HuggingFace tokenizer quirks from the rest of the app. + +The core problem: `"".join(tokenizer.decode([t]) for t in ids)` != `tokenizer.decode(ids)` +because tokenizers encode word boundaries in family-specific ways (BPE's Ġ prefix, +WordPiece's ## prefix, SentencePiece's ▁ prefix, byte-level token splitting, etc.). + +AppTokenizer provides two clean interfaces: +- get_spans(token_ids): per-token strings that concatenate to the full decoded text +- get_tok_display(token_id): single-token display string for vocab browsers / hover labels +""" + +from typing import Self + +from transformers import AutoTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +_CONTROL_CHAR_MAP = { + "\t": "⇥", + "\n": "↵", + "\r": "⏎", + "\x00": "␀", +} + + +def escape_for_display(s: str) -> str: + """Escape control characters for human-readable display.""" + for char, replacement in _CONTROL_CHAR_MAP.items(): + s = s.replace(char, replacement) + return s + + +class AppTokenizer: + """Wraps a HuggingFace tokenizer. All decoding grossness lives here.""" + + def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: + self._tok = tokenizer + self._is_fast = hasattr(tokenizer, "backend_tokenizer") + + @classmethod + def from_pretrained(cls, tokenizer_name: str) -> Self: + hf_tok = AutoTokenizer.from_pretrained(tokenizer_name) + assert isinstance(hf_tok, PreTrainedTokenizerBase) + return cls(hf_tok) + + @property + def hf_tokenizer(self) -> PreTrainedTokenizerBase: + """The underlying HuggingFace tokenizer, for APIs that require it directly.""" + return self._tok + + @property + def vocab_size(self) -> int: + size = self._tok.vocab_size + assert isinstance(size, int) + return size + + @property + def eos_token_id(self) -> int: + eos = self._tok.eos_token_id + assert isinstance(eos, int) + return eos + + def encode(self, text: str) -> list[int]: + return self._tok.encode(text, add_special_tokens=False) + + def decode(self, token_ids: list[int]) -> str: + return self._tok.decode(token_ids, skip_special_tokens=False) + + def get_spans(self, token_ids: list[int]) -> list[str]: + """Decode token_ids into per-token display strings that concatenate to the full text. + + Uses offset_mapping (from the Rust tokenizer backend) when available, with dedup + for overlapping byte-token spans. Falls back to per-token decode otherwise. + """ + if not token_ids: + return [] + + if not self._is_fast: + return self._fallback_spans(token_ids) + + text = self._tok.decode(token_ids, skip_special_tokens=False) + re_encoded = self._tok(text, return_offsets_mapping=True, add_special_tokens=False) + + if re_encoded.input_ids != token_ids: + return self._fallback_spans(token_ids) + + offsets: list[tuple[int, int]] = re_encoded.offset_mapping + assert len(offsets) == len(token_ids) + + spans: list[str] = [] + prev_end = 0 + for start, end in offsets: + if start >= prev_end: + # Include any gap characters (spaces, etc.) as prefix of this span + spans.append(text[prev_end:end]) + prev_end = end + else: + # Multi-byte char split across tokens: first token claimed the full char, + # continuation byte-tokens get empty string + spans.append("") + + assert "".join(spans) == text, f"span concat mismatch: {''.join(spans)!r} != {text!r}" + return [escape_for_display(span) for span in spans] + + def get_tok_display(self, token_id: int) -> str: + """Single token -> display string for vocab browsers and hover labels.""" + return escape_for_display(self._tok.decode([token_id], skip_special_tokens=False)) + + def _fallback_spans(self, token_ids: list[int]) -> list[str]: + """Incremental decode: each span = decode(:i+1) - decode(:i). + + O(n²) but correct for all tokenizer families (BPE, WordPiece, SentencePiece). + """ + spans: list[str] = [] + prev = "" + for i in range(len(token_ids)): + current = self._tok.decode(token_ids[: i + 1], skip_special_tokens=False) + spans.append(escape_for_display(current[len(prev) :])) + prev = current + return spans diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 6ef65dc99..c99de5a02 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -4,6 +4,7 @@ to avoid importing script files with global execution. """ +import time from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass @@ -11,14 +12,29 @@ import torch from jaxtyping import Bool, Float +from pydantic import BaseModel from torch import Tensor, nn -from tqdm.auto import tqdm -from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from spd.app.backend.optim_cis import OptimCIConfig, compute_label_prob, optimize_ci_values +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.optim_cis import ( + AdvPGDConfig, + CELossConfig, + CISnapshotCallback, + LogitLossConfig, + LossConfig, + OptimCIConfig, + OptimizationMetrics, + compute_recon_loss, + optimize_ci_values, + optimize_ci_values_batched, + run_adv_pgd, +) from spd.configs import SamplingType +from spd.log import logger +from spd.metrics.pgd_utils import interpolate_pgd_mask from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos +from spd.topology import TransformerTopology from spd.utils.general_utils import bf16_autocast @@ -30,6 +46,9 @@ class LayerAliveInfo: alive_c_idxs: list[int] # Components alive at any position +MAX_OUTPUT_NODES_PER_POS = 15 + + def compute_layer_alive_info( layer_name: str, ci_lower_leaky: dict[str, Tensor], @@ -37,20 +56,36 @@ def compute_layer_alive_info( output_prob_threshold: float, n_seq: int, device: str, + topology: TransformerTopology, ) -> LayerAliveInfo: - """Compute alive info for a layer. Handles regular, wte, and output layers. + """Compute alive info for a layer. Handles regular, embedding, and unembed layers. For CI layers, all components with CI > 0 are considered alive. Filtering by CI threshold is done at display time, not computation time. + + For unembed layer, caps at MAX_OUTPUT_NODES_PER_POS per position to keep + edge computation tractable with large vocabularies. """ - if layer_name == "wte": - # WTE: single pseudo-component, always alive at all positions + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + + if layer_name == embed_path: alive_mask = torch.ones(n_seq, 1, device=device, dtype=torch.bool) alive_c_idxs = [0] - elif layer_name == "output": + elif layer_name == unembed_path: assert ci_masked_out_probs is not None assert ci_masked_out_probs.shape[0] == 1 - alive_mask = ci_masked_out_probs[0] > output_prob_threshold + probs = ci_masked_out_probs[0] # [seq, vocab] + alive_mask = probs > output_prob_threshold + # Cap per position: keep only top-k per seq pos + for s in range(n_seq): + pos_alive = torch.where(alive_mask[s])[0] + if len(pos_alive) > MAX_OUTPUT_NODES_PER_POS: + pos_probs = probs[s, pos_alive] + _, keep_local = torch.topk(pos_probs, MAX_OUTPUT_NODES_PER_POS) + keep_idxs = pos_alive[keep_local] + alive_mask[s] = False + alive_mask[s, keep_idxs] = True alive_c_idxs = torch.where(alive_mask.any(dim=0))[0].tolist() else: ci = ci_lower_leaky[layer_name] @@ -72,6 +107,13 @@ def __str__(self) -> str: return f"{self.layer}:{self.seq_pos}:{self.component_idx}" +def _get_seq_pos(node_key: str) -> int: + """Extract sequence position from node key format 'layer:seq:cIdx'.""" + parts = node_key.split(":") + assert len(parts) == 3, f"Invalid node key format: {node_key}" + return int(parts[1]) + + @dataclass class Edge: """Edge in the attribution graph.""" @@ -87,6 +129,7 @@ class PromptAttributionResult: """Result of computing prompt attributions for a prompt.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -100,153 +143,36 @@ class OptimizedPromptAttributionResult: """Result of computing prompt attributions with optimized CI values.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities target_out_logits: Float[Tensor, "seq vocab"] # Target model raw logits - label_prob: float | None # P(label_token) with optimized CI mask, None if KL-only node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val node_subcomp_acts: dict[str, float] # layer:seq:c_idx -> subcomponent activation (v_i^T @ a) - - -def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: - """Check if pair requires cross-sequence gradient computation. - - For k/v → o_proj within the same attention block, output at s_out - has gradients w.r.t. inputs at all s_in ≤ s_out (causal attention). - """ - in_is_kv = any(x in in_layer for x in ["k_proj", "v_proj"]) - out_is_o = "o_proj" in out_layer - if not (in_is_kv and out_is_o): - return False - - # Check same attention block: "h.{idx}.attn.{proj}" - in_block = in_layer.split(".")[1] - out_block = out_layer.split(".")[1] - return in_block == out_block - - -def get_sources_by_target( - model: ComponentModel, - device: str, - sampling: SamplingType, -) -> dict[str, list[str]]: - """Find valid gradient connections grouped by target layer. - - Includes wte (input embeddings) as a source and output (logits) as a target. - - Returns: - Dict mapping out_layer -> list of in_layers that have gradient flow to it. - """ - # Use a small dummy batch - we only need to trace gradient connections - batch: Float[Tensor, "batch seq"] = torch.zeros(2, 3, dtype=torch.long, device=device) - - with torch.no_grad(), bf16_autocast(): - output_with_cache: OutputWithCache = model(batch, cache_type="input") - - ci = model.calc_causal_importances( - pre_weight_acts=output_with_cache.cache, - sampling=sampling, - detach_inputs=False, - ) - - # Create masks so we can use all components - mask_infos = make_mask_infos( - component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, - routing_masks="all", - ) - - # Hook to capture wte output with gradients - wte_cache: dict[str, Tensor] = {} - - def wte_hook( - _module: nn.Module, _args: tuple[Any, ...], _kwargs: dict[Any, Any], output: Tensor - ) -> Any: - output.requires_grad_(True) - wte_cache["wte_post_detach"] = output - return output - - assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" - wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) - - with torch.enable_grad(), bf16_autocast(): - comp_output_with_cache: OutputWithCache = model( - batch, - mask_infos=mask_infos, - cache_type="component_acts", - ) - - wte_handle.remove() - - cache = comp_output_with_cache.cache - cache["wte_post_detach"] = wte_cache["wte_post_detach"] - cache["output_pre_detach"] = comp_output_with_cache.output - - # Build layer list: wte first, component layers, output last - layers = ["wte"] - component_layer_names = [ - "attn.q_proj", - "attn.k_proj", - "attn.v_proj", - "attn.o_proj", - "mlp.c_fc", - "mlp.down_proj", - ] - n_blocks = get_model_n_blocks(model.target_model) - for i in range(n_blocks): - layers.extend([f"h.{i}.{layer_name}" for layer_name in component_layer_names]) - - # Add lm_head if it exists in target_module_paths (unembedding matrix) - if "lm_head" in model.target_module_paths: - layers.append("lm_head") - - layers.append("output") - - # Test all pairs: wte can feed into anything, anything can feed into output - test_pairs = [] - for in_layer in layers[:-1]: # Don't include "output" as source - for out_layer in layers[1:]: # Don't include "wte" as target - if layers.index(in_layer) < layers.index(out_layer): - test_pairs.append((in_layer, out_layer)) - - sources_by_target: dict[str, list[str]] = defaultdict(list) - for in_layer, out_layer in test_pairs: - out_pre_detach = cache[f"{out_layer}_pre_detach"] - in_post_detach = cache[f"{in_layer}_post_detach"] - out_value = out_pre_detach[0, 0, 0] - grads = torch.autograd.grad( - outputs=out_value, - inputs=in_post_detach, - retain_graph=True, - allow_unused=True, - ) - assert len(grads) == 1 - grad = grads[0] - if grad is not None: # pyright: ignore[reportUnnecessaryComparison] - sources_by_target[out_layer].append(in_layer) - return dict(sources_by_target) + metrics: OptimizationMetrics # Final loss metrics from optimization ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) -def _setup_wte_hook() -> tuple[Callable[..., Any], list[Tensor]]: - """Create hook to capture wte output with gradients. +def _setup_embed_hook() -> tuple[Callable[..., Any], list[Tensor]]: + """Create hook to capture embedding output with gradients. Returns the hook function and a mutable container for the cached output. The container is a list to allow mutation from the hook closure. """ - wte_cache: list[Tensor] = [] + embed_cache: list[Tensor] = [] - def wte_hook( + def embed_hook( _module: nn.Module, _args: tuple[Any, ...], _kwargs: dict[Any, Any], output: Tensor ) -> Any: output.requires_grad_(True) - assert len(wte_cache) == 0, "wte output should be cached only once" - wte_cache.append(output) + assert len(embed_cache) == 0, "embedding output should be cached only once" + embed_cache.append(output) return output - return wte_hook, wte_cache + return embed_hook, embed_cache def _compute_edges_for_target( @@ -255,57 +181,90 @@ def _compute_edges_for_target( target_info: LayerAliveInfo, source_infos: list[LayerAliveInfo], cache: dict[str, Tensor], - n_seq: int, -) -> list[Edge]: + loss_seq_pos: int, + topology: TransformerTopology, +) -> tuple[list[Edge], list[Edge]]: """Compute all edges flowing into a single target layer. For each alive (s_out, c_out) in the target layer, computes gradient-based - attribution strengths from all alive source components. + attribution strengths from all alive source components. Computes both signed + (∂y/∂x · x) and absolute-target (∂|y|/∂x · x) variants. + + Args: + loss_seq_pos: Maximum sequence position to include (inclusive). + Only compute edges for target positions <= loss_seq_pos. + + Returns: + (edges, edges_abs): Signed and absolute-target edge lists. """ edges: list[Edge] = [] + edges_abs: list[Edge] = [] out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{target}_pre_detach"] in_post_detaches: list[Float[Tensor, "1 s C"]] = [ cache[f"{source}_post_detach"] for source in sources ] - for s_out in range(n_seq): + for s_out in range(loss_seq_pos + 1): s_out_alive_c = [c for c in target_info.alive_c_idxs if target_info.alive_mask[s_out, c]] if not s_out_alive_c: continue for c_out in s_out_alive_c: + target_val = out_pre_detach[0, s_out, c_out] grads = torch.autograd.grad( - outputs=out_pre_detach[0, s_out, c_out], + outputs=target_val, inputs=in_post_detaches, retain_graph=True, ) + # ∂|y|/∂x = sign(y) · ∂y/∂x — avoids a second backward pass. + # This works because target_val is a single scalar. In dataset_attributions/ + # harvester.py, the target is sum(|y_i|) over batch+seq — there each y_i has a + # different sign, so you can't factor out one scalar. The issue isn't the chain + # rule (sign·grad is always valid per-element), it's that abs breaks the + # grad(sum)=sum(grad) trick that makes the batch reduction a single backward pass. + target_sign = target_val.sign() with torch.no_grad(): + canonical_target = topology.target_to_canon(target) for source, source_info, grad, in_post_detach in zip( sources, source_infos, grads, in_post_detaches, strict=True ): - is_cross_seq = is_kv_to_o_pair(source, target) + canonical_source = topology.target_to_canon(source) + is_cross_seq = topology.is_cross_seq_pair(canonical_source, canonical_target) weighted: Float[Tensor, "s C"] = (grad * in_post_detach)[0] - if source == "wte": + weighted_abs: Float[Tensor, "s C"] = weighted * target_sign + if canonical_source == "embed": weighted = weighted.sum(dim=1, keepdim=True) + weighted_abs = weighted_abs.sum(dim=1, keepdim=True) s_in_range = range(s_out + 1) if is_cross_seq else [s_out] for s_in in s_in_range: for c_in in source_info.alive_c_idxs: if not source_info.alive_mask[s_in, c_in]: continue + src = Node(layer=canonical_source, seq_pos=s_in, component_idx=c_in) + tgt = Node(layer=canonical_target, seq_pos=s_out, component_idx=c_out) edges.append( Edge( - source=Node(layer=source, seq_pos=s_in, component_idx=c_in), - target=Node(layer=target, seq_pos=s_out, component_idx=c_out), + source=src, + target=tgt, strength=weighted[s_in, c_in].item(), is_cross_seq=is_cross_seq, ) ) - return edges + edges_abs.append( + Edge( + source=src, + target=tgt, + strength=weighted_abs[s_in, c_in].item(), + is_cross_seq=is_cross_seq, + ) + ) + return edges, edges_abs def compute_edges_from_ci( model: ComponentModel, + topology: TransformerTopology, tokens: Float[Tensor, "1 seq"], ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], pre_weight_acts: dict[str, Float[Tensor, "1 seq d_in"]], @@ -314,8 +273,8 @@ def compute_edges_from_ci( target_out_logits: Float[Tensor, "1 seq vocab"], output_prob_threshold: float, device: str, - show_progress: bool, on_progress: ProgressCallback | None = None, + loss_seq_pos: int | None = None, ) -> PromptAttributionResult: """Core edge computation from pre-computed CI values. @@ -331,19 +290,30 @@ def compute_edges_from_ci( We compute CI-masked output probs separately (for display) before running the unmasked forward pass used for gradient computation. + + Args: + loss_seq_pos: Maximum sequence position to include (inclusive). + If None, includes all positions (default behavior). """ n_seq = tokens.shape[1] + if loss_seq_pos is None: + loss_seq_pos = n_seq - 1 # Compute CI-masked output probs (for display) before the gradient computation + t0 = time.perf_counter() with torch.no_grad(), bf16_autocast(): ci_masks = make_mask_infos(component_masks=ci_lower_leaky) ci_masked_logits: Tensor = model(tokens, mask_infos=ci_masks) ci_masked_out_probs = torch.softmax(ci_masked_logits, dim=-1) + logger.info(f"[perf] CI-masked forward: {time.perf_counter() - t0:.2f}s") + + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path - # Setup wte hook and run forward pass for gradient computation - wte_hook, wte_cache = _setup_wte_hook() - assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" - wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + # Setup embedding hook and run forward pass for gradient computation + t0 = time.perf_counter() + embed_hook, embed_cache = _setup_embed_hook() + embed_handle = topology.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) weight_deltas = model.calc_weight_deltas() weight_deltas_and_masks = { @@ -358,14 +328,16 @@ def compute_edges_from_ci( tokens, mask_infos=unmasked_masks, cache_type="component_acts" ) - wte_handle.remove() - assert len(wte_cache) == 1, "wte output should be cached" + embed_handle.remove() + assert len(embed_cache) == 1, "embedding output should be cached" cache = comp_output_with_cache.cache - cache["wte_post_detach"] = wte_cache[0] - cache["output_pre_detach"] = comp_output_with_cache.output + cache[f"{embed_path}_post_detach"] = embed_cache[0] + cache[f"{unembed_path}_pre_detach"] = comp_output_with_cache.output + logger.info(f"[perf] Gradient forward pass: {time.perf_counter() - t0:.2f}s") # Compute alive info for all layers upfront + t0 = time.perf_counter() all_layers: set[str] = set(sources_by_target.keys()) for sources in sources_by_target.values(): all_layers.update(sources) @@ -378,55 +350,74 @@ def compute_edges_from_ci( output_prob_threshold=output_prob_threshold, n_seq=n_seq, device=device, + topology=topology, ) for layer in all_layers } + total_alive = sum(len(info.alive_c_idxs) for info in alive_info.values()) + unembed_alive = len( + alive_info.get(unembed_path, LayerAliveInfo(torch.tensor([]), [])).alive_c_idxs + ) + logger.info( + f"[perf] Alive info: {time.perf_counter() - t0:.2f}s " + f"({total_alive} alive components, {unembed_alive} output nodes)" + ) # Compute edges for each target layer + t0 = time.perf_counter() edges: list[Edge] = [] + edges_abs: list[Edge] = [] total_source_layers = sum(len(sources) for sources in sources_by_target.values()) progress_count = 0 - pbar = ( - tqdm(total=total_source_layers, desc="Source layers by target", leave=True) - if show_progress - else None - ) for target, sources in sources_by_target.items(): - if pbar is not None: - pbar.set_description(f"Source layers by target: {target}") - - target_edges = _compute_edges_for_target( + t_target = time.perf_counter() + target_edges, target_edges_abs = _compute_edges_for_target( target=target, sources=sources, target_info=alive_info[target], source_infos=[alive_info[source] for source in sources], cache=cache, - n_seq=n_seq, + loss_seq_pos=loss_seq_pos, + topology=topology, ) edges.extend(target_edges) + edges_abs.extend(target_edges_abs) + canonical_target = topology.target_to_canon(target) + logger.info( + f"[perf] {canonical_target}: {time.perf_counter() - t_target:.2f}s, " + f"{len(target_edges)} edges" + ) progress_count += len(sources) - if pbar is not None: - pbar.update(len(sources)) if on_progress is not None: on_progress(progress_count, total_source_layers, target) - if pbar is not None: - pbar.close() + logger.info( + f"[perf] Edge computation total: {time.perf_counter() - t0:.2f}s ({len(edges)} edges)" + ) - node_ci_vals = extract_node_ci_vals(ci_lower_leaky) + t0 = time.perf_counter() + node_ci_vals = extract_node_ci_vals(ci_lower_leaky, topology) component_acts = model.get_all_component_acts(pre_weight_acts) node_subcomp_acts = extract_node_subcomp_acts( - component_acts, ci_threshold=0.0, ci_lower_leaky=ci_lower_leaky + component_acts, ci_threshold=0.0, ci_lower_leaky=ci_lower_leaky, topology=topology ) + logger.info(f"[perf] Node CI/subcomp extraction: {time.perf_counter() - t0:.2f}s") + + # Filter nodes and output tensors to only include positions <= loss_seq_pos + node_ci_vals = {k: v for k, v in node_ci_vals.items() if _get_seq_pos(k) <= loss_seq_pos} + node_subcomp_acts = { + k: v for k, v in node_subcomp_acts.items() if _get_seq_pos(k) <= loss_seq_pos + } return PromptAttributionResult( edges=edges, - ci_masked_out_probs=ci_masked_out_probs[0], - ci_masked_out_logits=comp_output_with_cache.output[0], - target_out_probs=target_out_probs[0], - target_out_logits=target_out_logits[0], + edges_abs=edges_abs, + ci_masked_out_probs=ci_masked_out_probs[0, : loss_seq_pos + 1], + ci_masked_out_logits=ci_masked_logits[0, : loss_seq_pos + 1], + target_out_probs=target_out_probs[0, : loss_seq_pos + 1], + target_out_logits=target_out_logits[0, : loss_seq_pos + 1], node_ci_vals=node_ci_vals, node_subcomp_acts=node_subcomp_acts, ) @@ -491,14 +482,15 @@ def filter_ci_to_included_nodes( def compute_prompt_attributions( model: ComponentModel, + topology: TransformerTopology, tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], output_prob_threshold: float, sampling: SamplingType, device: str, - show_progress: bool, on_progress: ProgressCallback | None = None, included_nodes: set[str] | None = None, + loss_seq_pos: int | None = None, ) -> PromptAttributionResult: """Compute prompt attributions using the model's natural CI values. @@ -508,7 +500,12 @@ def compute_prompt_attributions( If included_nodes is provided, CI values for non-included nodes are zeroed out before edge computation. This efficiently filters to only compute edges between the specified nodes (useful for generating graphs from a selection). + + Args: + loss_seq_pos: Maximum sequence position to include (inclusive). + If None, includes all positions (default behavior). """ + t0 = time.perf_counter() with torch.no_grad(), bf16_autocast(): output_with_cache = model(tokens, cache_type="input") pre_weight_acts = output_with_cache.cache @@ -519,6 +516,7 @@ def compute_prompt_attributions( sampling=sampling, detach_inputs=False, ) + logger.info(f"[perf] CI forward pass: {time.perf_counter() - t0:.2f}s") ci_lower_leaky = ci.lower_leaky if included_nodes is not None: @@ -526,6 +524,7 @@ def compute_prompt_attributions( return compute_edges_from_ci( model=model, + topology=topology, tokens=tokens, ci_lower_leaky=ci_lower_leaky, pre_weight_acts=pre_weight_acts, @@ -534,20 +533,21 @@ def compute_prompt_attributions( target_out_logits=target_out_logits, output_prob_threshold=output_prob_threshold, device=device, - show_progress=show_progress, on_progress=on_progress, + loss_seq_pos=loss_seq_pos, ) def compute_prompt_attributions_optimized( model: ComponentModel, + topology: TransformerTopology, tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], optim_config: OptimCIConfig, output_prob_threshold: float, device: str, - show_progress: bool, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizedPromptAttributionResult: """Compute prompt attributions using optimized sparse CI values. @@ -562,22 +562,15 @@ def compute_prompt_attributions_optimized( target_logits = model(tokens) target_out_probs = torch.softmax(target_logits, dim=-1) - ci_params = optimize_ci_values( + optim_result = optimize_ci_values( model=model, tokens=tokens, config=optim_config, device=device, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) - ci_outputs = ci_params.create_ci_outputs(model, device) - - # Get label probability with optimized CI mask (if CE loss is used) - label_prob: float | None = None - if optim_config.ce_loss_config is not None: - with torch.no_grad(): - label_prob = compute_label_prob( - model, tokens, ci_outputs.lower_leaky, optim_config.ce_loss_config.label_token - ) + ci_outputs = optim_result.params.create_ci_outputs(model, device) # Signal transition to graph computation stage if on_progress is not None: @@ -587,8 +580,12 @@ def compute_prompt_attributions_optimized( with torch.no_grad(), bf16_autocast(): pre_weight_acts = model(tokens, cache_type="input").cache + # Extract loss_seq_pos from optimization config + loss_seq_pos = optim_config.loss_config.position + result = compute_edges_from_ci( model=model, + topology=topology, tokens=tokens, ci_lower_leaky=ci_outputs.lower_leaky, pre_weight_acts=pre_weight_acts, @@ -597,22 +594,95 @@ def compute_prompt_attributions_optimized( target_out_logits=target_logits, output_prob_threshold=output_prob_threshold, device=device, - show_progress=show_progress, on_progress=on_progress, + loss_seq_pos=loss_seq_pos, ) return OptimizedPromptAttributionResult( edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_probs=result.ci_masked_out_probs, ci_masked_out_logits=result.ci_masked_out_logits, target_out_probs=result.target_out_probs, target_out_logits=result.target_out_logits, - label_prob=label_prob, node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, + metrics=optim_result.metrics, ) +def compute_prompt_attributions_optimized_batched( + model: ComponentModel, + topology: TransformerTopology, + tokens: Float[Tensor, "1 seq"], + sources_by_target: dict[str, list[str]], + configs: list[OptimCIConfig], + output_prob_threshold: float, + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizedPromptAttributionResult]: + """Compute prompt attributions for multiple sparsity coefficients in one batched optimization.""" + with torch.no_grad(), bf16_autocast(): + target_logits = model(tokens) + target_out_probs = torch.softmax(target_logits, dim=-1) + + optim_results = optimize_ci_values_batched( + model=model, + tokens=tokens, + configs=configs, + device=device, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + if on_progress is not None: + on_progress(0, len(optim_results), "graph") + + with torch.no_grad(), bf16_autocast(): + pre_weight_acts = model(tokens, cache_type="input").cache + + loss_seq_pos = configs[0].loss_config.position + + results: list[OptimizedPromptAttributionResult] = [] + for i, optim_result in enumerate(optim_results): + ci_outputs = optim_result.params.create_ci_outputs(model, device) + + result = compute_edges_from_ci( + model=model, + topology=topology, + tokens=tokens, + ci_lower_leaky=ci_outputs.lower_leaky, + pre_weight_acts=pre_weight_acts, + sources_by_target=sources_by_target, + target_out_probs=target_out_probs, + target_out_logits=target_logits, + output_prob_threshold=output_prob_threshold, + device=device, + on_progress=on_progress, + loss_seq_pos=loss_seq_pos, + ) + + results.append( + OptimizedPromptAttributionResult( + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_probs=result.ci_masked_out_probs, + ci_masked_out_logits=result.ci_masked_out_logits, + target_out_probs=result.target_out_probs, + target_out_logits=result.target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + metrics=optim_result.metrics, + ) + ) + + if on_progress is not None: + on_progress(i + 1, len(optim_results), "graph") + + return results + + @dataclass class CIOnlyResult: """Result of computing CI values only (no attribution graph).""" @@ -661,22 +731,20 @@ def compute_ci_only( def extract_node_ci_vals( ci_lower_leaky: dict[str, Float[Tensor, "1 seq n_components"]], + topology: TransformerTopology, ) -> dict[str, float]: """Extract per-node CI values from CI tensors. - Args: - ci_lower_leaky: Dict mapping layer name to CI tensor [1, seq, n_components]. - - Returns: - Dict mapping "layer:seq:c_idx" to CI value. + Returns dict mapping canonical node key to CI value. """ node_ci_vals: dict[str, float] = {} for layer_name, ci_tensor in ci_lower_leaky.items(): + canonical = topology.target_to_canon(layer_name) n_seq = ci_tensor.shape[1] n_components = ci_tensor.shape[2] for seq_pos in range(n_seq): for c_idx in range(n_components): - key = f"{layer_name}:{seq_pos}:{c_idx}" + key = f"{canonical}:{seq_pos}:{c_idx}" node_ci_vals[key] = float(ci_tensor[0, seq_pos, c_idx].item()) return node_ci_vals @@ -685,187 +753,269 @@ def extract_node_subcomp_acts( component_acts: dict[str, Float[Tensor, "1 seq C"]], ci_threshold: float, ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], + topology: TransformerTopology, ) -> dict[str, float]: """Extract per-node subcomponent activations from pre-computed component acts. - Args: - component_acts: Dict mapping layer name to component activations [1, seq, C]. - ci_threshold: Threshold for filtering nodes by CI value. - ci_lower_leaky: Dict mapping layer name to CI tensor [1, seq, C]. - - Returns: - Dict mapping "layer:seq:c_idx" to subcomponent activation value. + Returns dict mapping canonical node key to subcomponent activation value. """ node_subcomp_acts: dict[str, float] = {} for layer_name, subcomp_acts in component_acts.items(): + canonical = topology.target_to_canon(layer_name) ci = ci_lower_leaky[layer_name] alive_mask = ci[0] > ci_threshold # [seq, C] alive_seq_indices, alive_c_indices = torch.where(alive_mask) for seq_pos, c_idx in zip( alive_seq_indices.tolist(), alive_c_indices.tolist(), strict=True ): - key = f"{layer_name}:{seq_pos}:{c_idx}" + key = f"{canonical}:{seq_pos}:{c_idx}" node_subcomp_acts[key] = float(subcomp_acts[0, seq_pos, c_idx].item()) return node_subcomp_acts -def extract_active_from_ci( - ci_lower_leaky: dict[str, Float[Tensor, "1 seq n_components"]], - target_out_probs: Float[Tensor, "1 seq vocab"], - ci_threshold: float, - output_prob_threshold: float, - n_seq: int, -) -> dict[str, tuple[float, list[int]]]: - """Build inverted index data directly from CI values. +class TokenPrediction(BaseModel): + """A single token prediction with probability.""" - For regular component layers, a component is active at positions where CI > threshold. - For the output layer, a token is active at positions where prob > threshold. - For wte, a single pseudo-component (idx 0) is always active at all positions. + token: str + token_id: int + prob: float + logit: float + target_prob: float + target_logit: float - Args: - ci_lower_leaky: Dict mapping layer name to CI tensor [1, seq, n_components]. - target_out_probs: Target model output probability tensor [1, seq, vocab]. - ci_threshold: Threshold for component activation. - output_prob_threshold: Threshold for output token activation. - n_seq: Sequence length. - Returns: - Dict mapping component_key ("layer:c_idx") to (max_ci, positions). - """ - active: dict[str, tuple[float, list[int]]] = {} - - # Regular component layers - for layer, ci_tensor in ci_lower_leaky.items(): - n_components = ci_tensor.shape[-1] - for c_idx in range(n_components): - ci_per_pos = ci_tensor[0, :, c_idx] - positions = torch.where(ci_per_pos > ci_threshold)[0].tolist() - if positions: - key = f"{layer}:{c_idx}" - max_ci = float(ci_per_pos.max().item()) - active[key] = (max_ci, positions) - - # Output layer - use probability threshold - for c_idx in range(target_out_probs.shape[-1]): - prob_per_pos = target_out_probs[0, :, c_idx] - positions = torch.where(prob_per_pos > output_prob_threshold)[0].tolist() - if positions: - key = f"output:{c_idx}" - max_prob = float(prob_per_pos.max().item()) - active[key] = (max_prob, positions) - - # WTE - single pseudo-component always active at all positions - active["wte:0"] = (1.0, list(range(n_seq))) - - return active - - -def get_model_n_blocks(model: nn.Module) -> int: - """Get the number of blocks in the model.""" - from transformers.models.gpt2 import GPT2LMHeadModel - - from spd.pretrain.models import GPT2, GPT2Simple, LlamaSimple, LlamaSimpleMLP - - match model: - case GPT2LMHeadModel(): - return len(model.transformer.h) - case GPT2() | GPT2Simple() | LlamaSimple() | LlamaSimpleMLP(): - return len(model.h) - case _ if hasattr(model, "h"): - return len(model.h) # pyright: ignore[reportArgumentType] - case _: - raise ValueError(f"Unsupported model: {type(model)}") +class LabelPredictions(BaseModel): + """Prediction stats for the CE label token at the optimized position, per masking regime.""" + position: int + ci: TokenPrediction + stochastic: TokenPrediction + adversarial: TokenPrediction + ablated: TokenPrediction | None -@dataclass -class InterventionResult: - """Result of intervention forward pass.""" + +class InterventionResult(BaseModel): + """Unified result of an intervention evaluation under multiple masking regimes.""" input_tokens: list[str] - predictions_per_position: list[ - list[tuple[str, int, float, float, float, float]] - ] # [(token, id, spd_prob, logit, target_prob, target_logit)] + ci: list[list[TokenPrediction]] + stochastic: list[list[TokenPrediction]] + adversarial: list[list[TokenPrediction]] + ablated: list[list[TokenPrediction]] | None + ci_loss: float + stochastic_loss: float + adversarial_loss: float + ablated_loss: float | None + label: LabelPredictions | None + + +# Default eval PGD settings (distinct from optimization PGD which is a training regularizer) +DEFAULT_EVAL_PGD_CONFIG = AdvPGDConfig(n_steps=4, step_size=1.0, init="random") + + +def _extract_topk_predictions( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + top_k: int, +) -> list[list[TokenPrediction]]: + """Extract top-k token predictions per position, paired with target probs.""" + probs = torch.softmax(logits, dim=-1) + target_probs = torch.softmax(target_logits, dim=-1) + result: list[list[TokenPrediction]] = [] + for pos in range(probs.shape[1]): + top_vals, top_ids = torch.topk(probs[0, pos], top_k) + pos_preds: list[TokenPrediction] = [] + for p, tid_t in zip(top_vals, top_ids, strict=True): + tid = int(tid_t.item()) + pos_preds.append( + TokenPrediction( + token=tokenizer.get_tok_display(tid), + token_id=tid, + prob=float(p.item()), + logit=float(logits[0, pos, tid].item()), + target_prob=float(target_probs[0, pos, tid].item()), + target_logit=float(target_logits[0, pos, tid].item()), + ) + ) + result.append(pos_preds) + return result + + +def _extract_label_prediction( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + position: int, + label_token: int, +) -> TokenPrediction: + """Extract the prediction for a specific token at a specific position.""" + probs = torch.softmax(logits[0, position], dim=-1) + target_probs = torch.softmax(target_logits[0, position], dim=-1) + return TokenPrediction( + token=tokenizer.get_tok_display(label_token), + token_id=label_token, + prob=float(probs[label_token].item()), + logit=float(logits[0, position, label_token].item()), + target_prob=float(target_probs[label_token].item()), + target_logit=float(target_logits[0, position, label_token].item()), + ) -def compute_intervention_forward( +def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], - active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] + active_nodes: list[tuple[str, int, int]], + nodes_to_ablate: list[tuple[str, int, int]] | None, + tokenizer: AppTokenizer, + adv_pgd_config: AdvPGDConfig, + loss_config: LossConfig, + sampling: SamplingType, top_k: int, - tokenizer: PreTrainedTokenizerBase, ) -> InterventionResult: - """Forward pass with only specified nodes active. + """Unified intervention evaluation: CI, stochastic, adversarial, and optionally ablated. Args: - model: ComponentModel to run intervention on. - tokens: Input tokens of shape [1, seq]. - active_nodes: List of (layer, seq_pos, component_idx) tuples specifying which nodes to activate. + active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. + Used for CI, stochastic, and adversarial masking. + nodes_to_ablate: If provided, nodes to ablate in ablated (full target model minus these). + The frontend computes this as all_graph_nodes - selected_nodes. + If None, ablated is skipped. + loss_config: Loss for PGD adversary to maximize and for reporting metrics. + sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. - tokenizer: Tokenizer for decoding tokens. - - Returns: - InterventionResult with input tokens and top-k predictions per position. """ - seq_len = tokens.shape[1] device = tokens.device - # Build component masks: all zeros, then set 1s for active nodes - component_masks: dict[str, Float[Tensor, "1 seq C"]] = {} - for layer_name, C in model.module_to_c.items(): - component_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) + # Compute natural CI alive masks (the model's own binarized CI, independent of graph) + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = { + k: v > 0 for k, v in ci_outputs.lower_leaky.items() + } + # Build binary CI masks from active nodes (selected = 1, rest = 0) + ci_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name, C in model.module_to_c.items(): + ci_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) for layer, seq_pos, c_idx in active_nodes: - assert layer in component_masks, f"Layer {layer} not in model" - assert 0 <= seq_pos < seq_len, f"seq_pos {seq_pos} out of bounds [0, {seq_len})" - assert 0 <= c_idx < model.module_to_c[layer], ( - f"component_idx {c_idx} out of bounds [0, {model.module_to_c[layer]})" + ci_masks[layer][0, seq_pos, c_idx] = 1.0 + assert alive_masks[layer][0, seq_pos, c_idx], ( + f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) - component_masks[layer][0, seq_pos, c_idx] = 1.0 - mask_infos = make_mask_infos(component_masks, routing_masks="all") + with torch.no_grad(), bf16_autocast(): + # Target forward (unmasked) + target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) + + # CI forward (binary mask) + ci_mask_infos = make_mask_infos(ci_masks, routing_masks="all") + ci_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ci_mask_infos) + + # Stochastic forward: ci + (1-ci) * uniform + stoch_masks = { + layer: ci_masks[layer] + (1 - ci_masks[layer]) * torch.rand_like(ci_masks[layer]) + for layer in ci_masks + } + stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") + stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) + + # Target-sans forward (only if nodes_to_ablate provided) + ts_logits: Float[Tensor, "1 seq vocab"] | None = None + if nodes_to_ablate is not None: + ts_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name in ci_masks: + ts_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) + for layer, seq_pos, c_idx in nodes_to_ablate: + ts_masks[layer][0, seq_pos, c_idx] = 0.0 + weight_deltas = model.calc_weight_deltas() + ts_wd = { + k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() + } + ts_mask_infos = make_mask_infos( + ts_masks, routing_masks="all", weight_deltas_and_masks=ts_wd + ) + ts_logits = model(tokens, mask_infos=ts_mask_infos) + # Adversarial: PGD optimizes alive-but-unselected components + adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=ci_masks, + alive_masks=alive_masks, + adv_config=adv_pgd_config, + target_out=target_logits, + loss_config=loss_config, + ) + # Non-alive positions get uniform random fill + adv_masks = interpolate_pgd_mask(ci_masks, adv_sources) + with torch.no_grad(): + for layer in adv_masks: + non_alive = ~alive_masks[layer] + adv_masks[layer][non_alive] = torch.rand(int(non_alive.sum().item()), device=device) with torch.no_grad(), bf16_autocast(): - # SPD model forward pass (with component masks) - spd_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=mask_infos) - spd_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(spd_logits, dim=-1) + adv_mask_infos = make_mask_infos(adv_masks, routing_masks="all") + adv_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=adv_mask_infos) + + # Extract predictions and loss metrics + device_str = str(device) + with torch.no_grad(): + ci_preds = _extract_topk_predictions(ci_logits, target_logits, tokenizer, top_k) + stoch_preds = _extract_topk_predictions(stoch_logits, target_logits, tokenizer, top_k) + adv_preds = _extract_topk_predictions(adv_logits, target_logits, tokenizer, top_k) + + ci_loss = float( + compute_recon_loss(ci_logits, loss_config, target_logits, device_str).item() + ) + stoch_loss = float( + compute_recon_loss(stoch_logits, loss_config, target_logits, device_str).item() + ) + adv_loss = float( + compute_recon_loss(adv_logits, loss_config, target_logits, device_str).item() + ) - # Target model forward pass (no masks) - target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) - target_out_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(target_logits, dim=-1) - - # Get top-k predictions per position (based on SPD model's top-k) - predictions_per_position: list[list[tuple[str, int, float, float, float, float]]] = [] - for pos in range(seq_len): - pos_spd_probs = spd_probs[0, pos] - pos_spd_logits = spd_logits[0, pos] - pos_target_out_probs = target_out_probs[0, pos] - pos_target_logits = target_logits[0, pos] - top_probs, top_ids = torch.topk(pos_spd_probs, top_k) - - pos_predictions: list[tuple[str, int, float, float, float, float]] = [] - for spd_prob, token_id in zip(top_probs, top_ids, strict=True): - tid = int(token_id.item()) - token_str = tokenizer.decode([tid]) - target_prob = float(pos_target_out_probs[tid].item()) - target_logit = float(pos_target_logits[tid].item()) - pos_predictions.append( - ( - token_str, - tid, - float(spd_prob.item()), - float(pos_spd_logits[tid].item()), - target_prob, - target_logit, - ) + ts_preds: list[list[TokenPrediction]] | None = None + ts_loss: float | None = None + if ts_logits is not None: + ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) + ts_loss = float( + compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() ) - predictions_per_position.append(pos_predictions) - # Decode input tokens - input_tokens = [tokenizer.decode([int(t.item())]) for t in tokens[0]] + label: LabelPredictions | None = None + if isinstance(loss_config, CELossConfig | LogitLossConfig): + pos, tid = loss_config.position, loss_config.label_token + ts_label = ( + _extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid) + if ts_logits is not None + else None + ) + label = LabelPredictions( + position=pos, + ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), + stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), + adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), + ablated=ts_label, + ) + + input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) return InterventionResult( input_tokens=input_tokens, - predictions_per_position=predictions_per_position, + ci=ci_preds, + stochastic=stoch_preds, + adversarial=adv_preds, + ablated=ts_preds, + ci_loss=ci_loss, + stochastic_loss=stoch_loss, + adversarial_loss=adv_loss, + ablated_loss=ts_loss, + label=label, ) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index dae7c736e..f64593237 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -6,25 +6,49 @@ Interpretations are stored separately at SPD_OUT_DIR/autointerp//. """ +import fcntl import hashlib +import io import json +import os import sqlite3 -from dataclasses import asdict +from contextlib import contextmanager from pathlib import Path from typing import Literal +import torch from pydantic import BaseModel from spd.app.backend.compute import Edge, Node -from spd.app.backend.optim_cis import MaskType -from spd.app.backend.schemas import OutputProbability -from spd.settings import REPO_ROOT +from spd.app.backend.optim_cis import ( + CELossConfig, + KLLossConfig, + LogitLossConfig, + MaskType, + PositionalLossConfig, +) +from spd.settings import SPD_OUT_DIR GraphType = Literal["standard", "optimized", "manual"] -# Persistent data directories -_APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = SPD_OUT_DIR / "app" / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path. + + Checks env vars in order: + 1. SPD_INVESTIGATION_DIR - investigation mode, db at dir/app.db + 2. SPD_APP_DB_PATH - explicit override + 3. Default: SPD_OUT_DIR/app/prompt_attr.db + """ + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + return Path(investigation_dir) / "app.db" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH class Run(BaseModel): @@ -43,6 +67,11 @@ class PromptRecord(BaseModel): is_custom: bool = False +class PgdConfig(BaseModel): + n_steps: int + step_size: float + + class OptimizationParams(BaseModel): """Optimization parameters that affect graph computation.""" @@ -51,11 +80,12 @@ class OptimizationParams(BaseModel): pnorm: float beta: float mask_type: MaskType - # CE loss params (optional, must be set together) - label_token: int | None = None - ce_loss_coeff: float | None = None - # KL loss param (optional) - kl_loss_coeff: float | None = None + loss: PositionalLossConfig + pgd: PgdConfig | None = None + # Computed metrics (persisted for display on reload) + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + adv_pgd_label_prob: float | None = None class StoredGraph(BaseModel): @@ -68,13 +98,16 @@ class StoredGraph(BaseModel): # Core graph data (all types) edges: list[Edge] - out_probs: dict[str, OutputProbability] # seq:c_idx -> {prob, target_prob, token} + edges_abs: list[Edge] | None = ( + None # absolute-target variant (∂|y|/∂x · x), None for old graphs + ) + ci_masked_out_logits: torch.Tensor # [seq, vocab] + target_out_logits: torch.Tensor # [seq, vocab] node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val (required for all graphs) node_subcomp_acts: dict[str, float] = {} # layer:seq:c_idx -> subcomp act (v_i^T @ a) # Optimized-specific (None for other types) optimization_params: OptimizationParams | None = None - label_prob: float | None = None # P(label_token) with optimized CI mask # Manual-specific (None for other types) included_nodes: list[str] | None = None # Nodes included in this graph @@ -86,17 +119,7 @@ class InterventionRunRecord(BaseModel): id: int graph_id: int selected_nodes: list[str] # node keys that were selected - result_json: str # JSON-encoded InterventionResponse - created_at: str - - -class ForkedInterventionRunRecord(BaseModel): - """A forked intervention run with modified tokens.""" - - id: int - intervention_run_id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result_json: str # JSON-encoded InterventionResponse + result_json: str # JSON-encoded InterventionResult created_at: str @@ -105,16 +128,15 @@ class PromptAttrDB: Schema: - runs: One row per SPD run (keyed by wandb_path) - - activation_contexts: Component metadata + generation config, 1:1 with runs - prompts: One row per stored prompt (token sequence), keyed by run_id - - original_component_seq_max_activations: Inverted index mapping components to prompts by a - component's max activation for that prompt + - graphs: Attribution graphs for prompts - Attribution graphs (edges) are computed on-demand at serve time, not stored. + Attribution graphs are computed on-demand and cached. """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() + self._lock_path = self.db_path.with_suffix(".db.lock") self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None @@ -138,6 +160,16 @@ def __enter__(self) -> "PromptAttrDB": def __exit__(self, *args: object) -> None: self.close() + @contextmanager + def _write_lock(self): + """Acquire an exclusive file lock for write operations (NFS-safe).""" + with open(self._lock_path, "w") as lock_fd: + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + yield + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + # ------------------------------------------------------------------------- # Schema initialization # ------------------------------------------------------------------------- @@ -145,7 +177,7 @@ def __exit__(self, *args: object) -> None: def init_schema(self) -> None: """Initialize the database schema. Safe to call multiple times.""" conn = self._get_conn() - conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA journal_mode=DELETE") conn.execute("PRAGMA foreign_keys=ON") conn.executescript(""" CREATE TABLE IF NOT EXISTS runs ( @@ -162,19 +194,8 @@ def init_schema(self) -> None: is_custom INTEGER NOT NULL DEFAULT 0 ); - CREATE TABLE IF NOT EXISTS original_component_seq_max_activations ( - prompt_id INTEGER NOT NULL REFERENCES prompts(id), - component_key TEXT NOT NULL, - max_ci REAL NOT NULL, - positions TEXT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_prompts_run_id ON prompts(run_id); - CREATE INDEX IF NOT EXISTS idx_component_key - ON original_component_seq_max_activations(component_key); - CREATE INDEX IF NOT EXISTS idx_prompt_id - ON original_component_seq_max_activations(prompt_id); CREATE TABLE IF NOT EXISTS graphs ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -182,14 +203,20 @@ def init_schema(self) -> None: graph_type TEXT NOT NULL, -- 'standard', 'optimized', 'manual' -- Optimization params (NULL for non-optimized graphs) - label_token INTEGER, imp_min_coeff REAL, - ce_loss_coeff REAL, - kl_loss_coeff REAL, steps INTEGER, pnorm REAL, beta REAL, mask_type TEXT, + loss_config TEXT, -- JSON: {type: "ce"|"kl", coeff, position, label_token?} + loss_config_hash TEXT, -- SHA256 hash for uniqueness indexing + adv_pgd_n_steps INTEGER, + adv_pgd_step_size REAL, + + -- Optimization metrics (NULL for non-optimized graphs) + ci_masked_label_prob REAL, + stoch_masked_label_prob REAL, + adv_pgd_label_prob REAL, -- Manual graph params (NULL for non-manual graphs) included_nodes TEXT, -- JSON array of node keys in this graph @@ -197,15 +224,14 @@ def init_schema(self) -> None: -- The actual graph data (JSON) edges_data TEXT NOT NULL, + -- Absolute-target edges (∂|y|/∂x · x), NULL for old graphs + edges_data_abs TEXT, -- Node CI values: "layer:seq:c_idx" -> ci_val (required for all graphs) node_ci_vals TEXT NOT NULL, -- Node subcomponent activations: "layer:seq:c_idx" -> v_i^T @ a node_subcomp_acts TEXT NOT NULL DEFAULT '{}', - -- Output probabilities: "seq:c_idx" -> {prob, token} - output_probs_data TEXT NOT NULL, - - -- Optimization stats (NULL for non-optimized graphs) - label_prob REAL, + -- Output logits: torch.save({ci_masked, target}) as blob + output_logits BLOB NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); @@ -217,7 +243,7 @@ def init_schema(self) -> None: -- One optimized graph per unique parameter combination CREATE UNIQUE INDEX IF NOT EXISTS idx_graphs_optimized - ON graphs(prompt_id, label_token, imp_min_coeff, ce_loss_coeff, kl_loss_coeff, steps, pnorm, beta, mask_type) + ON graphs(prompt_id, imp_min_coeff, steps, pnorm, beta, mask_type, loss_config_hash, adv_pgd_n_steps, adv_pgd_step_size) WHERE graph_type = 'optimized'; -- One manual graph per unique node set (using hash for reliable uniqueness) @@ -232,24 +258,14 @@ def init_schema(self) -> None: id INTEGER PRIMARY KEY AUTOINCREMENT, graph_id INTEGER NOT NULL REFERENCES graphs(id), selected_nodes TEXT NOT NULL, -- JSON array of node keys - result TEXT NOT NULL, -- JSON InterventionResponse + result TEXT NOT NULL, -- JSON InterventionResult created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_intervention_runs_graph ON intervention_runs(graph_id); - - CREATE TABLE IF NOT EXISTS forked_intervention_runs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - intervention_run_id INTEGER NOT NULL REFERENCES intervention_runs(id) ON DELETE CASCADE, - token_replacements TEXT NOT NULL, -- JSON array of [seq_pos, new_token_id] tuples - result TEXT NOT NULL, -- JSON InterventionResponse - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_forked_intervention_runs_parent - ON forked_intervention_runs(intervention_run_id); """) + conn.commit() # ------------------------------------------------------------------------- @@ -258,15 +274,16 @@ def init_schema(self) -> None: def create_run(self, wandb_path: str) -> int: """Create a new run. Returns the run ID.""" - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO runs (wandb_path) VALUES (?)", - (wandb_path,), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO runs (wandb_path) VALUES (?)", + (wandb_path,), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_run_by_wandb_path(self, wandb_path: str) -> Run | None: """Get a run by its wandb path.""" @@ -294,48 +311,6 @@ def get_run(self, run_id: int) -> Run | None: # Prompt operations # ------------------------------------------------------------------------- - def add_prompts( - self, - run_id: int, - prompts: list[tuple[list[int], dict[str, tuple[float, list[int]]]]], - context_length: int, - ) -> list[int]: - """Add multiple prompts to the database in a single transaction. - - Args: - run_id: The run these prompts belong to. - prompts: List of (token_ids, active_components) tuples. - context_length: The context length setting used when generating these prompts. - - Returns: - List of prompt IDs. - """ - conn = self._get_conn() - prompt_ids: list[int] = [] - component_rows: list[tuple[int, str, float, str]] = [] - - for token_ids, active_components in prompts: - cursor = conn.execute( - "INSERT INTO prompts (run_id, token_ids, context_length) VALUES (?, ?, ?)", - (run_id, json.dumps(token_ids), context_length), - ) - prompt_id = cursor.lastrowid - assert prompt_id is not None - prompt_ids.append(prompt_id) - - for component_key, (max_ci, positions) in active_components.items(): - component_rows.append((prompt_id, component_key, max_ci, json.dumps(positions))) - - if component_rows: - conn.executemany( - """INSERT INTO original_component_seq_max_activations - (prompt_id, component_key, max_ci, positions) VALUES (?, ?, ?, ?)""", - component_rows, - ) - - conn.commit() - return prompt_ids - def find_prompt_by_token_ids( self, run_id: int, @@ -354,7 +329,6 @@ def add_custom_prompt( self, run_id: int, token_ids: list[int], - active_components: dict[str, tuple[float, list[int]]], context_length: int, ) -> int: """Add a custom prompt to the database, or return existing if duplicate. @@ -362,37 +336,25 @@ def add_custom_prompt( Args: run_id: The run this prompt belongs to. token_ids: The token IDs for the prompt. - active_components: Dict mapping component_key to (max_ci, positions). context_length: The context length setting. Returns: The prompt ID (existing or newly created). """ - existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) - if existing_id is not None: - return existing_id - - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", - (run_id, json.dumps(token_ids), context_length), - ) - prompt_id = cursor.lastrowid - assert prompt_id is not None + with self._write_lock(): + existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) + if existing_id is not None: + return existing_id - component_rows = [ - (prompt_id, component_key, max_ci, json.dumps(positions)) - for component_key, (max_ci, positions) in active_components.items() - ] - if component_rows: - conn.executemany( - """INSERT INTO original_component_seq_max_activations - (prompt_id, component_key, max_ci, positions) VALUES (?, ?, ?, ?)""", - component_rows, + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", + (run_id, json.dumps(token_ids), context_length), ) - - conn.commit() - return prompt_id + prompt_id = cursor.lastrowid + assert prompt_id is not None + conn.commit() + return prompt_id def get_prompt(self, prompt_id: int) -> PromptRecord | None: """Get a prompt by ID.""" @@ -429,57 +391,6 @@ def get_all_prompt_ids(self, run_id: int, context_length: int) -> list[int]: ).fetchall() return [row["id"] for row in rows] - def has_prompts(self, run_id: int, context_length: int) -> bool: - """Check if any prompts exist for a run with a specific context length.""" - return self.get_prompt_count(run_id, context_length) > 0 - - # ------------------------------------------------------------------------- - # Query operations - # ------------------------------------------------------------------------- - - def find_prompts_with_components( - self, - run_id: int, - component_keys: list[str], - require_all: bool = True, - ) -> list[int]: - """Find prompts where specified components are active. - - Args: - run_id: The run to search within. - component_keys: List of component keys like "h.0.attn.q_proj:5". - require_all: If True, require ALL components to be active (intersection). - If False, require ANY component to be active (union). - - Returns: - List of prompt IDs matching the query. - """ - assert component_keys, "No component keys provided" - - conn = self._get_conn() - placeholders = ",".join("?" * len(component_keys)) - - if require_all: - query = f""" - SELECT ca.prompt_id - FROM original_component_seq_max_activations ca - JOIN prompts p ON ca.prompt_id = p.id - WHERE p.run_id = ? AND ca.component_key IN ({placeholders}) - GROUP BY ca.prompt_id - HAVING COUNT(DISTINCT ca.component_key) = ? - """ - rows = conn.execute(query, (run_id, *component_keys, len(component_keys))).fetchall() - else: - query = f""" - SELECT DISTINCT ca.prompt_id - FROM original_component_seq_max_activations ca - JOIN prompts p ON ca.prompt_id = p.id - WHERE p.run_id = ? AND ca.component_key IN ({placeholders}) - """ - rows = conn.execute(query, (run_id, *component_keys)).fetchall() - - return [row["prompt_id"] for row in rows] - # ------------------------------------------------------------------------- # Graph operations # ------------------------------------------------------------------------- @@ -500,32 +411,69 @@ def save_graph( """ conn = self._get_conn() - edges_json = json.dumps([asdict(e) for e in graph.edges]) - probs_json = json.dumps({k: v.model_dump() for k, v in graph.out_probs.items()}) + def _node_to_dict(n: Node) -> dict[str, str | int]: + return { + "layer": n.layer, + "seq_pos": n.seq_pos, + "component_idx": n.component_idx, + } + + def _edges_to_json(edges: list[Edge]) -> str: + return json.dumps( + [ + { + "source": _node_to_dict(e.source), + "target": _node_to_dict(e.target), + "strength": e.strength, + "is_cross_seq": e.is_cross_seq, + } + for e in edges + ] + ) + + edges_json = _edges_to_json(graph.edges) + edges_abs_json = _edges_to_json(graph.edges_abs) if graph.edges_abs is not None else None + buf = io.BytesIO() + logits_dict: dict[str, torch.Tensor] = { + "ci_masked": graph.ci_masked_out_logits, + "target": graph.target_out_logits, + } + torch.save(logits_dict, buf) + output_logits_blob = buf.getvalue() node_ci_vals_json = json.dumps(graph.node_ci_vals) node_subcomp_acts_json = json.dumps(graph.node_subcomp_acts) # Extract optimization-specific values (NULL for non-optimized graphs) - label_token = None imp_min_coeff = None - ce_loss_coeff = None - kl_loss_coeff = None steps = None pnorm = None beta = None mask_type = None - label_prob = None + loss_config_json: str | None = None + loss_config_hash: str | None = None + adv_pgd_n_steps = None + adv_pgd_step_size = None + ci_masked_label_prob = None + stoch_masked_label_prob = None + adv_pgd_label_prob = None if graph.optimization_params: - label_token = graph.optimization_params.label_token imp_min_coeff = graph.optimization_params.imp_min_coeff - ce_loss_coeff = graph.optimization_params.ce_loss_coeff - kl_loss_coeff = graph.optimization_params.kl_loss_coeff steps = graph.optimization_params.steps pnorm = graph.optimization_params.pnorm beta = graph.optimization_params.beta mask_type = graph.optimization_params.mask_type - label_prob = graph.label_prob + loss_config_json = graph.optimization_params.loss.model_dump_json() + loss_config_hash = hashlib.sha256(loss_config_json.encode()).hexdigest() + adv_pgd_n_steps = ( + graph.optimization_params.pgd.n_steps if graph.optimization_params.pgd else None + ) + adv_pgd_step_size = ( + graph.optimization_params.pgd.step_size if graph.optimization_params.pgd else None + ) + ci_masked_label_prob = graph.optimization_params.ci_masked_label_prob + stoch_masked_label_prob = graph.optimization_params.stoch_masked_label_prob + adv_pgd_label_prob = graph.optimization_params.adv_pgd_label_prob # Extract manual-specific values (NULL for non-manual graphs) # Sort included_nodes and compute hash for reliable uniqueness @@ -535,94 +483,128 @@ def save_graph( included_nodes_json = json.dumps(sorted(graph.included_nodes)) included_nodes_hash = hashlib.sha256(included_nodes_json.encode()).hexdigest() - try: - cursor = conn.execute( - """INSERT INTO graphs - (prompt_id, graph_type, - label_token, imp_min_coeff, ce_loss_coeff, kl_loss_coeff, steps, pnorm, - beta, mask_type, included_nodes, included_nodes_hash, - edges_data, output_probs_data, node_ci_vals, node_subcomp_acts, label_prob) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - prompt_id, - graph.graph_type, - label_token, - imp_min_coeff, - ce_loss_coeff, - kl_loss_coeff, - steps, - pnorm, - beta, - mask_type, - included_nodes_json, - included_nodes_hash, - edges_json, - probs_json, - node_ci_vals_json, - node_subcomp_acts_json, - label_prob, - ), - ) - conn.commit() - graph_id = cursor.lastrowid - assert graph_id is not None - return graph_id - except sqlite3.IntegrityError as e: - match graph.graph_type: - case "standard": - raise ValueError( - f"Standard graph already exists for prompt_id={prompt_id}. " - "Use get_graphs() to retrieve existing graph or delete it first." - ) from e - case "optimized": - raise ValueError( - f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." - ) from e - case "manual": - # Get-or-create semantics: return existing graph ID - conn.rollback() - row = conn.execute( - """SELECT id FROM graphs - WHERE prompt_id = ? AND graph_type = 'manual' - AND included_nodes_hash = ?""", - (prompt_id, included_nodes_hash), - ).fetchone() - if row: - return row["id"] - # Should not happen if constraint triggered - raise ValueError("A manual graph with the same nodes already exists.") from e + with self._write_lock(): + try: + cursor = conn.execute( + """INSERT INTO graphs + (prompt_id, graph_type, + imp_min_coeff, steps, pnorm, beta, mask_type, + loss_config, loss_config_hash, + adv_pgd_n_steps, adv_pgd_step_size, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob, + included_nodes, included_nodes_hash, + edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + prompt_id, + graph.graph_type, + imp_min_coeff, + steps, + pnorm, + beta, + mask_type, + loss_config_json, + loss_config_hash, + adv_pgd_n_steps, + adv_pgd_step_size, + ci_masked_label_prob, + stoch_masked_label_prob, + adv_pgd_label_prob, + included_nodes_json, + included_nodes_hash, + edges_json, + edges_abs_json, + output_logits_blob, + node_ci_vals_json, + node_subcomp_acts_json, + ), + ) + conn.commit() + graph_id = cursor.lastrowid + assert graph_id is not None + return graph_id + except sqlite3.IntegrityError as e: + match graph.graph_type: + case "standard": + raise ValueError( + f"Standard graph already exists for prompt_id={prompt_id}. " + "Use get_graphs() to retrieve existing graph or delete it first." + ) from e + case "optimized": + raise ValueError( + f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." + ) from e + case "manual": + conn.rollback() + row = conn.execute( + """SELECT id FROM graphs + WHERE prompt_id = ? AND graph_type = 'manual' + AND included_nodes_hash = ?""", + (prompt_id, included_nodes_hash), + ).fetchone() + if row: + return row["id"] + raise ValueError( + "A manual graph with the same nodes already exists." + ) from e def _row_to_stored_graph(self, row: sqlite3.Row) -> StoredGraph: """Convert a database row to a StoredGraph.""" - edges = [ - Edge( - source=Node(**e["source"]), - target=Node(**e["target"]), - strength=float(e["strength"]), - is_cross_seq=bool(e["is_cross_seq"]), + + def _node_from_dict(d: dict[str, str | int]) -> Node: + return Node( + layer=str(d["layer"]), + seq_pos=int(d["seq_pos"]), + component_idx=int(d["component_idx"]), ) - for e in json.loads(row["edges_data"]) - ] - out_probs = { - k: OutputProbability(**v) for k, v in json.loads(row["output_probs_data"]).items() - } + + def _parse_edges(data: str) -> list[Edge]: + return [ + Edge( + source=_node_from_dict(e["source"]), + target=_node_from_dict(e["target"]), + strength=float(e["strength"]), + is_cross_seq=bool(e["is_cross_seq"]), + ) + for e in json.loads(data) + ] + + edges = _parse_edges(row["edges_data"]) + edges_abs = _parse_edges(row["edges_data_abs"]) if row["edges_data_abs"] else None + logits_data = torch.load(io.BytesIO(row["output_logits"]), weights_only=True) + ci_masked_out_logits: torch.Tensor = logits_data["ci_masked"] + target_out_logits: torch.Tensor = logits_data["target"] node_ci_vals: dict[str, float] = json.loads(row["node_ci_vals"]) node_subcomp_acts: dict[str, float] = json.loads(row["node_subcomp_acts"] or "{}") opt_params: OptimizationParams | None = None - label_prob: float | None = None if row["graph_type"] == "optimized": + loss_config_data = json.loads(row["loss_config"]) + loss_type = loss_config_data["type"] + assert loss_type in ("ce", "kl", "logit"), f"Unknown loss type: {loss_type}" + loss_config: PositionalLossConfig + match loss_type: + case "ce": + loss_config = CELossConfig(**loss_config_data) + case "kl": + loss_config = KLLossConfig(**loss_config_data) + case "logit": + loss_config = LogitLossConfig(**loss_config_data) + pgd = None + if row["adv_pgd_n_steps"] is not None: + pgd = PgdConfig(n_steps=row["adv_pgd_n_steps"], step_size=row["adv_pgd_step_size"]) opt_params = OptimizationParams( imp_min_coeff=row["imp_min_coeff"], steps=row["steps"], pnorm=row["pnorm"], beta=row["beta"], mask_type=row["mask_type"], - label_token=row["label_token"], - ce_loss_coeff=row["ce_loss_coeff"], - kl_loss_coeff=row["kl_loss_coeff"], + loss=loss_config, + pgd=pgd, + ci_masked_label_prob=row["ci_masked_label_prob"], + stoch_masked_label_prob=row["stoch_masked_label_prob"], + adv_pgd_label_prob=row["adv_pgd_label_prob"], ) - label_prob = row["label_prob"] # Parse manual-specific fields included_nodes: list[str] | None = None @@ -633,11 +615,12 @@ def _row_to_stored_graph(self, row: sqlite3.Row) -> StoredGraph: id=row["id"], graph_type=row["graph_type"], edges=edges, - out_probs=out_probs, + edges_abs=edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, node_ci_vals=node_ci_vals, node_subcomp_acts=node_subcomp_acts, optimization_params=opt_params, - label_prob=label_prob, included_nodes=included_nodes, ) @@ -652,10 +635,10 @@ def get_graphs(self, prompt_id: int) -> list[StoredGraph]: """ conn = self._get_conn() rows = conn.execute( - """SELECT id, graph_type, edges_data, output_probs_data, node_ci_vals, - node_subcomp_acts, label_token, imp_min_coeff, ce_loss_coeff, kl_loss_coeff, - steps, pnorm, beta, mask_type, label_prob, - included_nodes + """SELECT id, graph_type, edges_data, edges_data_abs, output_logits, node_ci_vals, + node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, + loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob FROM graphs WHERE prompt_id = ? ORDER BY @@ -669,10 +652,11 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: """Retrieve a single graph by its ID. Returns (graph, prompt_id) or None.""" conn = self._get_conn() row = conn.execute( - """SELECT id, prompt_id, graph_type, edges_data, output_probs_data, node_ci_vals, - node_subcomp_acts, label_token, imp_min_coeff, ce_loss_coeff, kl_loss_coeff, - steps, pnorm, beta, mask_type, label_prob, - included_nodes + """SELECT id, prompt_id, graph_type, edges_data, edges_data_abs, output_logits, + node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, + mask_type, loss_config, adv_pgd_n_steps, adv_pgd_step_size, + included_nodes, ci_masked_label_prob, stoch_masked_label_prob, + adv_pgd_label_prob FROM graphs WHERE id = ?""", (graph_id,), @@ -681,23 +665,18 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: return None return (self._row_to_stored_graph(row), row["prompt_id"]) - def delete_graphs_for_prompt(self, prompt_id: int) -> int: - """Delete all graphs for a prompt. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) - conn.commit() - return cursor.rowcount - - def delete_graphs_for_run(self, run_id: int) -> int: - """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute( - """DELETE FROM graphs - WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", - (run_id,), - ) - conn.commit() - return cursor.rowcount + def delete_prompt(self, prompt_id: int) -> None: + """Delete a prompt and all its graphs, intervention runs, and forked runs.""" + with self._write_lock(): + conn = self._get_conn() + graph_ids_query = "SELECT id FROM graphs WHERE prompt_id = ?" + conn.execute( + f"DELETE FROM intervention_runs WHERE graph_id IN ({graph_ids_query})", + (prompt_id,), + ) + conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.execute("DELETE FROM prompts WHERE id = ?", (prompt_id,)) + conn.commit() # ------------------------------------------------------------------------- # Intervention run operations @@ -714,21 +693,22 @@ def save_intervention_run( Args: graph_id: The graph ID this run belongs to. selected_nodes: List of node keys that were selected. - result_json: JSON-encoded InterventionResponse. + result_json: JSON-encoded InterventionResult. Returns: The intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO intervention_runs (graph_id, selected_nodes, result) - VALUES (?, ?, ?)""", - (graph_id, json.dumps(selected_nodes), result_json), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO intervention_runs (graph_id, selected_nodes, result) + VALUES (?, ?, ?)""", + (graph_id, json.dumps(selected_nodes), result_json), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: """Get all intervention runs for a graph. @@ -761,102 +741,7 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: def delete_intervention_run(self, run_id: int) -> None: """Delete an intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) - conn.commit() - - def delete_intervention_runs_for_graph(self, graph_id: int) -> int: - """Delete all intervention runs for a graph. Returns count deleted.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) - conn.commit() - return cursor.rowcount - - # ------------------------------------------------------------------------- - # Forked intervention run operations - # ------------------------------------------------------------------------- - - def save_forked_intervention_run( - self, - intervention_run_id: int, - token_replacements: list[tuple[int, int]], - result_json: str, - ) -> int: - """Save a forked intervention run. - - Args: - intervention_run_id: The parent intervention run ID. - token_replacements: List of (seq_pos, new_token_id) tuples. - result_json: JSON-encoded InterventionResponse. - - Returns: - The forked intervention run ID. - """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) - VALUES (?, ?, ?)""", - (intervention_run_id, json.dumps(token_replacements), result_json), - ) - conn.commit() - fork_id = cursor.lastrowid - assert fork_id is not None - return fork_id - - def get_forked_intervention_runs( - self, intervention_run_id: int - ) -> list[ForkedInterventionRunRecord]: - """Get all forked runs for an intervention run. - - Args: - intervention_run_id: The parent intervention run ID. - - Returns: - List of forked intervention run records, ordered by creation time. - """ - conn = self._get_conn() - rows = conn.execute( - """SELECT id, intervention_run_id, token_replacements, result, created_at - FROM forked_intervention_runs - WHERE intervention_run_id = ? - ORDER BY created_at""", - (intervention_run_id,), - ).fetchall() - - return [ - ForkedInterventionRunRecord( - id=row["id"], - intervention_run_id=row["intervention_run_id"], - token_replacements=json.loads(row["token_replacements"]), - result_json=row["result"], - created_at=row["created_at"], - ) - for row in rows - ] - - def get_intervention_run(self, run_id: int) -> InterventionRunRecord | None: - """Get a single intervention run by ID.""" - conn = self._get_conn() - row = conn.execute( - """SELECT id, graph_id, selected_nodes, result, created_at - FROM intervention_runs - WHERE id = ?""", - (run_id,), - ).fetchone() - - if row is None: - return None - - return InterventionRunRecord( - id=row["id"], - graph_id=row["graph_id"], - selected_nodes=json.loads(row["selected_nodes"]), - result_json=row["result"], - created_at=row["created_at"], - ) - - def delete_forked_intervention_run(self, fork_id: int) -> None: - """Delete a forked intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) + conn.commit() diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 798c88176..5ec2f15b5 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -1,6 +1,3 @@ -# %% -"""Optimize CI values for a single prompt while keeping component weights fixed.""" - from collections.abc import Callable from dataclasses import dataclass from pathlib import Path @@ -10,11 +7,13 @@ import torch.nn.functional as F import torch.optim as optim from jaxtyping import Bool, Float +from pydantic import BaseModel from torch import Tensor from tqdm.auto import tqdm -from spd.configs import ImportanceMinimalityLossConfig, SamplingType +from spd.configs import ImportanceMinimalityLossConfig, PGDInitStrategy, SamplingType from spd.metrics import importance_minimality_loss +from spd.metrics.pgd_utils import get_pgd_init_tensor, interpolate_pgd_mask from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.routing import AllLayersRouter @@ -25,19 +24,75 @@ MaskType = Literal["stochastic", "ci"] -@dataclass -class OptimCELossConfig: - """Cross-entropy loss config for CI optimization. These losses apply to the final token only.""" +class AdvPGDConfig(BaseModel): + """PGD adversary config for robust CI optimization.""" + + n_steps: int + step_size: float + init: PGDInitStrategy + +class CELossConfig(BaseModel): + """Cross-entropy loss: optimize for a specific token at a position.""" + + type: Literal["ce"] = "ce" coeff: float + position: int label_token: int -@dataclass -class OptimKLLossConfig: - """KL divergence loss config for CI optimization. These losses apply to the final token only.""" +class KLLossConfig(BaseModel): + """KL divergence loss: match target model distribution at a position.""" + type: Literal["kl"] = "kl" coeff: float + position: int + + +class LogitLossConfig(BaseModel): + """Logit loss: maximize the pre-softmax logit for a specific token at a position.""" + + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + + +class MeanKLLossConfig(BaseModel): + """Mean KL divergence loss: match target model distribution across all positions.""" + + type: Literal["mean_kl"] = "mean_kl" + coeff: float = 1.0 + + +PositionalLossConfig = CELossConfig | KLLossConfig | LogitLossConfig +LossConfig = CELossConfig | KLLossConfig | LogitLossConfig | MeanKLLossConfig + + +def compute_recon_loss( + logits: Tensor, + loss_config: LossConfig, + target_out: Tensor, + device: str, +) -> Tensor: + """Compute recon loss (CE, KL, or mean KL) from model output logits.""" + match loss_config: + case CELossConfig(position=pos, label_token=label_token): + return F.cross_entropy( + logits[0, pos, :].unsqueeze(0), + torch.tensor([label_token], device=device), + ) + case KLLossConfig(position=pos): + target_probs = F.softmax(target_out[0, pos, :], dim=-1) + pred_log_probs = F.log_softmax(logits[0, pos, :], dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="sum") + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[0, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + # sum over vocab, mean over positions (consistent with batched version) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) @dataclass @@ -65,6 +120,17 @@ def compute_alive_info( return AliveComponentInfo(alive_masks=alive_masks, alive_counts=alive_counts) +class OptimizationMetrics(BaseModel): + """Final loss metrics from CI optimization.""" + + ci_masked_label_prob: float | None = None # Probability of label under CI mask (CE loss only) + stoch_masked_label_prob: float | None = ( + None # Probability of label under stochastic mask (CE loss only) + ) + adv_pgd_label_prob: float | None = None # Probability of label under adversarial mask (CE only) + l0_total: float # Total L0 (active components) + + @dataclass class OptimizableCIParams: """Container for optimizable CI pre-sigmoid parameters.""" @@ -139,108 +205,6 @@ def create_optimizable_ci_params( ) -def compute_label_prob( - model: ComponentModel, - tokens: Tensor, - ci_lower_leaky: dict[str, Tensor], - label_token: int, -) -> float: - """Compute probability of label_token at final position with CI mask.""" - mask_infos = make_mask_infos(ci_lower_leaky, routing_masks="all") - with bf16_autocast(): - logits = model(tokens, mask_infos=mask_infos) - probs = F.softmax(logits[0, -1, :], dim=-1) - return float(probs[label_token].item()) - - -def compute_l0_stats( - ci_outputs: CIOutputs, - ci_alive_threshold: float, -) -> dict[str, float]: - """Compute L0 statistics for each layer.""" - stats: dict[str, float] = {} - for layer_name, layer_ci in ci_outputs.lower_leaky.items(): - l0_val = calc_ci_l_zero(layer_ci, ci_alive_threshold) - stats[f"l0/{layer_name}"] = l0_val - stats["l0/total"] = sum(stats.values()) - return stats - - -def compute_final_token_ce_kl( - model: ComponentModel, - batch: Tensor, - target_out: Tensor, - ci: dict[str, Tensor], - rounding_threshold: float, -) -> dict[str, float]: - """Compute CE and KL metrics for the final token only. - - Args: - model: The ComponentModel. - batch: Input tokens of shape [1, seq_len]. - target_out: Target model output logits of shape [1, seq_len, vocab]. - ci: Causal importance values (lower_leaky) per layer. - rounding_threshold: Threshold for rounding CI values to binary masks. - - Returns: - Dict with kl and ce_difference metrics for ci_masked, unmasked, and rounded_masked. - """ - assert batch.ndim == 2 and batch.shape[0] == 1, "Expected batch shape [1, seq_len]" - - # Get the label for CE (next token prediction at final position) - # The label is the token at the final position for the second-to-last logit prediction - # But since we're optimizing for CI on a single prompt, we use the final logit position - final_target_logits = target_out[0, -1, :] # [vocab] - - def kl_vs_target(logits: Tensor) -> float: - """KL divergence between predicted and target logits at final position.""" - final_logits = logits[0, -1, :] # [vocab] - target_probs = F.softmax(final_target_logits, dim=-1) - pred_log_probs = F.log_softmax(final_logits, dim=-1) - return F.kl_div(pred_log_probs, target_probs, reduction="sum").item() - - def ce_vs_target(logits: Tensor) -> float: - """CE between predicted logits and target's argmax at final position.""" - final_logits = logits[0, -1, :] # [vocab] - target_token = final_target_logits.argmax() - return F.cross_entropy(final_logits.unsqueeze(0), target_token.unsqueeze(0)).item() - - # Target model CE (baseline) - target_ce = ce_vs_target(target_out) - - # CI masked - ci_mask_infos = make_mask_infos(ci) - with bf16_autocast(): - ci_masked_logits = model(batch, mask_infos=ci_mask_infos) - ci_masked_kl = kl_vs_target(ci_masked_logits) - ci_masked_ce = ce_vs_target(ci_masked_logits) - - # Unmasked (all components active) - unmasked_infos = make_mask_infos({k: torch.ones_like(v) for k, v in ci.items()}) - with bf16_autocast(): - unmasked_logits = model(batch, mask_infos=unmasked_infos) - unmasked_kl = kl_vs_target(unmasked_logits) - unmasked_ce = ce_vs_target(unmasked_logits) - - # Rounded masked (binary masks based on threshold) - rounded_mask_infos = make_mask_infos( - {k: (v > rounding_threshold).float() for k, v in ci.items()} - ) - with bf16_autocast(): - rounded_masked_logits = model(batch, mask_infos=rounded_mask_infos) - rounded_masked_kl = kl_vs_target(rounded_masked_logits) - rounded_masked_ce = ce_vs_target(rounded_masked_logits) - - return { - "kl_ci_masked": ci_masked_kl, - "kl_unmasked": unmasked_kl, - "kl_rounded_masked": rounded_masked_kl, - "ce_difference_ci_masked": ci_masked_ce - target_ce, - "ce_difference_unmasked": unmasked_ce - target_ce, - "ce_difference_rounded_masked": rounded_masked_ce - target_ce, - } - - @dataclass class OptimCIConfig: """Configuration for optimizing CI values on a single prompt.""" @@ -257,27 +221,97 @@ class OptimCIConfig: log_freq: int - # Loss configs + # Loss config (CE or KL — must target a specific position) imp_min_config: ImportanceMinimalityLossConfig - ce_loss_config: OptimCELossConfig | None - kl_loss_config: OptimKLLossConfig | None + loss_config: PositionalLossConfig sampling: SamplingType ce_kl_rounding_threshold: float - mask_type: MaskType = "stochastic" + mask_type: MaskType + adv_pgd: AdvPGDConfig | None ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) +class CISnapshot(BaseModel): + """Snapshot of alive component counts during CI optimization for visualization.""" + + step: int + total_steps: int + layers: list[str] + seq_len: int + initial_alive: list[list[int]] # layers × seq + current_alive: list[list[int]] # layers × seq + l0_total: float + loss: float + + +CISnapshotCallback = Callable[[CISnapshot], None] + + +@dataclass +class OptimizeCIResult: + """Result from CI optimization including params and final metrics.""" + + params: OptimizableCIParams + metrics: OptimizationMetrics + + +def run_adv_pgd( + model: ComponentModel, + tokens: Tensor, + ci: dict[str, Float[Tensor, "1 seq C"]], + alive_masks: dict[str, Bool[Tensor, "1 seq C"]], + adv_config: AdvPGDConfig, + target_out: Tensor, + loss_config: LossConfig, +) -> dict[str, Float[Tensor, "1 seq C"]]: + """Run PGD to find adversarial sources maximizing loss. + + Sources are optimized via signed gradient ascent. Only alive positions are optimized. + Masks are computed as ci + (1 - ci) * source (same interpolation as training PGD). + + Returns detached adversarial source tensors. + """ + ci_detached = {k: v.detach() for k, v in ci.items()} + + adv_sources: dict[str, Tensor] = {} + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) + source[~alive_masks[layer_name]] = 0.0 + source.requires_grad_(True) + adv_sources[layer_name] = source + + source_list = list(adv_sources.values()) + + for _ in range(adv_config.n_steps): + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) + + with bf16_autocast(): + out = model(tokens, mask_infos=mask_infos) + + loss = compute_recon_loss(out, loss_config, target_out, str(tokens.device)) + + grads = torch.autograd.grad(loss, source_list) + with torch.no_grad(): + for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): + source.add_(adv_config.step_size * grad.sign()) + source.clamp_(0.0, 1.0) + source[~alive_masks[layer_name]] = 0.0 + + return {k: v.detach() for k, v in adv_sources.items()} + + def optimize_ci_values( model: ComponentModel, tokens: Tensor, config: OptimCIConfig, device: str, on_progress: ProgressCallback | None = None, -) -> OptimizableCIParams: + on_ci_snapshot: CISnapshotCallback | None = None, +) -> OptimizeCIResult: """Optimize CI values for a single prompt. Args: @@ -287,15 +321,13 @@ def optimize_ci_values( device: Device to run on. Returns: - The OptimizableCIParams object. + OptimizeCIResult containing params and final metrics. """ imp_min_coeff = config.imp_min_config.coeff assert imp_min_coeff is not None, "Importance minimality loss coefficient must be set" - # Freeze all model parameters model.requires_grad_(False) - # Get initial CI values from the model with torch.no_grad(), bf16_autocast(): output_with_cache: OutputWithCache = model(tokens, cache_type="input") initial_ci_outputs = model.calc_causal_importances( @@ -305,7 +337,6 @@ def optimize_ci_values( ) target_out = output_with_cache.output.detach() - # Compute alive info and create optimizable parameters alive_info = compute_alive_info(initial_ci_outputs.lower_leaky) ci_params: OptimizableCIParams = create_optimizable_ci_params( alive_info=alive_info, @@ -314,114 +345,488 @@ def optimize_ci_values( weight_deltas = model.calc_weight_deltas() + # Precompute snapshot metadata for CI visualization + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) progress_interval = max(1, config.steps // 20) # Report ~20 times during optimization + latest_loss: float = 0.0 for step in tqdm(range(config.steps), desc="Optimizing CI values"): - if on_progress is not None and step % progress_interval == 0: - on_progress(step, config.steps, "optimizing") + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params.create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) optimizer.zero_grad() - # Create CI outputs from current parameters ci_outputs = ci_params.create_ci_outputs(model, device) + # Recon forward pass (stochastic or CI masking) match config.mask_type: case "stochastic": - mask_infos = calc_stochastic_component_mask_info( + recon_mask_infos = calc_stochastic_component_mask_info( causal_importances=ci_outputs.lower_leaky, component_mask_sampling=config.sampling, weight_deltas=weight_deltas, router=AllLayersRouter(), ) case "ci": - mask_infos = make_mask_infos(component_masks=ci_outputs.lower_leaky) + recon_mask_infos = make_mask_infos(component_masks=ci_outputs.lower_leaky) with bf16_autocast(): - out = model(tokens, mask_infos=mask_infos) + recon_out = model(tokens, mask_infos=recon_mask_infos) + + imp_min_loss = importance_minimality_loss( + ci_upper_leaky=ci_outputs.upper_leaky, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + beta=config.imp_min_config.beta, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) - imp_min_loss = importance_minimality_loss( - ci_upper_leaky=ci_outputs.upper_leaky, - current_frac_of_training=step / config.steps, - pnorm=config.imp_min_config.pnorm, - beta=config.imp_min_config.beta, - eps=config.imp_min_config.eps, - p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, - p_anneal_final_p=config.imp_min_config.p_anneal_final_p, - p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + recon_loss = compute_recon_loss(recon_out, config.loss_config, target_out, device) + total_loss = config.loss_config.coeff * recon_loss + imp_min_coeff * imp_min_loss + latest_loss = total_loss.item() + + # PGD adversarial loss (runs in tandem with recon) + if config.adv_pgd is not None: + adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=ci_outputs.lower_leaky, + alive_masks=alive_info.alive_masks, + adv_config=config.adv_pgd, + loss_config=config.loss_config, + target_out=target_out, + ) + pgd_mask_infos = make_mask_infos( + interpolate_pgd_mask(ci_outputs.lower_leaky, adv_sources) ) - # Compute faithfulness losses (CE and/or KL) - faithfulness_loss = torch.tensor(0.0, device=device) - ce_loss_val: float | None = None - kl_loss_val: float | None = None + with bf16_autocast(): + pgd_out = model(tokens, mask_infos=pgd_mask_infos) - if config.ce_loss_config is not None: - ce_loss = F.cross_entropy( - out[0, -1, :].unsqueeze(0), - torch.tensor([config.ce_loss_config.label_token], device=device), - ) - faithfulness_loss = faithfulness_loss + config.ce_loss_config.coeff * ce_loss - ce_loss_val = ce_loss.item() + pgd_loss = compute_recon_loss(pgd_out, config.loss_config, target_out, device) + total_loss = total_loss + config.loss_config.coeff * pgd_loss - if config.kl_loss_config is not None: - # KL divergence: encourage masked output to match target distribution - target_probs = F.softmax(target_out[0, -1, :], dim=-1) - pred_log_probs = F.log_softmax(out[0, -1, :], dim=-1) - kl_loss = F.kl_div(pred_log_probs, target_probs, reduction="sum") - faithfulness_loss = faithfulness_loss + config.kl_loss_config.coeff * kl_loss - kl_loss_val = kl_loss.item() + total_loss.backward() + optimizer.step() - total_loss = faithfulness_loss + imp_min_coeff * imp_min_loss + # Compute final metrics after optimization + with torch.no_grad(): + final_ci_outputs = ci_params.create_ci_outputs(model, device) - if step % config.log_freq == 0 or step == config.steps - 1: - l0_stats = compute_l0_stats(ci_outputs, ci_alive_threshold=0.0) + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci_outputs.lower_leaky.values() + ) - # Compute CE/KL metrics for final token only - with torch.no_grad(): - ce_kl_stats = compute_final_token_ce_kl( - model=model, - batch=tokens, - target_out=target_out, - ci=ci_outputs.lower_leaky, - rounding_threshold=config.ce_kl_rounding_threshold, - ) + final_ci_masked_label_prob: float | None = None + final_stoch_masked_label_prob: float | None = None + + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + + # CI-masked probability + ci_mask_infos = make_mask_infos(final_ci_outputs.lower_leaky, routing_masks="all") + ci_logits = model(tokens, mask_infos=ci_mask_infos) + ci_probs = F.softmax(ci_logits[0, pos, :], dim=-1) + final_ci_masked_label_prob = float(ci_probs[label_token].item()) + + # Stochastic-masked probability (sample once for final metric) + stoch_mask_infos = calc_stochastic_component_mask_info( + causal_importances=final_ci_outputs.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + stoch_logits = model(tokens, mask_infos=stoch_mask_infos) + stoch_probs = F.softmax(stoch_logits[0, pos, :], dim=-1) + final_stoch_masked_label_prob = float(stoch_probs[label_token].item()) + + # Adversarial PGD final evaluation (needs gradients for PGD, so outside no_grad block) + final_adv_pgd_label_prob: float | None = None + + if config.adv_pgd is not None: + final_adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=final_ci_outputs.lower_leaky, + alive_masks=alive_info.alive_masks, + adv_config=config.adv_pgd, + target_out=target_out, + loss_config=config.loss_config, + ) + with torch.no_grad(): + adv_pgd_masks = make_mask_infos( + interpolate_pgd_mask(final_ci_outputs.lower_leaky, final_adv_sources) + ) + with bf16_autocast(): + adv_logits = model(tokens, mask_infos=adv_pgd_masks) + + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) + final_adv_pgd_label_prob = float(adv_probs[label_token].item()) + + metrics = OptimizationMetrics( + ci_masked_label_prob=final_ci_masked_label_prob, + stoch_masked_label_prob=final_stoch_masked_label_prob, + adv_pgd_label_prob=final_adv_pgd_label_prob, + l0_total=total_l0, + ) - log_terms: dict[str, float] = { - "imp_min_loss": imp_min_loss.item(), - "total_loss": total_loss.item(), - } - if ce_loss_val is not None: - log_terms["ce_loss"] = ce_loss_val - if kl_loss_val is not None: - log_terms["kl_loss"] = kl_loss_val - - # Log label probability if CE loss is used - if config.ce_loss_config is not None: - stoch_label_prob = F.softmax(out[0, -1, :], dim=-1)[ - config.ce_loss_config.label_token - ] - log_terms["stoch_masked_label_prob"] = stoch_label_prob.item() + return OptimizeCIResult( + params=ci_params, + metrics=metrics, + ) + + +def compute_recon_loss_batched( + logits: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, + target_out: Float[Tensor, "N seq vocab"], + device: str, +) -> Float[Tensor, " N"]: + """Compute per-element reconstruction loss for batched logits.""" + match loss_config: + case CELossConfig(position=pos, label_token=label_token): + labels = torch.full((logits.shape[0],), label_token, device=device) + return F.cross_entropy(logits[:, pos, :], labels, reduction="none") + case KLLossConfig(position=pos): + target_probs = F.softmax(target_out[:, pos, :], dim=-1) + pred_log_probs = F.log_softmax(logits[:, pos, :], dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1) + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[:, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) + + +def importance_minimality_loss_per_element( + ci_upper_leaky_batched: dict[str, Float[Tensor, "N seq C"]], + n_batch: int, + current_frac_of_training: float, + pnorm: float, + beta: float, + eps: float, + p_anneal_start_frac: float, + p_anneal_final_p: float | None, + p_anneal_end_frac: float, +) -> Float[Tensor, " N"]: + """Compute importance minimality loss independently for each batch element.""" + losses = [] + for i in range(n_batch): + element_ci = {k: v[i : i + 1] for k, v in ci_upper_leaky_batched.items()} + losses.append( + importance_minimality_loss( + ci_upper_leaky=element_ci, + current_frac_of_training=current_frac_of_training, + pnorm=pnorm, + beta=beta, + eps=eps, + p_anneal_start_frac=p_anneal_start_frac, + p_anneal_final_p=p_anneal_final_p, + p_anneal_end_frac=p_anneal_end_frac, + ) + ) + return torch.stack(losses) + + +def run_adv_pgd_batched( + model: ComponentModel, + tokens: Float[Tensor, "N seq"], + ci: dict[str, Float[Tensor, "N seq C"]], + alive_masks: dict[str, Bool[Tensor, "N seq C"]], + adv_config: AdvPGDConfig, + target_out: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, +) -> dict[str, Float[Tensor, "N seq C"]]: + """Run PGD adversary with batched tensors. Returns detached adversarial sources.""" + ci_detached = {k: v.detach() for k, v in ci.items()} + + adv_sources: dict[str, Tensor] = {} + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) + source[~alive_masks[layer_name]] = 0.0 + source.requires_grad_(True) + adv_sources[layer_name] = source + + source_list = list(adv_sources.values()) + + for _ in range(adv_config.n_steps): + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) + + with bf16_autocast(): + out = model(tokens, mask_infos=mask_infos) + + losses = compute_recon_loss_batched(out, loss_config, target_out, str(tokens.device)) + loss = losses.sum() + + grads = torch.autograd.grad(loss, source_list) + with torch.no_grad(): + for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): + source.add_(adv_config.step_size * grad.sign()) + source.clamp_(0.0, 1.0) + source[~alive_masks[layer_name]] = 0.0 + + return {k: v.detach() for k, v in adv_sources.items()} + + +def optimize_ci_values_batched( + model: ComponentModel, + tokens: Float[Tensor, "1 seq"], + configs: list[OptimCIConfig], + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizeCIResult]: + """Optimize CI values for N sparsity coefficients in a single batched loop. + + All configs must share the same loss_config, steps, mask_type, adv_pgd settings — + only imp_min_config.coeff varies between them. + """ + N = len(configs) + assert N > 0 + + config = configs[0] + imp_min_coeffs = torch.tensor([c.imp_min_config.coeff for c in configs], device=device) + for c in configs: + assert c.imp_min_config.coeff is not None + + model.requires_grad_(False) + + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + initial_ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + target_out = output_with_cache.output.detach() + + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky) + + ci_params_list = [ + create_optimizable_ci_params( + alive_info=alive_info, + initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, + ) + for _ in range(N) + ] + + weight_deltas = model.calc_weight_deltas() + + all_params: list[Tensor] = [] + for ci_params in ci_params_list: + all_params.extend(ci_params.get_parameters()) + + optimizer = optim.AdamW(all_params, lr=config.lr, weight_decay=config.weight_decay) + tokens_batched = tokens.expand(N, -1) + target_out_batched = target_out.expand(N, -1, -1) + + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + + progress_interval = max(1, config.steps // 20) + latest_loss = 0.0 + + for step in tqdm(range(config.steps), desc="Optimizing CI values (batched)"): + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: with torch.no_grad(): - ci_masked_label_prob = compute_label_prob( - model, tokens, ci_outputs.lower_leaky, config.ce_loss_config.label_token + snap_ci = ci_params_list[0].create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, ) - log_terms["ci_masked_label_prob"] = ci_masked_label_prob + ) - tqdm.write(f"\n--- Step {step} ---") - for name, value in log_terms.items(): - tqdm.write(f" {name}: {value:.6f}") - for name, value in l0_stats.items(): - tqdm.write(f" {name}: {value:.2f}") - for name, value in ce_kl_stats.items(): - tqdm.write(f" {name}: {value:.6f}") + optimizer.zero_grad() + + ci_outputs_list = [cp.create_ci_outputs(model, device) for cp in ci_params_list] + + layers = list(ci_outputs_list[0].lower_leaky.keys()) + batched_ci_lower_leaky: dict[str, Tensor] = { + layer: torch.cat([co.lower_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + batched_ci_upper_leaky: dict[str, Tensor] = { + layer: torch.cat([co.upper_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + + match config.mask_type: + case "stochastic": + recon_mask_infos = calc_stochastic_component_mask_info( + causal_importances=batched_ci_lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + case "ci": + recon_mask_infos = make_mask_infos(component_masks=batched_ci_lower_leaky) + + with bf16_autocast(): + recon_out = model(tokens_batched, mask_infos=recon_mask_infos) + + imp_min_losses = importance_minimality_loss_per_element( + ci_upper_leaky_batched=batched_ci_upper_leaky, + n_batch=N, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + beta=config.imp_min_config.beta, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) + + recon_losses = compute_recon_loss_batched( + recon_out, config.loss_config, target_out_batched, device + ) + + loss_coeff = config.loss_config.coeff + total_loss = (loss_coeff * recon_losses + imp_min_coeffs * imp_min_losses).sum() + latest_loss = total_loss.item() + + if config.adv_pgd is not None: + batched_alive_masks = { + k: v.expand(N, -1, -1) for k, v in alive_info.alive_masks.items() + } + adv_sources = run_adv_pgd_batched( + model=model, + tokens=tokens_batched, + ci=batched_ci_lower_leaky, + alive_masks=batched_alive_masks, + adv_config=config.adv_pgd, + target_out=target_out_batched, + loss_config=config.loss_config, + ) + pgd_masks = interpolate_pgd_mask(batched_ci_lower_leaky, adv_sources) + pgd_mask_infos = make_mask_infos(pgd_masks) + with bf16_autocast(): + pgd_out = model(tokens_batched, mask_infos=pgd_mask_infos) + pgd_losses = compute_recon_loss_batched( + pgd_out, config.loss_config, target_out_batched, device + ) + total_loss = total_loss + (loss_coeff * pgd_losses).sum() total_loss.backward() optimizer.step() - return ci_params + # Compute final metrics per element + results: list[OptimizeCIResult] = [] + for ci_params in ci_params_list: + with torch.no_grad(): + final_ci = ci_params.create_ci_outputs(model, device) + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci.lower_leaky.values() + ) + + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + + ci_mask_infos = make_mask_infos(final_ci.lower_leaky, routing_masks="all") + ci_logits = model(tokens, mask_infos=ci_mask_infos) + ci_probs = F.softmax(ci_logits[0, pos, :], dim=-1) + ci_masked_label_prob = float(ci_probs[label_token].item()) + + stoch_mask_infos = calc_stochastic_component_mask_info( + causal_importances=final_ci.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + stoch_logits = model(tokens, mask_infos=stoch_mask_infos) + stoch_probs = F.softmax(stoch_logits[0, pos, :], dim=-1) + stoch_masked_label_prob = float(stoch_probs[label_token].item()) + + adv_pgd_label_prob: float | None = None + if config.adv_pgd is not None: + final_adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=final_ci.lower_leaky, + alive_masks=alive_info.alive_masks, + adv_config=config.adv_pgd, + target_out=target_out, + loss_config=config.loss_config, + ) + with torch.no_grad(): + adv_masks = make_mask_infos( + interpolate_pgd_mask(final_ci.lower_leaky, final_adv_sources) + ) + with bf16_autocast(): + adv_logits = model(tokens, mask_infos=adv_masks) + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) + adv_pgd_label_prob = float(adv_probs[label_token].item()) + + results.append( + OptimizeCIResult( + params=ci_params, + metrics=OptimizationMetrics( + ci_masked_label_prob=ci_masked_label_prob, + stoch_masked_label_prob=stoch_masked_label_prob, + adv_pgd_label_prob=adv_pgd_label_prob, + l0_total=total_l0, + ), + ) + ) + + return results def get_out_dir() -> Path: diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index ffbb0c5de..7a3dfadcf 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -4,11 +4,17 @@ from spd.app.backend.routers.agents import router as agents_router from spd.app.backend.routers.clusters import router as clusters_router from spd.app.backend.routers.correlations import router as correlations_router +from spd.app.backend.routers.data_sources import router as data_sources_router from spd.app.backend.routers.dataset_attributions import router as dataset_attributions_router from spd.app.backend.routers.dataset_search import router as dataset_search_router +from spd.app.backend.routers.graph_interp import router as graph_interp_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router +from spd.app.backend.routers.pretrain_info import router as pretrain_info_router from spd.app.backend.routers.prompts import router as prompts_router +from spd.app.backend.routers.run_registry import router as run_registry_router from spd.app.backend.routers.runs import router as runs_router __all__ = [ @@ -16,10 +22,16 @@ "agents_router", "clusters_router", "correlations_router", + "data_sources_router", "dataset_attributions_router", "dataset_search_router", + "graph_interp_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", + "pretrain_info_router", "prompts_router", + "run_registry_router", "runs_router", ] diff --git a/spd/app/backend/routers/activation_contexts.py b/spd/app/backend/routers/activation_contexts.py index 260d8e0d0..266a7d40b 100644 --- a/spd/app/backend/routers/activation_contexts.py +++ b/spd/app/backend/routers/activation_contexts.py @@ -4,16 +4,18 @@ """ from collections import defaultdict +from typing import Annotated import torch -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel +from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import compute_ci_only from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.schemas import SubcomponentActivationContexts, SubcomponentMetadata from spd.app.backend.utils import log_errors -from spd.harvest.loaders import load_component_activation_contexts +from spd.harvest.schemas import ComponentData from spd.utils.distributed_utils import get_device @@ -31,27 +33,51 @@ class ComponentProbeResponse(BaseModel): tokens: list[str] ci_values: list[float] subcomp_acts: list[float] + next_token_probs: list[float | None] # Probability of next token (last is None) router = APIRouter(prefix="/api/activation_contexts", tags=["activation_contexts"]) +def example_to_activation_contexts( + comp: ComponentData, tokenizer: AppTokenizer, limit: int | None = None +) -> SubcomponentActivationContexts: + examples = comp.activation_examples + if limit is not None: + examples = examples[:limit] + + mean_ci = comp.mean_activations["causal_importance"] + example_tokens = [tokenizer.get_spans(ex.token_ids) for ex in examples] + example_ci = [ex.activations["causal_importance"] for ex in examples] + example_component_acts = [ex.activations["component_activation"] for ex in examples] + + return SubcomponentActivationContexts( + subcomponent_idx=comp.component_idx, + # We might consider replacing mean_ci here with firing density + mean_ci=mean_ci, + example_tokens=example_tokens, + example_ci=example_ci, + example_component_acts=example_component_acts, + ) + + @router.get("/summary") @log_errors def get_activation_contexts_summary( loaded: DepLoadedRun, ) -> dict[str, list[SubcomponentMetadata]]: """Return lightweight summary of activation contexts (just idx + mean_ci per component).""" - if not loaded.harvest.has_activation_contexts_summary(): - raise HTTPException(status_code=404, detail="No activation contexts summary found") - summary_data = loaded.harvest.activation_contexts_summary + if loaded.harvest is None: + raise HTTPException(status_code=404, detail="No harvest data available") + summary_data = loaded.harvest.get_summary() summary: dict[str, list[SubcomponentMetadata]] = defaultdict(list) for comp in summary_data.values(): - summary[comp.layer].append( + canonical_layer = loaded.topology.target_to_canon(comp.layer) + summary[canonical_layer].append( SubcomponentMetadata( subcomponent_idx=comp.component_idx, - mean_ci=comp.mean_ci, + mean_ci=comp.mean_activations["causal_importance"], ) ) @@ -68,32 +94,66 @@ def get_activation_context_detail( layer: str, component_idx: int, loaded: DepLoadedRun, + limit: Annotated[int | None, Query(ge=1, description="Max examples to return")] = None, ) -> SubcomponentActivationContexts: - """Return full activation context data for a single component.""" - component_key = f"{layer}:{component_idx}" - comp = load_component_activation_contexts(loaded.harvest.run_id, component_key) + """Return full activation context data for a single component. - # Convert token IDs to strings - PADDING_SENTINEL = -1 - token_strings = loaded.token_strings + Args: + limit: Maximum number of activation examples to return. If None, returns all. + Use limit=30 for initial load, then fetch more via pagination if needed. - def token_str(tid: int) -> str: - if tid == PADDING_SENTINEL: - return "" - assert tid in token_strings, f"Token ID {tid} not in vocab" - return token_strings[tid] + TODO: Add offset parameter for pagination to allow fetching remaining examples + after initial view is loaded. + """ + assert loaded.harvest is not None, "No harvest data available" + concrete_layer = loaded.topology.canon_to_target(layer) + component_key = f"{concrete_layer}:{component_idx}" + comp = loaded.harvest.get_component(component_key) + if comp is None: + raise HTTPException(status_code=404, detail=f"Component {component_key} not found") - example_tokens = [[token_str(tid) for tid in ex.token_ids] for ex in comp.activation_examples] - example_ci = [ex.ci_values for ex in comp.activation_examples] - example_component_acts = [ex.component_acts for ex in comp.activation_examples] + return example_to_activation_contexts(comp, loaded.tokenizer, limit) - return SubcomponentActivationContexts( - subcomponent_idx=comp.component_idx, - mean_ci=comp.mean_ci, - example_tokens=example_tokens, - example_ci=example_ci, - example_component_acts=example_component_acts, - ) + +class BulkActivationContextsRequest(BaseModel): + """Request for bulk activation contexts.""" + + component_keys: list[str] # canonical keys, e.g. ["0.mlp.up:5", "1.attn.q:12"] + limit: int = 30 + + +@router.post("/bulk") +@log_errors +def get_activation_contexts_bulk( + request: BulkActivationContextsRequest, + loaded: DepLoadedRun, +) -> dict[str, SubcomponentActivationContexts]: + """Bulk fetch activation contexts for multiple components. + + Returns a dict keyed by component_key. Components not found are omitted. + Uses optimized bulk loader with single file handle and sorted seeks. + """ + + # Translate canonical component keys to concrete paths for harvest lookup + def _to_concrete_key(canonical_key: str) -> str: + layer, idx = canonical_key.rsplit(":", 1) + concrete = loaded.topology.canon_to_target(layer) + return f"{concrete}:{idx}" + + assert loaded.harvest is not None, "No harvest data available" + concrete_to_canonical = {_to_concrete_key(k): k for k in request.component_keys} + concrete_keys = list(concrete_to_canonical.keys()) + components = loaded.harvest.get_components_bulk(concrete_keys) + + # Convert to response format with limit applied, keyed by canonical keys + result: dict[str, SubcomponentActivationContexts] = {} + for concrete_key, comp in components.items(): + canonical_key = concrete_to_canonical[concrete_key] + result[canonical_key] = example_to_activation_contexts( + comp, loaded.tokenizer, request.limit + ) + + return result @router.post("/probe") @@ -109,7 +169,7 @@ def probe_component( """ device = get_device() - token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) + token_ids = loaded.tokenizer.encode(request.text) assert len(token_ids) > 0, "Text produced no tokens" tokens_tensor = torch.tensor([token_ids], device=device) @@ -120,15 +180,28 @@ def probe_component( sampling=loaded.config.sampling, ) - assert request.layer in loaded.model.components, f"Layer {request.layer} not in model" + concrete_layer = loaded.topology.canon_to_target(request.layer) + assert concrete_layer in loaded.model.components, f"Layer {request.layer} not in model" - ci_tensor = result.ci_lower_leaky[request.layer] + ci_tensor = result.ci_lower_leaky[concrete_layer] ci_values = ci_tensor[0, :, request.component_idx].tolist() - token_strings = [loaded.token_strings[t] for t in token_ids] + spans = loaded.tokenizer.get_spans(token_ids) - subcomp_acts_tensor = result.component_acts[request.layer] + subcomp_acts_tensor = result.component_acts[concrete_layer] subcomp_acts = subcomp_acts_tensor[0, :, request.component_idx].tolist() + # Get probability of next token at each position + probs = result.target_out_probs[0] # [seq, vocab] + next_token_probs: list[float | None] = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[i, next_token_id].item() + next_token_probs.append(prob) + next_token_probs.append(None) # No next token for last position + return ComponentProbeResponse( - tokens=token_strings, ci_values=ci_values, subcomp_acts=subcomp_acts + tokens=spans, + ci_values=ci_values, + subcomp_acts=subcomp_acts, + next_token_probs=next_token_probs, ) diff --git a/spd/app/backend/routers/agents.py b/spd/app/backend/routers/agents.py index 7d380d52f..0df390bfa 100644 --- a/spd/app/backend/routers/agents.py +++ b/spd/app/backend/routers/agents.py @@ -39,7 +39,7 @@ def get_graph_by_id( return stored_graph_to_response( graph=graph, token_ids=prompt.token_ids, - token_strings_map=loaded.token_strings, + tokenizer=loaded.tokenizer, normalize=normalize, ci_threshold=ci_threshold, ) diff --git a/spd/app/backend/routers/clusters.py b/spd/app/backend/routers/clusters.py index ad8ded6bf..b2dc1d5b9 100644 --- a/spd/app/backend/routers/clusters.py +++ b/spd/app/backend/routers/clusters.py @@ -10,6 +10,7 @@ from spd.app.backend.utils import log_errors from spd.base_config import BaseConfig from spd.settings import SPD_OUT_DIR +from spd.topology import TransformerTopology router = APIRouter(prefix="/api/clusters", tags=["clusters"]) @@ -26,11 +27,10 @@ class ClusterMapping(BaseConfig): class ClusterMappingFile(BaseConfig): """Schema for the on-disk cluster mapping JSON file.""" - ensemble_id: str + clustering_run_id: str notes: str spd_run: str - n_iterations: int - run_idx: int + iteration: int clusters: dict[str, int | None] @@ -42,7 +42,7 @@ def load_cluster_mapping(file_path: str) -> ClusterMapping: Paths are resolved relative to SPD_OUT_DIR unless they are absolute. The file should contain a JSON object with: - - ensemble_id: string + - clustering_run_id: string - notes: string - spd_run: wandb path (must match currently loaded run) - clusters: dict mapping component keys to cluster IDs @@ -87,4 +87,17 @@ def load_cluster_mapping(file_path: str) -> ClusterMapping: f"but loaded run is '{run_state.run.wandb_path}'", ) - return ClusterMapping(mapping=parsed.clusters) + canonical_clusters = _to_canonical_keys(parsed.clusters, run_state.topology) + return ClusterMapping(mapping=canonical_clusters) + + +def _to_canonical_keys( + clusters: dict[str, int | None], topology: TransformerTopology +) -> dict[str, int | None]: + """Convert concrete component keys (e.g. 'h.3.mlp.down_proj:5') to canonical (e.g. '3.mlp.down:5').""" + result: dict[str, int | None] = {} + for key, cluster_id in clusters.items(): + layer, idx = key.rsplit(":", 1) + canonical_layer = topology.target_to_canon(layer) + result[f"{canonical_layer}:{idx}"] = cluster_id + return result diff --git a/spd/app/backend/routers/correlations.py b/spd/app/backend/routers/correlations.py index 865e881e6..ae9af7b64 100644 --- a/spd/app/backend/routers/correlations.py +++ b/spd/app/backend/routers/correlations.py @@ -11,9 +11,27 @@ from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.autointerp.schemas import ModelMetadata +from spd.configs import LMTaskConfig from spd.harvest import analysis -from spd.harvest.loaders import load_component_activation_contexts from spd.log import logger +from spd.topology import TransformerTopology +from spd.utils.general_utils import runtime_cast + + +def _canonical_to_concrete_key( + canonical_layer: str, component_idx: int, topology: TransformerTopology +) -> str: + """Translate canonical layer address + component idx to concrete component key for harvest data.""" + concrete = topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +def _concrete_to_canonical_key(concrete_key: str, topology: TransformerTopology) -> str: + """Translate concrete component key to canonical component key.""" + layer, idx = concrete_key.rsplit(":", 1) + canonical = topology.target_to_canon(layer) + return f"{canonical}:{idx}" class CorrelatedComponent(BaseModel): @@ -71,6 +89,8 @@ class InterpretationHeadline(BaseModel): label: str confidence: str + detection_score: float | None = None + fuzzing_score: float | None = None class InterpretationDetail(BaseModel): @@ -85,18 +105,28 @@ class InterpretationDetail(BaseModel): def get_all_interpretations( loaded: DepLoadedRun, ) -> dict[str, InterpretationHeadline]: - """Get all interpretation headlines (label + confidence only). + """Get all interpretation headlines (label + confidence + eval scores). Returns a dict keyed by component_key (layer:cIdx). + Returns empty dict if no interpretations are available. Reasoning and prompt are excluded - fetch individually via GET /interpretations/{layer}/{component_idx} when needed. """ + if loaded.interp is None: + return {} + + interpretations = loaded.interp.get_all_interpretations() + detection_scores = loaded.interp.get_detection_scores() + fuzzing_scores = loaded.interp.get_fuzzing_scores() + return { - key: InterpretationHeadline( + _concrete_to_canonical_key(key, loaded.topology): InterpretationHeadline( label=result.label, confidence=result.confidence, + detection_score=detection_scores.get(key) if detection_scores else None, + fuzzing_score=fuzzing_scores.get(key) if fuzzing_scores else None, ) - for key, result in loaded.harvest.interpretations.items() + for key, result in interpretations.items() } @@ -111,16 +141,17 @@ def get_interpretation_detail( Returns reasoning and prompt for the specified component. """ - component_key = f"{layer}:{component_idx}" - interpretations = loaded.harvest.interpretations + if loaded.interp is None: + raise HTTPException(status_code=404, detail="No autointerp data available") + concrete_key = _canonical_to_concrete_key(layer, component_idx, loaded.topology) + result = loaded.interp.get_interpretation(concrete_key) - if component_key not in interpretations: + if result is None: raise HTTPException( status_code=404, - detail=f"No interpretation found for component {component_key}", + detail=f"No interpretation found for component {layer}:{component_idx}", ) - result = interpretations[component_key] return InterpretationDetail(reasoning=result.reasoning, prompt=result.prompt) @@ -136,33 +167,28 @@ async def request_component_interpretation( Requires OPENROUTER_API_KEY environment variable. Returns the headline (label + confidence). Full detail available via GET endpoint. """ - import json import os - from dataclasses import asdict from openrouter import OpenRouter - from spd.autointerp.interpret import ( - OpenRouterModelName, - get_architecture_info, - interpret_component, - ) - from spd.autointerp.schemas import get_autointerp_dir + from spd.autointerp.config import CompactSkepticalConfig + from spd.autointerp.interpret import interpret_component - component_key = f"{layer}:{component_idx}" + assert loaded.harvest is not None, "No harvest data available" + assert loaded.interp is not None, "No autointerp data available" - interpretations = loaded.harvest.interpretations + component_key = _canonical_to_concrete_key(layer, component_idx, loaded.topology) - if component_key in interpretations: - result = interpretations[component_key] + existing = loaded.interp.get_interpretation(component_key) + if existing is not None: return InterpretationHeadline( - label=result.label, - confidence=result.confidence, + label=existing.label, + confidence=existing.confidence, ) - component_data = load_component_activation_contexts(loaded.harvest.run_id, component_key) + component_data = loaded.harvest.get_component(component_key) + assert component_data is not None, f"Component {component_key} not found in harvest" - # Get API key api_key = os.getenv("OPENROUTER_API_KEY") if not api_key: raise HTTPException( @@ -170,11 +196,8 @@ async def request_component_interpretation( detail="OPENROUTER_API_KEY environment variable not set", ) - # Get architecture info and tokenizer - arch = get_architecture_info(loaded.run.wandb_path) - - # Get token stats - token_stats = loaded.harvest.token_stats + token_stats = loaded.harvest.get_token_stats() + assert token_stats is not None, "Token stats required for interpretation" input_token_stats = analysis.get_input_token_stats( token_stats, component_key, loaded.tokenizer, top_k=20 @@ -188,40 +211,35 @@ async def request_component_interpretation( detail=f"Token stats not available for component {component_key}", ) - # Interpret the component - model_name = OpenRouterModelName.GEMINI_3_FLASH_PREVIEW - - async with OpenRouter(api_key=api_key) as client: - res = await interpret_component( - client=client, - model=model_name, - component=component_data, - arch=arch, - tokenizer=loaded.tokenizer, - input_token_stats=input_token_stats, - output_token_stats=output_token_stats, - ) - - if res is None: - raise HTTPException( - status_code=500, - detail="Failed to generate interpretation", - ) - - result, _, _ = res + model_metadata = ModelMetadata( + n_blocks=loaded.topology.n_blocks, + model_class=loaded.model.__class__.__name__, + dataset_name=runtime_cast(LMTaskConfig, loaded.config.task_config).dataset_name, + layer_descriptions={ + path: loaded.topology.target_to_canon(path) for path in loaded.model.target_module_paths + }, + ) - # Save to file - autointerp_dir = get_autointerp_dir(loaded.harvest.run_id) - autointerp_dir.mkdir(parents=True, exist_ok=True) - output_path = autointerp_dir / "results.jsonl" - with open(output_path, "a") as f: - f.write(json.dumps(asdict(result)) + "\n") + async with OpenRouter(api_key=api_key) as api: + try: + result = await interpret_component( + api=api, + model="google/gemini-3-flash-preview", + reasoning_effort="none", + strategy=CompactSkepticalConfig(), + component=component_data, + model_metadata=model_metadata, + app_tok=loaded.tokenizer, + input_token_stats=input_token_stats, + output_token_stats=output_token_stats, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to generate interpretation: {e}", + ) from e - # Update the cache - if loaded.harvest._interpretations is None: - loaded.harvest._interpretations = {} - assert isinstance(loaded.harvest._interpretations, dict) - loaded.harvest._interpretations[component_key] = result + loaded.interp.save_interpretation(result) logger.info(f"Generated interpretation for {component_key}: {result.label}") @@ -231,6 +249,24 @@ async def request_component_interpretation( ) +@router.get("/intruder_scores") +@log_errors +def get_intruder_scores(loaded: DepLoadedRun) -> dict[str, float]: + """Get intruder eval scores for all components. + + Returns a dict keyed by component_key (layer:cIdx) → score (0-1). + Returns empty dict if no intruder scores are available. + """ + if loaded.harvest is None: + return {} + scores = loaded.harvest.get_scores("intruder") + if not scores: + return {} + return { + _concrete_to_canonical_key(key, loaded.topology): score for key, score in scores.items() + } + + # ============================================================================= # Component Correlation Data Endpoints # ============================================================================= @@ -250,8 +286,11 @@ def get_component_token_stats( and output tokens (what this component predicts). Returns None if token stats haven't been harvested for this run. """ - token_stats = loaded.harvest.token_stats - component_key = f"{layer}:{component_idx}" + assert loaded.harvest is not None, "No harvest data available" + token_stats = loaded.harvest.get_token_stats() + if token_stats is None: + return None + component_key = _canonical_to_concrete_key(layer, component_idx, loaded.topology) input_stats = analysis.get_input_token_stats( token_stats, component_key, loaded.tokenizer, top_k @@ -263,9 +302,6 @@ def get_component_token_stats( if input_stats is None or output_stats is None: return None - assert input_stats.bottom_pmi is None, "Input stats should not have bottom PMI" - assert output_stats.bottom_pmi is not None, "Output stats should have bottom PMI" - return TokenStatsResponse( input=TokenPRLiftPMI( top_recall=input_stats.top_recall, @@ -297,8 +333,11 @@ def get_component_correlations( Returns top-k correlations across different metrics (precision, recall, Jaccard, PMI). Returns None if correlations haven't been harvested for this run. """ - correlations = loaded.harvest.correlations - component_key = f"{layer}:{component_idx}" + assert loaded.harvest is not None, "No harvest data available" + correlations = loaded.harvest.get_correlations() + if correlations is None: + raise HTTPException(status_code=404, detail="No correlations data available") + component_key = _canonical_to_concrete_key(layer, component_idx, loaded.topology) if not analysis.has_component(correlations, component_key): raise HTTPException( @@ -307,7 +346,7 @@ def get_component_correlations( def to_schema(c: analysis.CorrelatedComponent) -> CorrelatedComponent: return CorrelatedComponent( - component_key=c.component_key, + component_key=_concrete_to_canonical_key(c.component_key, loaded.topology), score=c.score, count_i=c.count_i, count_j=c.count_j, diff --git a/spd/app/backend/routers/data_sources.py b/spd/app/backend/routers/data_sources.py new file mode 100644 index 000000000..6888b339f --- /dev/null +++ b/spd/app/backend/routers/data_sources.py @@ -0,0 +1,96 @@ +"""Data sources provenance endpoint. + +Shows where harvest/autointerp/attribution data came from: subrun IDs, configs, counts. +""" + +from typing import Any + +from fastapi import APIRouter +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors + + +class HarvestInfo(BaseModel): + subrun_id: str + config: dict[str, Any] + n_components: int + has_intruder_scores: bool + + +class AutointerpInfo(BaseModel): + subrun_id: str + config: dict[str, Any] + n_interpretations: int + eval_scores: list[str] + + +class AttributionsInfo(BaseModel): + subrun_id: str + n_tokens_processed: int + ci_threshold: float + + +class GraphInterpInfo(BaseModel): + subrun_id: str + config: dict[str, Any] | None + label_counts: dict[str, int] + + +class DataSourcesResponse(BaseModel): + harvest: HarvestInfo | None + autointerp: AutointerpInfo | None + attributions: AttributionsInfo | None + graph_interp: GraphInterpInfo | None + + +router = APIRouter(prefix="/api/data_sources", tags=["data_sources"]) + + +@router.get("") +@log_errors +def get_data_sources(loaded: DepLoadedRun) -> DataSourcesResponse: + harvest_info: HarvestInfo | None = None + if loaded.harvest is not None: + harvest_info = HarvestInfo( + subrun_id=loaded.harvest.subrun_id, + config=loaded.harvest.get_config(), + n_components=loaded.harvest.get_component_count(), + has_intruder_scores=bool(loaded.harvest.get_scores("intruder")), + ) + + autointerp_info: AutointerpInfo | None = None + if loaded.interp is not None: + config = loaded.interp.get_config() + if config is not None: + autointerp_info = AutointerpInfo( + subrun_id=loaded.interp.subrun_id, + config=config, + n_interpretations=loaded.interp.get_interpretation_count(), + eval_scores=loaded.interp.get_available_score_types(), + ) + + attributions_info: AttributionsInfo | None = None + if loaded.attributions is not None: + storage = loaded.attributions.get_attributions() + attributions_info = AttributionsInfo( + subrun_id=loaded.attributions.subrun_id, + n_tokens_processed=storage.n_tokens_processed, + ci_threshold=storage.ci_threshold, + ) + + graph_interp_info: GraphInterpInfo | None = None + if loaded.graph_interp is not None: + graph_interp_info = GraphInterpInfo( + subrun_id=loaded.graph_interp.subrun_id, + config=loaded.graph_interp.get_config(), + label_counts=loaded.graph_interp.get_label_counts(), + ) + + return DataSourcesResponse( + harvest=harvest_info, + autointerp=autointerp_info, + attributions=attributions_info, + graph_interp=graph_interp_info, + ) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index fa38f5146..178eefc72 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -7,50 +7,43 @@ from typing import Annotated, Literal from fastapi import APIRouter, HTTPException, Query -from jaxtyping import Float from pydantic import BaseModel -from torch import Tensor, nn from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors -from spd.dataset_attributions.storage import ( - DatasetAttributionEntry as StorageEntry, -) -from spd.dataset_attributions.storage import ( - DatasetAttributionStorage, -) +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage +from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool - n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -59,73 +52,66 @@ class ComponentAttributions(BaseModel): def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" - if not loaded.harvest.has_dataset_attributions(): + if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) - return loaded.harvest.dataset_attributions - - -def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" - if not storage.has_source(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as source in attributions", - ) - + return loaded.attributions.get_attributions() -def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" - if not storage.has_target(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as target in attributions", - ) - -def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" - lm_head = loaded.model.target_model.lm_head - assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" - return lm_head.weight.T.detach() - - -def _to_api_entries(entries: list[StorageEntry]) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format.""" +def _to_api_entries( + entries: list[StorageEntry], loaded: DepLoadedRun +) -> list[DatasetAttributionEntry]: return [ DatasetAttributionEntry( component_key=e.component_key, layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + loaded: DepLoadedRun, + component_key: str, + k: int, + metric: AttrMetric, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric), loaded + ), + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric), loaded + ), + positive_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "positive", metric), loaded + ), + negative_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "negative", metric), loaded + ), + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" - if not loaded.harvest.has_dataset_attributions(): + if loaded.attributions is None: return DatasetAttributionMetadata( available=False, - n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, ci_threshold=None, ) - - storage = loaded.harvest.dataset_attributions + storage = loaded.attributions.get_attributions() return DatasetAttributionMetadata( available=True, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -137,40 +123,18 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all metrics.""" storage = _require_storage(loaded) component_key = f"{layer}:{component_idx}" - # Component can be both a source and a target, so we need to check both - is_source = storage.has_source(component_key) - is_target = storage.has_target(component_key) - - if not is_source and not is_target: - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found in attributions", - ) - - w_unembed = _get_w_unembed(loaded) if is_source else None - - return ComponentAttributions( - positive_sources=_to_api_entries(storage.get_top_sources(component_key, k, "positive")) - if is_target - else [], - negative_sources=_to_api_entries(storage.get_top_sources(component_key, k, "negative")) - if is_target - else [], - positive_targets=_to_api_entries( - storage.get_top_targets(component_key, k, "positive", w_unembed=w_unembed) - ) - if is_source - else [], - negative_targets=_to_api_entries( - storage.get_top_targets(component_key, k, "negative", w_unembed=w_unembed) - ) - if is_source - else [], + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, loaded, component_key, k, metric + ) + for metric in ATTR_METRICS + } ) @@ -182,15 +146,12 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = f"{layer}:{component_idx}" - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if layer == "output" else None - - return _to_api_entries(storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed)) + return _to_api_entries( + storage.get_top_sources(f"{layer}:{component_idx}", k, sign, metric), loaded + ) @router.get("/{layer}/{component_idx}/targets") @@ -201,33 +162,9 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" - storage = _require_storage(loaded) - source_key = f"{layer}:{component_idx}" - _require_source(storage, source_key) - - w_unembed = _get_w_unembed(loaded) - - return _to_api_entries(storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed)) - - -@router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") -@log_errors -def get_attribution_between( - source_layer: str, - source_idx: int, - target_layer: str, - target_idx: int, - loaded: DepLoadedRun, -) -> float: - """Get attribution strength from source component to target component.""" storage = _require_storage(loaded) - source_key = f"{source_layer}:{source_idx}" - target_key = f"{target_layer}:{target_idx}" - _require_source(storage, source_key) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) + return _to_api_entries( + storage.get_top_targets(f"{layer}:{component_idx}", k, sign, metric), loaded + ) diff --git a/spd/app/backend/routers/dataset_search.py b/spd/app/backend/routers/dataset_search.py index 5e9d6c1fa..c079a372c 100644 --- a/spd/app/backend/routers/dataset_search.py +++ b/spd/app/backend/routers/dataset_search.py @@ -1,20 +1,25 @@ -"""Dataset search endpoints for SimpleStories exploration. +"""Dataset search endpoints. -This module provides search functionality for the SimpleStories dataset, -independent of any loaded SPD run. Results are cached in memory for pagination. +Provides search functionality for the training dataset of the loaded run. +The dataset name and text column are read from the run's config. +Results are cached in memory for pagination. """ +import random import time from typing import Annotated, Any +import torch from datasets import Dataset, load_dataset from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel -from spd.app.backend.dependencies import DepStateManager +from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.state import DatasetSearchState from spd.app.backend.utils import log_errors +from spd.configs import LMTaskConfig from spd.log import logger +from spd.utils.distributed_utils import get_device # ============================================================================= # Schemas @@ -22,12 +27,20 @@ class DatasetSearchResult(BaseModel): - """A single search result from the SimpleStories dataset.""" + """A single search result from the dataset.""" - story: str + text: str occurrence_count: int - topic: str | None = None - theme: str | None = None + metadata: dict[str, str] + + +class TokenizedSearchResult(BaseModel): + """A tokenized search result with per-token probability.""" + + tokens: list[str] + next_token_probs: list[float | None] + occurrence_count: int + metadata: dict[str, str] class DatasetSearchMetadata(BaseModel): @@ -35,6 +48,7 @@ class DatasetSearchMetadata(BaseModel): query: str split: str + dataset_name: str total_results: int search_time_seconds: float @@ -49,74 +63,106 @@ class DatasetSearchPage(BaseModel): total_pages: int +class TokenizedSearchPage(BaseModel): + """Paginated tokenized results from a dataset search.""" + + results: list[TokenizedSearchResult] + query: str + page: int + page_size: int + total_results: int + total_pages: int + + router = APIRouter(prefix="/api/dataset", tags=["dataset"]) +def _get_lm_task_config(loaded: DepLoadedRun) -> LMTaskConfig: + """Extract LMTaskConfig from the loaded run, or raise 400.""" + task_config = loaded.config.task_config + if not isinstance(task_config, LMTaskConfig): + raise HTTPException( + status_code=400, + detail=f"Dataset search requires an LM experiment, got {task_config.task_name}", + ) + return task_config + + @router.post("/search") @log_errors def search_dataset( query: Annotated[str, Query(min_length=1)], + loaded: DepLoadedRun, manager: DepStateManager, split: Annotated[str, Query(pattern="^(train|test)$")] = "train", ) -> DatasetSearchMetadata: - """Search SimpleStories dataset for stories containing query string. + """Search the run's training dataset for entries containing query string. + Reads dataset_name and column_name from the loaded run's config. Caches results for pagination via /results endpoint. - Works independently of any loaded run. Args: query: Text to search for (case-insensitive) split: Dataset split to search ("train" or "test") Returns: - Search metadata (query, split, total results, search time) + Search metadata (query, split, dataset_name, total results, search time) """ + task_config = _get_lm_task_config(loaded) + dataset_name = task_config.dataset_name + text_column = task_config.column_name + start_time = time.time() search_query = query.lower() - logger.info(f"Loading SimpleStories dataset (split={split})...") - dataset = load_dataset("lennart-finke/SimpleStories", split=split) + logger.info(f"Loading dataset {dataset_name} (split={split})...") + dataset = load_dataset(dataset_name, split=split) assert isinstance(dataset, Dataset), f"Expected Dataset, got {type(dataset)}" - total_stories = len(dataset) - logger.info(f"Searching {total_stories} stories for '{query}'...") + total_rows = len(dataset) + logger.info(f"Searching {total_rows} rows for '{query}'...") filtered = dataset.filter( - lambda x: search_query in x["story"].lower(), + lambda x: search_query in x[text_column].lower(), num_proc=8, ) + # Collect extra string columns as metadata (skip the text column itself) + column_names = dataset.column_names + metadata_columns = [c for c in column_names if c != text_column] + results: list[dict[str, Any]] = [] for item in filtered: item_dict: dict[str, Any] = dict(item) - story: str = item_dict["story"] + text: str = item_dict[text_column] + row_metadata = { + col: str(item_dict[col]) for col in metadata_columns if item_dict.get(col) is not None + } results.append( { - "story": story, - "occurrence_count": story.lower().count(search_query), - "topic": item_dict.get("topic"), - "theme": item_dict.get("theme"), + "text": text, + "occurrence_count": text.lower().count(search_query), + "metadata": row_metadata, } ) search_time = time.time() - start_time - metadata = DatasetSearchMetadata( + search_metadata = DatasetSearchMetadata( query=query, split=split, + dataset_name=dataset_name, total_results=len(results), search_time_seconds=search_time, ) manager.state.dataset_search_state = DatasetSearchState( results=results, - metadata=metadata.model_dump(), + metadata=search_metadata.model_dump(), ) - logger.info( - f"Found {len(results)} results in {search_time:.2f}s (searched {total_stories} stories)" - ) + logger.info(f"Found {len(results)} results in {search_time:.2f}s (searched {total_rows} rows)") - return metadata + return search_metadata @router.get("/results") @@ -162,3 +208,266 @@ def get_dataset_results( total_results=total_results, total_pages=total_pages, ) + + +@router.get("/results_tokenized") +@log_errors +def get_tokenized_results( + loaded: DepLoadedRun, + manager: DepStateManager, + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=20)] = 10, + max_tokens: Annotated[int, Query(ge=16, le=512)] = 256, +) -> TokenizedSearchPage: + """Get paginated tokenized results with per-token probability. + + Requires a loaded run for model inference. Results are tokenized and + run through the model to compute next-token probabilities. + + Args: + page: Page number (1-indexed) + page_size: Results per page (1-20, lower limit due to model inference) + max_tokens: Maximum tokens per result (truncated if longer) + + Returns: + Paginated tokenized results with probabilities + """ + search_state = manager.state.dataset_search_state + if search_state is None: + raise HTTPException( + status_code=404, + detail="No search results available. Perform a search first.", + ) + + device = get_device() + model = loaded.model + tokenizer = loaded.tokenizer + + total_results = len(search_state.results) + total_pages = max(1, (total_results + page_size - 1) // page_size) + + if page > total_pages and total_results > 0: + raise HTTPException( + status_code=400, + detail=f"Page {page} exceeds total pages {total_pages}", + ) + + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + page_results = search_state.results[start_idx:end_idx] + + tokenized_results: list[TokenizedSearchResult] = [] + + for result in page_results: + story: str = result["story"] + + token_ids = tokenizer.encode(story) + if len(token_ids) > max_tokens: + token_ids = token_ids[:max_tokens] + + if len(token_ids) == 0: + continue + + tokens_tensor = torch.tensor([token_ids], device=device) + + with torch.no_grad(): + logits = model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs: list[float | None] = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(prob) + next_token_probs.append(None) + + token_strings = loaded.tokenizer.get_spans(token_ids) + + # Extract all non-core fields as metadata + metadata = {k: str(v) for k, v in result.items() if k not in ["story", "occurrence_count"]} + + tokenized_results.append( + TokenizedSearchResult( + tokens=token_strings, + next_token_probs=next_token_probs, + occurrence_count=result["occurrence_count"], + metadata=metadata, + ) + ) + + query = search_state.metadata.get("query", "") + + return TokenizedSearchPage( + results=tokenized_results, + query=query, + page=page, + page_size=page_size, + total_results=total_results, + total_pages=total_pages, + ) + + +class RandomSamplesResult(BaseModel): + """Random samples from the dataset.""" + + results: list[DatasetSearchResult] + total_available: int + seed: int + + +@router.get("/random") +@log_errors +def get_random_samples( + n_samples: Annotated[int, Query(ge=1, le=200)] = 100, + seed: Annotated[int, Query(ge=0)] = 42, + split: Annotated[str, Query(pattern="^(train|test)$")] = "train", +) -> RandomSamplesResult: + """Get random samples from the SimpleStories dataset. + + Args: + n_samples: Number of random samples to return (1-200) + seed: Random seed for reproducibility + split: Dataset split ("train" or "test") + + Returns: + Random samples with metadata + """ + logger.info(f"Loading SimpleStories dataset (split={split}) for random sampling...") + dataset = load_dataset("lennart-finke/SimpleStories", split=split) + assert isinstance(dataset, Dataset), f"Expected Dataset, got {type(dataset)}" + + total_available = len(dataset) + actual_samples = min(n_samples, total_available) + + # Generate random indices directly instead of shuffling entire dataset (~100x faster) + rng = random.Random(seed) + indices = rng.sample(range(total_available), actual_samples) + samples = dataset.select(indices) + + results = [] + for item in samples: + item_dict: dict[str, Any] = dict(item) + # Extract text field (usually "story" for SimpleStories, but could be different) + text = item_dict.get("story") or item_dict.get("text", "") + # Extract all non-text fields as metadata + metadata = {k: str(v) for k, v in item_dict.items() if k not in ["story", "text"]} + results.append( + DatasetSearchResult( + text=text, + occurrence_count=0, + metadata=metadata, + ) + ) + + logger.info(f"Returned {len(results)} random samples from {total_available} total stories") + + return RandomSamplesResult( + results=results, + total_available=total_available, + seed=seed, + ) + + +class TokenizedSample(BaseModel): + """A single tokenized sample with per-token next-token probability.""" + + tokens: list[str] + next_token_probs: list[float | None] # Probability of next token; None for last position + metadata: dict[str, str] + + +class RandomSamplesWithLossResult(BaseModel): + """Random samples with tokenized data and next-token probabilities.""" + + results: list[TokenizedSample] + total_available: int + seed: int + + +@router.get("/random_with_loss") +@log_errors +def get_random_samples_with_loss( + loaded: DepLoadedRun, + n_samples: Annotated[int, Query(ge=1, le=50)] = 20, + seed: Annotated[int, Query(ge=0)] = 42, + split: Annotated[str, Query(pattern="^(train|test)$")] = "train", + max_tokens: Annotated[int, Query(ge=16, le=512)] = 256, +) -> RandomSamplesWithLossResult: + """Get random samples with tokenized data and per-token next-token probability. + + This endpoint requires a loaded run (for model and tokenizer). + Each sample is tokenized and run through the model to compute probabilities. + + Args: + n_samples: Number of random samples to return (1-50, lower limit due to model inference) + seed: Random seed for reproducibility + split: Dataset split ("train" or "test") + max_tokens: Maximum tokens per sample (truncated if longer) + + Returns: + Tokenized samples with next-token probability per token + """ + device = get_device() + model = loaded.model + tokenizer = loaded.tokenizer + + logger.info(f"Loading SimpleStories dataset (split={split}) for random sampling with loss...") + dataset = load_dataset("lennart-finke/SimpleStories", split=split) + assert isinstance(dataset, Dataset), f"Expected Dataset, got {type(dataset)}" + + total_available = len(dataset) + actual_samples = min(n_samples, total_available) + + rng = random.Random(seed) + indices = rng.sample(range(total_available), actual_samples) + samples = dataset.select(indices) + + results: list[TokenizedSample] = [] + + for item in samples: + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + + token_ids = tokenizer.encode(story) + if len(token_ids) > max_tokens: + token_ids = token_ids[:max_tokens] + + if len(token_ids) == 0: + continue + + tokens_tensor = torch.tensor([token_ids], device=device) + + with torch.no_grad(): + logits = model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + # Get probability of next token at each position + next_token_probs: list[float | None] = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(prob) + next_token_probs.append(None) # No next token for last position + + token_strings = loaded.tokenizer.get_spans(token_ids) + + # Extract all non-text fields as metadata + metadata = {k: str(v) for k, v in item_dict.items() if k not in ["story", "text"]} + + results.append( + TokenizedSample( + tokens=token_strings, + next_token_probs=next_token_probs, + metadata=metadata, + ) + ) + + logger.info( + f"Returned {len(results)} tokenized samples with CE loss from {total_available} total stories" + ) + + return RandomSamplesWithLossResult( + results=results, + total_available=total_available, + seed=seed, + ) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py new file mode 100644 index 000000000..003075b4d --- /dev/null +++ b/spd/app/backend/routers/graph_interp.py @@ -0,0 +1,244 @@ +"""Graph interpretation endpoints. + +Serves context-aware component labels (output/input/unified) and the +prompt-edge graph produced by the graph_interp pipeline. +""" + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors +from spd.graph_interp.schemas import LabelResult +from spd.topology import TransformerTopology + +MAX_GRAPH_NODES = 500 + + +_ALREADY_CANONICAL = {"embed", "output"} + + +def _concrete_to_canonical_key(concrete_key: str, topology: TransformerTopology) -> str: + layer, idx = concrete_key.rsplit(":", 1) + if layer in _ALREADY_CANONICAL: + return concrete_key + canonical = topology.target_to_canon(layer) + return f"{canonical}:{idx}" + + +def _canonical_to_concrete_key( + canonical_layer: str, component_idx: int, topology: TransformerTopology +) -> str: + concrete = topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +# -- Schemas ------------------------------------------------------------------- + + +class GraphInterpHeadline(BaseModel): + label: str + confidence: str + output_label: str | None + input_label: str | None + + +class LabelDetail(BaseModel): + label: str + confidence: str + reasoning: str + prompt: str + + +class GraphInterpDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + + +class PromptEdgeResponse(BaseModel): + related_key: str + pass_name: str + attribution: float + related_label: str | None + related_confidence: str | None + token_str: str | None + + +class GraphInterpComponentDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + edges: list[PromptEdgeResponse] + + +class GraphNode(BaseModel): + component_key: str + label: str + confidence: str + + +class GraphEdge(BaseModel): + source: str + target: str + attribution: float + pass_name: str + + +class ModelGraphResponse(BaseModel): + nodes: list[GraphNode] + edges: list[GraphEdge] + + +# -- Router -------------------------------------------------------------------- + +router = APIRouter(prefix="/api/graph_interp", tags=["graph_interp"]) + + +@router.get("/labels") +@log_errors +def get_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + repo = loaded.graph_interp + if repo is None: + return {} + + topology = loaded.topology + unified = repo.get_all_unified_labels() + output = repo.get_all_output_labels() + input_ = repo.get_all_input_labels() + + all_keys = set(unified) | set(output) | set(input_) + result: dict[str, GraphInterpHeadline] = {} + + for concrete_key in all_keys: + u = unified.get(concrete_key) + o = output.get(concrete_key) + i = input_.get(concrete_key) + + label = u or o or i + assert label is not None + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + + result[canonical_key] = GraphInterpHeadline( + label=label.label, + confidence=label.confidence, + output_label=o.label if o else None, + input_label=i.label if i else None, + ) + + return result + + +def _to_detail(label: LabelResult | None) -> LabelDetail | None: + if label is None: + return None + return LabelDetail( + label=label.label, + confidence=label.confidence, + reasoning=label.reasoning, + prompt=label.prompt, + ) + + +@router.get("/labels/{layer}/{c_idx}") +@log_errors +def get_label_detail(layer: str, c_idx: int, loaded: DepLoadedRun) -> GraphInterpDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + concrete_key = _canonical_to_concrete_key(layer, c_idx, loaded.topology) + + return GraphInterpDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + ) + + +@router.get("/detail/{layer}/{c_idx}") +@log_errors +def get_component_detail( + layer: str, c_idx: int, loaded: DepLoadedRun +) -> GraphInterpComponentDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + concrete_key = _canonical_to_concrete_key(layer, c_idx, topology) + + raw_edges = repo.get_prompt_edges(concrete_key) + tokenizer = loaded.tokenizer + edges = [] + for e in raw_edges: + rel_layer, rel_idx = e.related_key.rsplit(":", 1) + token_str = tokenizer.decode([int(rel_idx)]) if rel_layer in ("embed", "output") else None + edges.append( + PromptEdgeResponse( + related_key=_concrete_to_canonical_key(e.related_key, topology), + pass_name=e.pass_name, + attribution=e.attribution, + related_label=e.related_label, + related_confidence=e.related_confidence, + token_str=token_str, + ) + ) + + return GraphInterpComponentDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + edges=edges, + ) + + +@router.get("/graph") +@log_errors +def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + + unified = repo.get_all_unified_labels() + nodes = [] + for concrete_key, label in unified.items(): + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + nodes.append( + GraphNode( + component_key=canonical_key, + label=label.label, + confidence=label.confidence, + ) + ) + + nodes = nodes[:MAX_GRAPH_NODES] + node_keys = {n.component_key for n in nodes} + + raw_edges = repo.get_all_prompt_edges() + edges = [] + for e in raw_edges: + comp_canon = _concrete_to_canonical_key(e.component_key, topology) + rel_canon = _concrete_to_canonical_key(e.related_key, topology) + + match e.pass_name: + case "output": + source, target = comp_canon, rel_canon + case "input": + source, target = rel_canon, comp_canon + + if source not in node_keys or target not in node_keys: + continue + + edges.append( + GraphEdge( + source=source, + target=target, + attribution=e.attribution, + pass_name=e.pass_name, + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 4e7b0b15c..12cddb3cb 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -5,8 +5,10 @@ import queue import sys import threading +import time import traceback from collections.abc import Callable, Generator +from dataclasses import dataclass from itertools import groupby from typing import Annotated, Any, Literal @@ -15,19 +17,98 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel +from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, + MAX_OUTPUT_NODES_PER_POS, Edge, + compute_intervention, compute_prompt_attributions, compute_prompt_attributions_optimized, + compute_prompt_attributions_optimized_batched, +) +from spd.app.backend.database import ( + GraphType, + OptimizationParams, + PgdConfig, + PromptAttrDB, + StoredGraph, ) -from spd.app.backend.database import GraphType, OptimizationParams, StoredGraph from spd.app.backend.dependencies import DepLoadedRun, DepStateManager -from spd.app.backend.optim_cis import MaskType, OptimCELossConfig, OptimCIConfig, OptimKLLossConfig +from spd.app.backend.optim_cis import ( + AdvPGDConfig, + CELossConfig, + CISnapshot, + KLLossConfig, + LogitLossConfig, + LossConfig, + MaskType, + MeanKLLossConfig, + OptimCIConfig, + PositionalLossConfig, +) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors -from spd.configs import ImportanceMinimalityLossConfig +from spd.configs import ImportanceMinimalityLossConfig, SamplingType +from spd.log import logger +from spd.models.component_model import ComponentModel +from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device +NON_INTERVENTABLE_LAYERS = {"embed", "output"} + + +def _save_base_intervention_run( + graph_id: int, + model: ComponentModel, + tokens: torch.Tensor, + node_ci_vals: dict[str, float], + tokenizer: AppTokenizer, + topology: TransformerTopology, + db: PromptAttrDB, + sampling: SamplingType, + loss_config: LossConfig | None = None, +) -> None: + """Compute intervention for all interventable nodes and save as an intervention run.""" + interventable_keys = [ + k + for k, ci in node_ci_vals.items() + if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 + ] + if not interventable_keys: + logger.warning( + f"Graph {graph_id}: no interventable nodes with CI > 0, skipping base intervention run" + ) + return + + active_nodes: list[tuple[str, int, int]] = [] + for key in interventable_keys: + canon_layer, seq_str, cidx_str = key.split(":") + concrete_path = topology.canon_to_target(canon_layer) + active_nodes.append((concrete_path, int(seq_str), int(cidx_str))) + + effective_loss_config: LossConfig = ( + loss_config if loss_config is not None else MeanKLLossConfig() + ) + + result = compute_intervention( + model=model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=effective_loss_config, + sampling=sampling, + top_k=10, + ) + + db.save_intervention_run( + graph_id=graph_id, + selected_nodes=interventable_keys, + result_json=result.model_dump_json(), + ) + class EdgeData(BaseModel): """Edge in the attribution graph.""" @@ -54,16 +135,111 @@ class GraphData(BaseModel): graphType: GraphType tokens: list[str] edges: list[EdgeData] + edgesAbs: list[EdgeData] | None = None # absolute-target variant, None for old graphs outputProbs: dict[str, OutputProbability] nodeCiVals: dict[ str, float - ] # node key -> CI value (or output prob for output nodes or 1 for wte node) + ] # node key -> CI value (or output prob for output nodes or 1 for embed node) nodeSubcompActs: dict[str, float] # node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: float # max absolute edge value + maxAbsAttrAbs: float | None = None # max absolute edge value for abs-target variant maxAbsSubcompAct: float # max absolute subcomponent activation for normalization l0_total: int # total active components at current CI threshold +class CELossResult(BaseModel): + """CE loss result (specific token target).""" + + type: Literal["ce"] = "ce" + coeff: float + position: int + label_token: int + label_str: str + + +class KLLossResult(BaseModel): + """KL loss result (distribution matching).""" + + type: Literal["kl"] = "kl" + coeff: float + position: int + + +class LogitLossResult(BaseModel): + """Logit loss result (maximize pre-softmax logit).""" + + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + label_str: str + + +LossType = Literal["ce", "kl", "logit"] +LossResult = CELossResult | KLLossResult | LogitLossResult + + +def _build_loss_config( + loss_type: LossType, + loss_coeff: float, + loss_position: int, + label_token: int | None, +) -> PositionalLossConfig: + match loss_type: + case "ce": + assert label_token is not None, "label_token is required for CE loss" + return CELossConfig(coeff=loss_coeff, position=loss_position, label_token=label_token) + case "kl": + return KLLossConfig(coeff=loss_coeff, position=loss_position) + case "logit": + assert label_token is not None, "label_token is required for logit loss" + return LogitLossConfig( + coeff=loss_coeff, position=loss_position, label_token=label_token + ) + + +def _build_loss_result( + loss_config: PositionalLossConfig, + tok_display: Callable[[int], str], +) -> LossResult: + match loss_config: + case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): + return CELossResult( + coeff=coeff, position=pos, label_token=label_tok, label_str=tok_display(label_tok) + ) + case KLLossConfig(coeff=coeff, position=pos): + return KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + return LogitLossResult( + coeff=coeff, position=pos, label_token=label_tok, label_str=tok_display(label_tok) + ) + + +def _maybe_pgd( + n_steps: int | None, step_size: float | None +) -> tuple[PgdConfig, AdvPGDConfig] | None: + assert (n_steps is None) == (step_size is None), ( + "adv_pgd n_steps and step_size must both be set or both be None" + ) + if n_steps is None: + return None + assert step_size is not None # for narrowing + return PgdConfig(n_steps=n_steps, step_size=step_size), AdvPGDConfig( + n_steps=n_steps, step_size=step_size, init="random" + ) + + +class OptimizationMetricsResult(BaseModel): + """Final loss metrics from CI optimization.""" + + ci_masked_label_prob: float | None = None # Probability of label under CI mask (CE loss only) + stoch_masked_label_prob: float | None = ( + None # Probability of label under stochastic mask (CE loss only) + ) + adv_pgd_label_prob: float | None = None # Probability of label under adversarial mask (CE only) + l0_total: float # Total L0 (active components) + + class OptimizationResult(BaseModel): """Results from optimized CI computation.""" @@ -72,13 +248,9 @@ class OptimizationResult(BaseModel): pnorm: float beta: float mask_type: MaskType - # CE loss params (optional - required together) - label_token: int | None = None - label_str: str | None = None - ce_loss_coeff: float | None = None - label_prob: float | None = None - # KL loss param (optional) - kl_loss_coeff: float | None = None + loss: CELossResult | KLLossResult | LogitLossResult + metrics: OptimizationMetricsResult + pgd: PgdConfig | None = None class GraphDataWithOptimization(GraphData): @@ -116,19 +288,7 @@ class TokenizeResponse(BaseModel): token_ids: list[int] tokens: list[str] text: str - - -class TokenInfo(BaseModel): - """A single token from the tokenizer vocabulary.""" - - id: int - string: str - - -class TokensResponse(BaseModel): - """Response containing all tokens in the vocabulary.""" - - tokens: list[TokenInfo] + next_token_probs: list[float | None] # Probability of next token (last token is None) # SSE streaming message types @@ -162,103 +322,120 @@ class CompleteMessageWithOptimization(BaseModel): data: GraphDataWithOptimization +class BatchGraphResult(BaseModel): + """Batch optimization result containing multiple graphs.""" + + graphs: list[GraphDataWithOptimization] + + router = APIRouter(prefix="/api/graphs", tags=["graphs"]) DEVICE = get_device() # This is a bit of a hack. We want to limit the number of edges returned to avoid overwhelming the frontend. -GLOBAL_EDGE_LIMIT = 5_000 +GLOBAL_EDGE_LIMIT = 50_000 ProgressCallback = Callable[[int, int, str], None] -def build_out_probs( - ci_masked_out_probs: torch.Tensor, +def _build_out_probs( ci_masked_out_logits: torch.Tensor, - target_out_probs: torch.Tensor, target_out_logits: torch.Tensor, - output_prob_threshold: float, - token_strings: dict[int, str], + tok_display: Callable[[int], str], ) -> dict[str, OutputProbability]: - """Build output probs dict from CI-masked and target model tensors. + """Build output probs dict from logit tensors. - Filters by CI-masked probability threshold, but includes both probabilities. - - Args: - ci_masked_out_probs: Shape [seq, vocab] - CI-masked model output probabilities - ci_masked_out_logits: Shape [seq, vocab] - CI-masked model output logits - target_out_probs: Shape [seq, vocab] - Target model output probabilities - target_out_logits: Shape [seq, vocab] - Target model output logits - output_prob_threshold: Threshold for filtering output probabilities - token_strings: Dictionary mapping token IDs to strings + Takes top MAX_OUTPUT_NODES_PER_POS per position (CI slider handles threshold filtering). """ - assert ci_masked_out_probs.ndim == 2, f"Expected [seq, vocab], got {ci_masked_out_probs.shape}" - assert target_out_probs.ndim == 2, f"Expected [seq, vocab], got {target_out_probs.shape}" - assert ci_masked_out_probs.shape == target_out_probs.shape, ( - f"Shape mismatch: {ci_masked_out_probs.shape} vs {target_out_probs.shape}" - ) + ci_masked_out_probs = torch.softmax(ci_masked_out_logits, dim=-1) + target_out_probs = torch.softmax(target_out_logits, dim=-1) out_probs: dict[str, OutputProbability] = {} for s in range(ci_masked_out_probs.shape[0]): - for c_idx in range(ci_masked_out_probs.shape[1]): - prob = float(ci_masked_out_probs[s, c_idx].item()) - if prob < output_prob_threshold: - continue + pos_probs = ci_masked_out_probs[s] + top_vals, top_idxs = torch.topk( + pos_probs, min(MAX_OUTPUT_NODES_PER_POS, pos_probs.shape[0]) + ) + for prob_t, c_idx_t in zip(top_vals, top_idxs, strict=True): + prob = float(prob_t.item()) + c_idx = int(c_idx_t.item()) logit = float(ci_masked_out_logits[s, c_idx].item()) target_prob = float(target_out_probs[s, c_idx].item()) target_logit = float(target_out_logits[s, c_idx].item()) + key = f"{s}:{c_idx}" out_probs[key] = OutputProbability( prob=round(prob, 6), logit=round(logit, 4), target_prob=round(target_prob, 6), target_logit=round(target_logit, 4), - token=token_strings[c_idx], + token=tok_display(c_idx), ) return out_probs +CISnapshotCallback = Callable[[CISnapshot], None] + + def stream_computation( - work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + work: Callable[[ProgressCallback, CISnapshotCallback | None], BaseModel], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: progress_queue.put({"type": "progress", "current": current, "total": total, "stage": stage}) + def on_ci_snapshot(snapshot: CISnapshot) -> None: + progress_queue.put({"type": "ci_snapshot", **snapshot.model_dump()}) + def compute_thread() -> None: try: - result = work(on_progress) + result = work(on_progress, on_ci_snapshot) progress_queue.put({"type": "result", "result": result}) except Exception as e: traceback.print_exc(file=sys.stderr) progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] in ("progress", "ci_snapshot"): + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -266,23 +443,89 @@ def generate() -> Generator[str]: @router.post("/tokenize") @log_errors def tokenize_text(text: str, loaded: DepLoadedRun) -> TokenizeResponse: - """Tokenize text and return tokens for preview (special tokens filtered).""" - token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) + """Tokenize text and return tokens with probability of next token.""" + device = get_device() + token_ids = loaded.tokenizer.encode(text) + + if len(token_ids) == 0: + return TokenizeResponse( + text=text, + token_ids=[], + tokens=[], + next_token_probs=[], + ) + + tokens_tensor = torch.tensor([token_ids], device=device) + + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + # Get probability of next token at each position + next_token_probs: list[float | None] = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(prob) + next_token_probs.append(None) # No next token for last position return TokenizeResponse( text=text, token_ids=token_ids, - tokens=[loaded.token_strings[t] for t in token_ids], + tokens=loaded.tokenizer.get_spans(token_ids), + next_token_probs=next_token_probs, ) -@router.get("/tokens") +class TokenSearchResult(BaseModel): + """A token search result with model probability at the queried position.""" + + id: int + string: str + prob: float + + +class TokenSearchResponse(BaseModel): + """Response from token search endpoint.""" + + tokens: list[TokenSearchResult] + + +@router.get("/tokens/search") @log_errors -def get_all_tokens(loaded: DepLoadedRun) -> TokensResponse: - """Get all tokens in the tokenizer vocabulary for client-side search.""" - return TokensResponse( - tokens=[TokenInfo(id=tid, string=tstr) for tid, tstr in loaded.token_strings.items()] - ) +def search_tokens( + q: Annotated[str, Query(min_length=1)], + prompt_id: Annotated[int, Query()], + position: Annotated[int, Query()], + loaded: DepLoadedRun, + manager: DepStateManager, + limit: Annotated[int, Query(ge=1, le=50)] = 20, +) -> TokenSearchResponse: + """Search tokens by substring match, sorted by target model probability at position.""" + prompt = manager.state.db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail=f"prompt {prompt_id} not found") + if not (0 <= position < len(prompt.token_ids)): + raise HTTPException( + status_code=422, + detail=f"position {position} out of range for prompt with {len(prompt.token_ids)} tokens", + ) + + device = next(loaded.model.parameters()).device + tokens_tensor = torch.tensor([prompt.token_ids], device=device) + with manager.gpu_lock(), torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits[0, position], dim=-1) + + query = q.lower() + matches: list[TokenSearchResult] = [] + for tid in range(loaded.tokenizer.vocab_size): + string = loaded.tokenizer.get_tok_display(tid) + if query in string.lower(): + matches.append(TokenSearchResult(id=tid, string=string, prob=probs[tid].item())) + + matches.sort(key=lambda m: m.prob, reverse=True) + return TokenSearchResponse(tokens=matches[:limit]) NormalizeType = Literal["none", "target", "layer"] @@ -320,8 +563,6 @@ def compute_graph_stream( Args: included_nodes: JSON array of node keys to include (creates manual graph if provided) """ - output_prob_threshold = 0.01 - # Parse and validate included_nodes if provided included_nodes_set: set[str] | None = None included_nodes_list: list[str] | None = None @@ -355,68 +596,92 @@ def compute_graph_stream( raise HTTPException(status_code=404, detail="Prompt not found") token_ids = prompt.token_ids - token_strings = [loaded.token_strings[t] for t in token_ids] + spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) - def work(on_progress: ProgressCallback) -> GraphData: + def work( + on_progress: ProgressCallback, _on_ci_snapshot: CISnapshotCallback | None + ) -> GraphData: + t_total = time.perf_counter() + result = compute_prompt_attributions( model=loaded.model, + topology=loaded.topology, tokens=tokens_tensor, sources_by_target=loaded.sources_by_target, - output_prob_threshold=output_prob_threshold, + output_prob_threshold=0.01, sampling=loaded.config.sampling, device=DEVICE, - show_progress=False, on_progress=on_progress, included_nodes=included_nodes_set, ) - out_probs = build_out_probs( - ci_masked_out_probs=result.ci_masked_out_probs.cpu(), - ci_masked_out_logits=result.ci_masked_out_logits.cpu(), - target_out_probs=result.target_out_probs.cpu(), - target_out_logits=result.target_out_logits.cpu(), - output_prob_threshold=output_prob_threshold, - token_strings=loaded.token_strings, - ) + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + t0 = time.perf_counter() graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( graph_type=graph_type, edges=result.edges, - out_probs=out_probs, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, included_nodes=included_nodes_list, ), ) + logger.info(f"[perf] save_graph: {time.perf_counter() - t0:.2f}s") - filtered_node_ci_vals = {k: v for k, v in result.node_ci_vals.items() if v > ci_threshold} - node_ci_vals_with_pseudo = _add_pseudo_layer_nodes( - filtered_node_ci_vals, len(token_ids), out_probs + t0 = time.perf_counter() + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, ) - edges_data, max_abs_attr = process_edges_for_response( + logger.info(f"[perf] base intervention run: {time.perf_counter() - t0:.2f}s") + + t0 = time.perf_counter() + fg = filter_graph_for_display( raw_edges=result.edges, - normalize=normalize, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + tok_display=loaded.tokenizer.get_tok_display, num_tokens=len(token_ids), - node_ci_vals_with_pseudo=node_ci_vals_with_pseudo, - is_optimized=False, + ci_threshold=ci_threshold, + normalize=normalize, + raw_edges_abs=result.edges_abs, ) + logger.info( + f"[perf] filter_graph: {time.perf_counter() - t0:.2f}s ({len(fg.edges)} edges after filter)" + ) + logger.info(f"[perf] Total graph computation: {time.perf_counter() - t_total:.2f}s") return GraphData( id=graph_id, graphType=graph_type, - tokens=token_strings, - edges=edges_data, - outputProbs=out_probs, - nodeCiVals=node_ci_vals_with_pseudo, + tokens=spans, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, - maxAbsAttr=max_abs_attr, - maxAbsSubcompAct=compute_max_abs_subcomp_act(result.node_subcomp_acts), - l0_total=len(filtered_node_ci_vals), + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -469,35 +734,25 @@ def compute_graph_optimized_stream( pnorm: Annotated[float, Query(gt=0)], beta: Annotated[float, Query(ge=0)], normalize: Annotated[NormalizeType, Query()], - output_prob_threshold: Annotated[float, Query(ge=0, le=1)], loaded: DepLoadedRun, manager: DepStateManager, ci_threshold: Annotated[float, Query()], - mask_type: Annotated[MaskType, Query()] = "stochastic", - # Optional CE loss params (required together) + mask_type: Annotated[MaskType, Query()], + loss_type: Annotated[LossType, Query()], + loss_coeff: Annotated[float, Query(gt=0)], + loss_position: Annotated[int, Query(ge=0)], label_token: Annotated[int | None, Query()] = None, - ce_loss_coeff: Annotated[float | None, Query(gt=0)] = None, - # Optional KL loss param - kl_loss_coeff: Annotated[float | None, Query(gt=0)] = None, + adv_pgd_n_steps: Annotated[int | None, Query(gt=0)] = None, + adv_pgd_step_size: Annotated[float | None, Query(gt=0)] = None, ): """Compute optimized attribution graph for a prompt with streaming progress. - At least one of (ce_loss_coeff, kl_loss_coeff) must be provided. - If ce_loss_coeff is provided, label_token is also required. + loss_type determines whether to use CE (cross-entropy for specific token) or KL (distribution matching). + label_token is required when loss_type is "ce". + adv_pgd_n_steps and adv_pgd_step_size enable adversarial PGD when both are provided. """ - # Validation - if ce_loss_coeff is None and kl_loss_coeff is None: - raise HTTPException( - status_code=400, - detail="At least one of ce_loss_coeff or kl_loss_coeff must be provided", - ) - if ce_loss_coeff is not None and label_token is None: - raise HTTPException( - status_code=400, - detail="label_token is required when ce_loss_coeff is provided", - ) - - lr = 1e-2 + loss_config = _build_loss_config(loss_type, loss_coeff, loss_position, label_token) + pgd_configs = _maybe_pgd(adv_pgd_n_steps, adv_pgd_step_size) db = manager.db prompt = db.get_prompt(prompt_id) @@ -505,32 +760,31 @@ def compute_graph_optimized_stream( raise HTTPException(status_code=404, detail="Prompt not found") token_ids = prompt.token_ids - label_str = loaded.token_strings[label_token] if label_token is not None else None - token_strings = [loaded.token_strings[t] for t in token_ids] + if loss_position >= len(token_ids): + raise HTTPException( + status_code=400, + detail=f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens", + ) + + spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) + num_tokens = loss_position + 1 + spans_sliced = spans[:num_tokens] + opt_params = OptimizationParams( imp_min_coeff=imp_min_coeff, steps=steps, pnorm=pnorm, beta=beta, mask_type=mask_type, - label_token=label_token, - ce_loss_coeff=ce_loss_coeff, - kl_loss_coeff=kl_loss_coeff, + loss=loss_config, + pgd=pgd_configs[0] if pgd_configs else None, ) - ce_loss_config: OptimCELossConfig | None = None - if ce_loss_coeff is not None: - assert label_token is not None - ce_loss_config = OptimCELossConfig(coeff=ce_loss_coeff, label_token=label_token) - kl_loss_config: OptimKLLossConfig | None = None - if kl_loss_coeff is not None: - kl_loss_config = OptimKLLossConfig(coeff=kl_loss_coeff) - optim_config = OptimCIConfig( seed=0, - lr=lr, + lr=1e-2, steps=steps, weight_decay=0.0, lr_schedule="cosine", @@ -538,203 +792,448 @@ def compute_graph_optimized_stream( lr_warmup_pct=0.01, log_freq=max(1, steps // 4), imp_min_config=ImportanceMinimalityLossConfig(coeff=imp_min_coeff, pnorm=pnorm, beta=beta), - ce_loss_config=ce_loss_config, - kl_loss_config=kl_loss_config, + loss_config=loss_config, sampling=loaded.config.sampling, ce_kl_rounding_threshold=0.5, mask_type=mask_type, + adv_pgd=pgd_configs[1] if pgd_configs else None, ) - def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> GraphDataWithOptimization: result = compute_prompt_attributions_optimized( model=loaded.model, + topology=loaded.topology, tokens=tokens_tensor, sources_by_target=loaded.sources_by_target, optim_config=optim_config, - output_prob_threshold=output_prob_threshold, + output_prob_threshold=0.01, device=DEVICE, - show_progress=False, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) - out_probs = build_out_probs( - ci_masked_out_probs=result.ci_masked_out_probs.cpu(), - ci_masked_out_logits=result.ci_masked_out_logits.cpu(), - target_out_probs=result.target_out_probs.cpu(), - target_out_logits=result.target_out_logits.cpu(), - output_prob_threshold=output_prob_threshold, - token_strings=loaded.token_strings, - ) + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob + graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( graph_type="optimized", edges=result.edges, - out_probs=out_probs, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, optimization_params=opt_params, - label_prob=result.label_prob, ), ) - filtered_node_ci_vals = {k: v for k, v in result.node_ci_vals.items() if v > ci_threshold} - node_ci_vals_with_pseudo = _add_pseudo_layer_nodes( - filtered_node_ci_vals, len(token_ids), out_probs + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, ) - edges_data, max_abs_attr = process_edges_for_response( + + fg = filter_graph_for_display( raw_edges=result.edges, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + tok_display=loaded.tokenizer.get_tok_display, + num_tokens=num_tokens, + ci_threshold=ci_threshold, normalize=normalize, - num_tokens=len(token_ids), - node_ci_vals_with_pseudo=node_ci_vals_with_pseudo, - is_optimized=True, + raw_edges_abs=result.edges_abs, ) return GraphDataWithOptimization( id=graph_id, graphType="optimized", - tokens=token_strings, - edges=edges_data, - outputProbs=out_probs, - nodeCiVals=node_ci_vals_with_pseudo, + tokens=spans_sliced, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, - maxAbsAttr=max_abs_attr, - maxAbsSubcompAct=compute_max_abs_subcomp_act(result.node_subcomp_acts), - l0_total=len(filtered_node_ci_vals), + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, optimization=OptimizationResult( imp_min_coeff=imp_min_coeff, steps=steps, pnorm=pnorm, beta=beta, mask_type=mask_type, - label_token=label_token, - label_str=label_str, - ce_loss_coeff=ce_loss_coeff, - label_prob=result.label_prob, - kl_loss_coeff=kl_loss_coeff, + loss=_build_loss_result(loss_config, loaded.tokenizer.get_tok_display), + metrics=OptimizationMetricsResult( + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + l0_total=result.metrics.l0_total, + ), + pgd=pgd_configs[0] if pgd_configs else None, ), ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) -def _add_pseudo_layer_nodes( - node_ci_vals: dict[str, float], - num_tokens: int, - out_probs: dict[str, OutputProbability], -) -> dict[str, float]: - """Add wte and output pseudo-nodes for simpler rendering and filtering logic. +class BatchOptimizedRequest(BaseModel): + """Request body for batch optimized graph computation.""" - wte nodes get CI=1.0 (always visible), output nodes use their CI-masked probability. + prompt_id: int + imp_min_coeffs: list[float] + steps: int + pnorm: float + beta: float + normalize: NormalizeType + ci_threshold: float + mask_type: MaskType + loss_type: LossType + loss_coeff: float + loss_position: int + label_token: int | None = None + adv_pgd_n_steps: int | None = None + adv_pgd_step_size: float | None = None + + +@router.post("/optimized/batch/stream") +@log_errors +def compute_graph_optimized_batch_stream( + body: BatchOptimizedRequest, + loaded: DepLoadedRun, + manager: DepStateManager, +): + """Compute optimized graphs for multiple sparsity coefficients in one batched optimization. + + Returns N graphs (one per imp_min_coeff) via SSE streaming. + All coefficients share the same loss config, steps, and other hyperparameters. """ - result = dict(node_ci_vals) - for seq_pos in range(num_tokens): - result[f"wte:{seq_pos}:0"] = 1.0 - for key, out_prob in out_probs.items(): - seq_pos, token_id = key.split(":") - result[f"output:{seq_pos}:{token_id}"] = out_prob.prob - return result + assert len(body.imp_min_coeffs) > 0, "At least one coefficient required" + assert len(body.imp_min_coeffs) <= 20, "Too many coefficients (max 20)" + + loss_config = _build_loss_config( + body.loss_type, body.loss_coeff, body.loss_position, body.label_token + ) + pgd_configs = _maybe_pgd(body.adv_pgd_n_steps, body.adv_pgd_step_size) + + db = manager.db + prompt = db.get_prompt(body.prompt_id) + assert prompt is not None, f"prompt {body.prompt_id} not found" + + token_ids = prompt.token_ids + assert body.loss_position < len(token_ids), ( + f"loss_position {body.loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + spans = loaded.tokenizer.get_spans(token_ids) + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + num_tokens = body.loss_position + 1 + spans_sliced = spans[:num_tokens] + + configs = [ + OptimCIConfig( + seed=0, + lr=1e-2, + steps=body.steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, body.steps // 4), + imp_min_config=ImportanceMinimalityLossConfig( + coeff=coeff, pnorm=body.pnorm, beta=body.beta + ), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type=body.mask_type, + adv_pgd=pgd_configs[1] if pgd_configs else None, + ) + for coeff in body.imp_min_coeffs + ] + + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> BatchGraphResult: + results = compute_prompt_attributions_optimized_batched( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + configs=configs, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + graphs: list[GraphDataWithOptimization] = [] + for result, coeff in zip(results, body.imp_min_coeffs, strict=True): + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + opt_params = OptimizationParams( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=loss_config, + pgd=pgd_configs[0] if pgd_configs else None, + ) + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob + + graph_id = db.save_graph( + prompt_id=body.prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + + fg = filter_graph_for_display( + raw_edges=result.edges, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + tok_display=loaded.tokenizer.get_tok_display, + num_tokens=num_tokens, + ci_threshold=body.ci_threshold, + normalize=body.normalize, + raw_edges_abs=result.edges_abs, + ) + + graphs.append( + GraphDataWithOptimization( + id=graph_id, + graphType="optimized", + tokens=spans_sliced, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, + nodeSubcompActs=result.node_subcomp_acts, + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, + optimization=OptimizationResult( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=_build_loss_result(loss_config, loaded.tokenizer.get_tok_display), + metrics=OptimizationMetricsResult( + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + l0_total=result.metrics.l0_total, + ), + pgd=pgd_configs[0] if pgd_configs else None, + ), + ) + ) + + return BatchGraphResult(graphs=graphs) + + return stream_computation(work, manager._gpu_lock) + +@dataclass +class FilteredGraph: + """Result of filtering a raw graph for display.""" -def process_edges_for_response( + edges: list[EdgeData] + edges_abs: list[EdgeData] | None # absolute-target variant, None for old graphs + node_ci_vals: dict[str, float] # with pseudo nodes + out_probs: dict[str, OutputProbability] + max_abs_attr: float + max_abs_attr_abs: float | None # max abs for absolute-target edges + max_abs_subcomp_act: float + l0_total: int + + +def filter_graph_for_display( raw_edges: list[Edge], - normalize: NormalizeType, + node_ci_vals: dict[str, float], + node_subcomp_acts: dict[str, float], + ci_masked_out_logits: torch.Tensor, + target_out_logits: torch.Tensor, + tok_display: Callable[[int], str], num_tokens: int, - node_ci_vals_with_pseudo: dict[str, float], - is_optimized: bool, + ci_threshold: float, + normalize: NormalizeType, + raw_edges_abs: list[Edge] | None = None, edge_limit: int = GLOBAL_EDGE_LIMIT, -) -> tuple[list[EdgeData], float]: - """Process edges: filter by CI, normalize, and limit.""" - - # Filter to final seq position for optimized graphs - if is_optimized: - final_seq_pos = num_tokens - 1 - raw_edges = [e for e in raw_edges if e.target.seq_pos == final_seq_pos] - - # Only include edges that connect to nodes in node_ci_vals_with_pseudo - node_keys = set(node_ci_vals_with_pseudo.keys()) - edges = [e for e in raw_edges if str(e.source) in node_keys and str(e.target) in node_keys] +) -> FilteredGraph: + """Filter and transform a raw attribution graph for display. + + 1. Build out_probs from logit tensors (top MAX_OUTPUT_NODES_PER_POS per position) + 2. Filter component nodes by CI threshold + 3. Add embed (CI=1.0) and output (CI=prob) pseudo-nodes + 4. Drop edges not connecting surviving nodes + 5. Normalize edge strengths (if requested) + 6. Cap edges at edge_limit + """ + out_probs = _build_out_probs(ci_masked_out_logits, target_out_logits, tok_display) - edges = _normalize_edges(edges=edges, normalize=normalize) - max_abs_attr = compute_max_abs_attr(edges=edges) + filtered_node_ci_vals = {k: v for k, v in node_ci_vals.items() if v > ci_threshold} - if len(edges) > edge_limit: - print(f"[WARNING] Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") - edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True)[:edge_limit] + # Add pseudo-nodes: embed always visible, output nodes use their probability + node_ci_vals_with_pseudo = dict(filtered_node_ci_vals) + for seq_pos in range(num_tokens): + node_ci_vals_with_pseudo[f"embed:{seq_pos}:0"] = 1.0 + for key, out_prob in out_probs.items(): + seq_pos, token_id = key.split(":") + node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob - edges_data = [_edge_to_edge_data(e) for e in edges] + # Filter, normalize, sort, and truncate an edge list to the surviving node set. + node_keys = set(node_ci_vals_with_pseudo.keys()) - return edges_data, max_abs_attr + def _filter_edges(raw: list[Edge]) -> tuple[list[EdgeData], float]: + filtered = [e for e in raw if str(e.source) in node_keys and str(e.target) in node_keys] + filtered = _normalize_edges(edges=filtered, normalize=normalize) + max_abs = compute_max_abs_attr(edges=filtered) + filtered = sorted(filtered, key=lambda e: abs(e.strength), reverse=True) + if len(filtered) > edge_limit: + logger.warning(f"Edge limit {edge_limit} exceeded ({len(filtered)} edges), truncating") + filtered = filtered[:edge_limit] + return [_edge_to_edge_data(e) for e in filtered], max_abs + + edges_out, max_abs_attr = _filter_edges(raw_edges) + + edges_abs_out: list[EdgeData] | None = None + max_abs_attr_abs: float | None = None + if raw_edges_abs is not None: + edges_abs_out, max_abs_attr_abs = _filter_edges(raw_edges_abs) + + return FilteredGraph( + edges=edges_out, + edges_abs=edges_abs_out, + node_ci_vals=node_ci_vals_with_pseudo, + out_probs=out_probs, + max_abs_attr=max_abs_attr, + max_abs_attr_abs=max_abs_attr_abs, + max_abs_subcomp_act=compute_max_abs_subcomp_act(node_subcomp_acts), + l0_total=len(filtered_node_ci_vals), + ) def stored_graph_to_response( graph: StoredGraph, token_ids: list[int], - token_strings_map: dict[int, str], + tokenizer: AppTokenizer, normalize: NormalizeType, ci_threshold: float, ) -> GraphData | GraphDataWithOptimization: """Convert a StoredGraph to API response format.""" - token_strings = [token_strings_map[t] for t in token_ids] + spans = tokenizer.get_spans(token_ids) num_tokens = len(token_ids) is_optimized = graph.optimization_params is not None - filtered_node_ci_vals = {k: v for k, v in graph.node_ci_vals.items() if v > ci_threshold} - l0_total = len(filtered_node_ci_vals) + if is_optimized: + assert graph.optimization_params is not None + num_tokens = graph.optimization_params.loss.position + 1 + spans = spans[:num_tokens] - node_ci_vals_with_pseudo = _add_pseudo_layer_nodes( - filtered_node_ci_vals, num_tokens, graph.out_probs - ) - edges_data, max_abs_attr = process_edges_for_response( + fg = filter_graph_for_display( raw_edges=graph.edges, - normalize=normalize, + node_ci_vals=graph.node_ci_vals, + node_subcomp_acts=graph.node_subcomp_acts, + ci_masked_out_logits=graph.ci_masked_out_logits, + target_out_logits=graph.target_out_logits, + tok_display=tokenizer.get_tok_display, num_tokens=num_tokens, - node_ci_vals_with_pseudo=node_ci_vals_with_pseudo, - is_optimized=is_optimized, + ci_threshold=ci_threshold, + normalize=normalize, + raw_edges_abs=graph.edges_abs, ) if not is_optimized: return GraphData( id=graph.id, graphType=graph.graph_type, - tokens=token_strings, - edges=edges_data, - outputProbs=graph.out_probs, - nodeCiVals=node_ci_vals_with_pseudo, + tokens=spans, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, - maxAbsAttr=max_abs_attr, - maxAbsSubcompAct=compute_max_abs_subcomp_act(graph.node_subcomp_acts), - l0_total=l0_total, + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, ) assert graph.optimization_params is not None - - label_str: str | None = None - if graph.optimization_params.label_token is not None: - label_str = token_strings_map[graph.optimization_params.label_token] + opt = graph.optimization_params return GraphDataWithOptimization( id=graph.id, graphType=graph.graph_type, - tokens=token_strings, - edges=edges_data, - outputProbs=graph.out_probs, - nodeCiVals=node_ci_vals_with_pseudo, + tokens=spans, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, - maxAbsAttr=max_abs_attr, - maxAbsSubcompAct=compute_max_abs_subcomp_act(graph.node_subcomp_acts), - l0_total=l0_total, + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, optimization=OptimizationResult( - imp_min_coeff=graph.optimization_params.imp_min_coeff, - steps=graph.optimization_params.steps, - pnorm=graph.optimization_params.pnorm, - beta=graph.optimization_params.beta, - mask_type=graph.optimization_params.mask_type, - label_token=graph.optimization_params.label_token, - label_str=label_str, - ce_loss_coeff=graph.optimization_params.ce_loss_coeff, - label_prob=graph.label_prob, - kl_loss_coeff=graph.optimization_params.kl_loss_coeff, + imp_min_coeff=opt.imp_min_coeff, + steps=opt.steps, + pnorm=opt.pnorm, + beta=opt.beta, + mask_type=opt.mask_type, + loss=_build_loss_result(opt.loss, tokenizer.get_tok_display), + metrics=OptimizationMetricsResult( + l0_total=float(fg.l0_total), + ci_masked_label_prob=opt.ci_masked_label_prob, + stoch_masked_label_prob=opt.stoch_masked_label_prob, + adv_pgd_label_prob=opt.adv_pgd_label_prob, + ), + pgd=opt.pgd, ), ) @@ -763,7 +1262,7 @@ def get_graphs( stored_graph_to_response( graph=graph, token_ids=prompt.token_ids, - token_strings_map=loaded.token_strings, + tokenizer=loaded.tokenizer, normalize=normalize, ci_threshold=ci_threshold, ) diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 4c46e136c..e26a73462 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -4,9 +4,14 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from spd.app.backend.compute import compute_intervention_forward +from spd.app.backend.compute import ( + InterventionResult, + compute_intervention, +) from spd.app.backend.dependencies import DepDB, DepLoadedRun, DepStateManager +from spd.app.backend.optim_cis import AdvPGDConfig, LossConfig, MeanKLLossConfig from spd.app.backend.utils import log_errors +from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device # ============================================================================= @@ -14,56 +19,19 @@ # ============================================================================= -class InterventionNode(BaseModel): - """A specific node to activate during intervention.""" - - layer: str - seq_pos: int - component_idx: int - - -class InterventionRequest(BaseModel): - """Request for intervention forward pass.""" - - text: str - nodes: list[InterventionNode] - top_k: int - - -class TokenPrediction(BaseModel): - """A single token prediction with probability.""" - - token: str - token_id: int - spd_prob: float - target_prob: float - logit: float - target_logit: float - - -class InterventionResponse(BaseModel): - """Response from intervention forward pass.""" - - input_tokens: list[str] - predictions_per_position: list[list[TokenPrediction]] +class AdvPgdParams(BaseModel): + n_steps: int + step_size: float class RunInterventionRequest(BaseModel): """Request to run and save an intervention.""" graph_id: int - text: str selected_nodes: list[str] # node keys (layer:seq:cIdx) - top_k: int = 10 - - -class ForkedInterventionRunSummary(BaseModel): - """Summary of a forked intervention run with modified tokens.""" - - id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result: InterventionResponse - created_at: str + nodes_to_ablate: list[str] | None = None # node keys to ablate in ablated (omit to skip) + top_k: int + adv_pgd: AdvPgdParams class InterventionRunSummary(BaseModel): @@ -71,16 +39,8 @@ class InterventionRunSummary(BaseModel): id: int selected_nodes: list[str] - result: InterventionResponse + result: InterventionResult created_at: str - forked_runs: list[ForkedInterventionRunSummary] - - -class ForkInterventionRequest(BaseModel): - """Request to fork an intervention run with modified tokens.""" - - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - top_k: int = 10 router = APIRouter(prefix="/api/intervention", tags=["intervention"]) @@ -88,105 +48,30 @@ class ForkInterventionRequest(BaseModel): DEVICE = get_device() -def _parse_node_key(key: str) -> tuple[str, int, int]: - """Parse 'layer:seq:cIdx' into (layer, seq_pos, component_idx).""" +def _parse_node_key(key: str, topology: TransformerTopology) -> tuple[str, int, int]: + """Parse canonical node key into (concrete_path, seq_pos, component_idx). + + Translates canonical layer address back to concrete module path for ComponentModel. + """ parts = key.split(":") assert len(parts) == 3, f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')" - layer, seq_str, cidx_str = parts - # wte and output are pseudo-layers for visualization only - not interventable - assert layer not in ("wte", "output"), ( - f"Cannot intervene on {layer!r} nodes - only internal layers (attn/mlp) are interventable" - ) - return layer, int(seq_str), int(cidx_str) - - -def _run_intervention_forward( - text: str, - selected_nodes: list[str], - top_k: int, - loaded: DepLoadedRun, -) -> InterventionResponse: - """Run intervention forward pass and return response.""" - token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [_parse_node_key(key) for key in selected_nodes] - - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, + canonical_layer, seq_str, cidx_str = parts + assert canonical_layer not in ("embed", "output"), ( + f"Cannot intervene on {canonical_layer!r} nodes - only internal layers are interventable" ) + concrete_path = topology.canon_to_target(canonical_layer) + return concrete_path, int(seq_str), int(cidx_str) -@router.post("") -@log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: - """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [(n.layer, n.seq_pos, n.component_idx) for n in request.nodes] - - seq_len = tokens.shape[1] +def _parse_and_validate_active_nodes( + selected_nodes: list[str], topology: TransformerTopology, seq_len: int +) -> list[tuple[str, int, int]]: + """Parse node keys and validate sequence bounds for the current prompt.""" + active_nodes = [_parse_node_key(key, topology) for key in selected_nodes] for _, seq_pos, _ in active_nodes: if seq_pos >= seq_len: raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return active_nodes @router.post("/run") @@ -195,19 +80,59 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + graph_record = db.get_graph(request.graph_id) + if graph_record is None: + raise HTTPException(status_code=404, detail="Graph not found") + graph, prompt_id = graph_record + + prompt = db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail="Prompt not found") + + token_ids = prompt.token_ids + active_nodes = _parse_and_validate_active_nodes( + request.selected_nodes, loaded.topology, len(token_ids) + ) + nodes_to_ablate = ( + _parse_and_validate_active_nodes( + request.nodes_to_ablate, loaded.topology, len(token_ids) + ) + if request.nodes_to_ablate is not None + else None + ) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Use graph's loss config if optimized, else mean KL + loss_config: LossConfig = ( + graph.optimization_params.loss + if graph.optimization_params is not None + else MeanKLLossConfig() + ) + + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=nodes_to_ablate, + tokenizer=loaded.tokenizer, + adv_pgd_config=AdvPGDConfig( + n_steps=request.adv_pgd.n_steps, + step_size=request.adv_pgd.step_size, + init="random", + ), + loss_config=loss_config, + sampling=loaded.config.sampling, + top_k=request.top_k, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, selected_nodes=request.selected_nodes, - result_json=response.model_dump_json(), + result_json=result.model_dump_json(), ) record = db.get_intervention_runs(request.graph_id) @@ -217,41 +142,25 @@ def run_and_save_intervention( return InterventionRunSummary( id=run_id, selected_nodes=request.selected_nodes, - result=response, + result=result, created_at=saved_run.created_at, - forked_runs=[], ) @router.get("/runs/{graph_id}") @log_errors def get_intervention_runs(graph_id: int, db: DepDB) -> list[InterventionRunSummary]: - """Get all intervention runs for a graph, including forked runs.""" + """Get all intervention runs for a graph.""" records = db.get_intervention_runs(graph_id) - results = [] - for r in records: - # Get forked runs for this intervention run - forked_records = db.get_forked_intervention_runs(r.id) - forked_runs = [ - ForkedInterventionRunSummary( - id=fr.id, - token_replacements=fr.token_replacements, - result=InterventionResponse.model_validate_json(fr.result_json), - created_at=fr.created_at, - ) - for fr in forked_records - ] - - results.append( - InterventionRunSummary( - id=r.id, - selected_nodes=r.selected_nodes, - result=InterventionResponse.model_validate_json(r.result_json), - created_at=r.created_at, - forked_runs=forked_runs, - ) + return [ + InterventionRunSummary( + id=r.id, + selected_nodes=r.selected_nodes, + result=InterventionResult.model_validate_json(r.result_json), + created_at=r.created_at, ) - return results + for r in records + ] @router.delete("/runs/{run_id}") @@ -260,86 +169,3 @@ def delete_intervention_run(run_id: int, db: DepDB) -> dict[str, bool]: """Delete an intervention run.""" db.delete_intervention_run(run_id) return {"success": True} - - -@router.post("/runs/{run_id}/fork") -@log_errors -def fork_intervention_run( - run_id: int, - request: ForkInterventionRequest, - loaded: DepLoadedRun, - manager: DepStateManager, -) -> ForkedInterventionRunSummary: - """Fork an intervention run with modified tokens. - - Takes the same selected_nodes from the parent run, applies token replacements - to the original prompt, and runs the intervention forward pass. - """ - db = manager.db - - # Get the parent intervention run - parent_run = db.get_intervention_run(run_id) - if parent_run is None: - raise HTTPException(status_code=404, detail="Intervention run not found") - - # Get the prompt_id from the graph - conn = db._get_conn() - row = conn.execute( - "SELECT prompt_id FROM graphs WHERE id = ?", (parent_run.graph_id,) - ).fetchone() - if row is None: - raise HTTPException(status_code=404, detail="Graph not found") - prompt_id = row["prompt_id"] - - # Get the prompt to get original token_ids - prompt = db.get_prompt(prompt_id) - if prompt is None: - raise HTTPException(status_code=404, detail="Prompt not found") - - # Apply token replacements to get modified token_ids - modified_token_ids = list(prompt.token_ids) # Make a copy - for seq_pos, new_token_id in request.token_replacements: - if seq_pos < 0 or seq_pos >= len(modified_token_ids): - raise HTTPException( - status_code=400, - detail=f"Invalid seq_pos {seq_pos} for prompt with {len(modified_token_ids)} tokens", - ) - modified_token_ids[seq_pos] = new_token_id - - # Decode the modified tokens back to text - modified_text = loaded.tokenizer.decode(modified_token_ids) - - # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) - - # Save the forked run - fork_id = db.save_forked_intervention_run( - intervention_run_id=run_id, - token_replacements=request.token_replacements, - result_json=response.model_dump_json(), - ) - - # Get the saved record for created_at - forked_records = db.get_forked_intervention_runs(run_id) - saved_fork = next((f for f in forked_records if f.id == fork_id), None) - assert saved_fork is not None - - return ForkedInterventionRunSummary( - id=fork_id, - token_replacements=request.token_replacements, - result=response, - created_at=saved_fork.created_at, - ) - - -@router.delete("/forks/{fork_id}") -@log_errors -def delete_forked_intervention_run(fork_id: int, db: DepDB) -> dict[str, bool]: - """Delete a forked intervention run.""" - db.delete_forked_intervention_run(fork_id) - return {"success": True} diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..296452b46 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,313 @@ +"""Investigations endpoint for viewing agent investigation results. + +Lists and serves investigation data from SPD_OUT_DIR/investigations/. +Each investigation directory contains findings from a single agent run. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +INVESTIGATIONS_DIR = SPD_OUT_DIR / "investigations" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + title: str | None + summary: str | None + status: str | None + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + artifact_ids: list[str] + title: str | None + summary: str | None + status: str | None + + +def _parse_metadata(inv_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from an investigation directory.""" + metadata_path = inv_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except json.JSONDecodeError: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + + return last_time, last_msg, count + + +def _parse_task_summary(inv_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from an investigation directory. Returns (title, summary, status).""" + summary_path = inv_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except json.JSONDecodeError: + return None, None, None + + +def _list_artifact_ids(inv_path: Path) -> list[str]: + """List all artifact IDs for an investigation.""" + artifacts_dir = inv_path / "artifacts" + if not artifacts_dir.exists(): + return [] + return [f.stem for f in sorted(artifacts_dir.glob("graph_*.json"))] + + +def _get_created_at(inv_path: Path, metadata: dict[str, Any] | None) -> str: + """Get creation time for an investigation.""" + events_path = inv_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except json.JSONDecodeError: + pass + + if metadata and "created_at" in metadata: + return metadata["created_at"] + + return datetime.fromtimestamp(inv_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations(loaded: DepLoadedRun) -> list[InvestigationSummary]: + """List investigations for the currently loaded run.""" + if not INVESTIGATIONS_DIR.exists(): + return [] + + wandb_path = loaded.run.wandb_path + results = [] + + for inv_path in INVESTIGATIONS_DIR.iterdir(): + if not inv_path.is_dir() or not inv_path.name.startswith("inv-"): + continue + + inv_id = inv_path.name + metadata = _parse_metadata(inv_path) + + meta_wandb_path = metadata.get("wandb_path") if metadata else None + if meta_wandb_path is None: + continue + # Normalize to canonical form for comparison (strips "runs/", "wandb:" prefix, etc.) + try: + e, p, r = parse_wandb_run_path(meta_wandb_path) + canonical_meta_path = f"{e}/{p}/{r}" + except ValueError: + continue + if canonical_meta_path != wandb_path: + continue + + events_path = inv_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(inv_path) + + explanations_path = inv_path / "explanations.jsonl" + + results.append( + InvestigationSummary( + id=inv_id, + wandb_path=meta_wandb_path, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + has_research_log=(inv_path / "research_log.md").exists(), + has_explanations=explanations_path.exists() + and explanations_path.stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{inv_id}") +def get_investigation(inv_id: str) -> InvestigationDetail: + """Get full details of an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + + if not inv_path.exists() or not inv_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + + metadata = _parse_metadata(inv_path) + + research_log = None + research_log_path = inv_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + events = [] + events_path = inv_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + explanations: list[dict[str, Any]] = [] + explanations_path = inv_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(inv_path) + artifact_ids = _list_artifact_ids(inv_path) + + return InvestigationDetail( + id=inv_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + artifact_ids=artifact_ids, + title=title, + summary=summary, + status=status, + ) + + +class LaunchRequest(BaseModel): + prompt: str + + +class LaunchResponse(BaseModel): + inv_id: str + job_id: str + + +@router.post("/launch") +def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) -> LaunchResponse: + """Launch a new investigation for the currently loaded run.""" + from spd.investigate.scripts.run_slurm import launch_investigation + + result = launch_investigation( + wandb_path=loaded.run.wandb_path, + prompt=request.prompt, + context_length=loaded.context_length, + max_turns=50, + time="8:00:00", + job_suffix=None, + ) + return LaunchResponse(inv_id=result.inv_id, job_id=result.job_id) + + +@router.get("/{inv_id}/artifacts") +def list_artifacts(inv_id: str) -> list[str]: + """List all artifact IDs for an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + if not inv_path.exists(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + return _list_artifact_ids(inv_path) + + +@router.get("/{inv_id}/artifacts/{artifact_id}") +def get_artifact(inv_id: str, artifact_id: str) -> dict[str, Any]: + """Get a specific artifact by ID.""" + inv_path = INVESTIGATIONS_DIR / inv_id + artifact_path = inv_path / "artifacts" / f"{artifact_id}.json" + + if not artifact_path.exists(): + raise HTTPException( + status_code=404, + detail=f"Artifact {artifact_id} not found in {inv_id}", + ) + + data: dict[str, Any] = json.loads(artifact_path.read_text()) + return data diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..488e48b4f --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1487 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Callable, Generator +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_ci_only, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import _build_out_probs +from spd.app.backend.routers.pretrain_info import _get_pretrain_info +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + + +@dataclass +class InvestigationConfig: + """Configuration for investigation mode. All paths are required when in investigation mode.""" + + events_log_path: Path + investigation_dir: Path + + +_investigation_config: InvestigationConfig | None = None + + +def set_investigation_config(config: InvestigationConfig) -> None: + """Configure MCP for investigation mode.""" + global _investigation_config + _investigation_config = config + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if in investigation mode.""" + if _investigation_config is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_investigation_config.events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response. + + Per JSON-RPC 2.0 spec, exactly one of result/error must be present (not both, not neither). + Use model_dump(exclude_none=True) when serializing to avoid including null fields. + """ + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up', '2.attn.o')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., '0.mlp.up:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), + ToolDefinition( + name="save_graph_artifact", + description="""Save a graph as an artifact for inclusion in your research report. + +After calling optimize_graph and getting a graph_id, call this to save the graph +as an artifact. Then reference it in your research log using the spd:graph syntax: + +```spd:graph +artifact: graph_001 +``` + +This allows humans reviewing your investigation to see interactive circuit visualizations +inline with your research notes.""", + inputSchema={ + "type": "object", + "properties": { + "graph_id": { + "type": "integer", + "description": "The graph ID returned by optimize_graph", + }, + "caption": { + "type": "string", + "description": "Optional caption describing what this graph shows", + }, + }, + "required": ["graph_id"], + }, + ), + ToolDefinition( + name="probe_component", + description="""Fast CI probing on custom text. + +Computes causal importance values and subcomponent activations for a specific component +across all positions in the input text. Also returns next-token probabilities. + +Use this for quick, targeted analysis of how a component responds to specific inputs.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The input text to probe", + }, + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + }, + "required": ["text", "layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_activation_examples", + description="""Get activation examples from harvest data for a component. + +Returns examples showing token windows where the component fires, along with +CI values and activation strengths at each position. + +Use this to understand what inputs activate a component.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "limit": { + "type": "integer", + "description": "Maximum number of examples to return (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_attributions", + description="""Get dataset-level component dependencies from pre-computed attributions. + +Returns the top source and target components that this component attributes to/from, +aggregated over the training dataset. Both positive and negative attributions are returned. + +Use this to understand a component's role in the broader network.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up') or 'output'", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "k": { + "type": "integer", + "description": "Number of top attributions to return per direction (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_model_info", + description="""Get architecture details about the pretrained model. + +Returns model type, summary, target model config, topology, and pretrain info. +No parameters required.""", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _canonicalize_layer(layer: str, loaded: Any) -> str: + """Translate concrete layer name to canonical, passing through 'output'.""" + if layer == "output": + return layer + return loaded.topology.target_to_canon(layer) + + +def _canonicalize_key(concrete_key: str, loaded: Any) -> str: + """Translate concrete component key (e.g. 'h.0.mlp.c_fc:444') to canonical ('0.mlp.up:444').""" + layer, idx = concrete_key.rsplit(":", 1) + return f"{_canonicalize_layer(layer, loaded)}:{idx}" + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + adv_pgd=None, # AdvPGDConfig(n_steps=10, step_size=0.01, init="random"), + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + # Build output probs for response + out_probs = _build_out_probs( + ci_masked_out_logits, + target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + canonical_key = f"{layer}:{component_idx}" + + # Harvest/interp repos store concrete keys (e.g. "h.0.mlp.c_fc:444") + concrete_layer = loaded.topology.canon_to_target(layer) + concrete_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_info: {canonical_key}", + {"layer": layer, "idx": component_idx}, + ) + + result: dict[str, Any] = {"component_key": canonical_key} + + # Get interpretation + if loaded.interp is not None: + interp = loaded.interp.get_interpretation(concrete_key) + if interp is not None: + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + else: + result["interpretation"] = None + + # Get token stats + assert loaded.harvest is not None, "harvest data not loaded" + token_stats = loaded.harvest.get_token_stats() + if token_stats is not None: + input_stats = analysis.get_input_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + else: + result["token_stats"] = None + + # Get correlations (return canonical keys) + correlations = loaded.harvest.get_correlations() + if correlations is not None and analysis.has_component(correlations, concrete_key): + result["correlated_components"] = { + "precision": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "precision", top_k + ) + ], + "pmi": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, + compute_intervention, + ) + from spd.app.backend.optim_cis import MeanKLLossConfig + + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "embed", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=loaded.tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=MeanKLLossConfig(), + sampling=loaded.config.sampling, + top_k=top_k, + ) + + predictions = [] + for pos_predictions in result.ci: + pos_result = [] + for pred in pos_predictions: + pos_result.append( + { + "token": pred.token, + "token_id": pred.token_id, + "circuit_prob": round(pred.prob, 6), + "full_model_prob": round(pred.target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _require_investigation_config() -> InvestigationConfig: + """Get investigation config, raising if not in investigation mode.""" + assert _investigation_config is not None, "Not running in investigation mode" + return _investigation_config + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + config = _require_investigation_config() + content = params["content"] + research_log_path = config.investigation_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.investigate.schemas import BehaviorExplanation, ComponentInfo, Evidence + + config = _require_investigation_config() + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = config.investigation_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + config = _require_investigation_config() + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = config.investigation_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +def _tool_save_graph_artifact(params: dict[str, Any]) -> dict[str, Any]: + """Save a graph as an artifact for the research report. + + Uses the same filtering logic as the main graph API: + 1. Filter nodes by CI threshold + 2. Add pseudo nodes (wte, output) + 3. Filter edges to only active nodes + 4. Apply edge limit + """ + config = _require_investigation_config() + manager, loaded = _get_state() + + graph_id = params["graph_id"] + caption = params.get("caption") + ci_threshold = params.get("ci_threshold", 0.5) + edge_limit = params.get("edge_limit", 5000) + + _log_event( + "tool_call", + f"save_graph_artifact: graph_id={graph_id}", + {"graph_id": graph_id, "caption": caption}, + ) + + # Fetch graph from DB + result = manager.db.get_graph(graph_id) + if result is None: + raise ValueError(f"Graph with id={graph_id} not found") + + graph, prompt_id = result + + # Get tokens from prompt + prompt_record = manager.db.get_prompt(prompt_id) + if prompt_record is None: + raise ValueError(f"Prompt with id={prompt_id} not found") + + tokens = [loaded.tokenizer.get_tok_display(tid) for tid in prompt_record.token_ids] + num_tokens = len(tokens) + + # Create artifacts directory + artifacts_dir = config.investigation_dir / "artifacts" + artifacts_dir.mkdir(exist_ok=True) + + # Generate artifact ID (find max existing number to avoid collisions) + existing_nums = [] + for f in artifacts_dir.glob("graph_*.json"): + try: + num = int(f.stem.split("_")[1]) + existing_nums.append(num) + except (IndexError, ValueError): + continue + artifact_num = max(existing_nums, default=0) + 1 + artifact_id = f"graph_{artifact_num:03d}" + + # Compute out_probs from stored logits + out_probs = _build_out_probs( + graph.ci_masked_out_logits, + graph.target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Step 1: Filter nodes by CI threshold (same as main graph API) + filtered_ci_vals = {k: v for k, v in graph.node_ci_vals.items() if v > ci_threshold} + l0_total = len(filtered_ci_vals) + + # Step 2: Add pseudo nodes (embed and output) - same as _add_pseudo_layer_nodes + node_ci_vals_with_pseudo = dict(filtered_ci_vals) + for seq_pos in range(num_tokens): + node_ci_vals_with_pseudo[f"embed:{seq_pos}:0"] = 1.0 + for key, out_prob in out_probs.items(): + seq_pos, token_id = key.split(":") + node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob + + # Step 3: Filter edges to only active nodes + active_node_keys = set(node_ci_vals_with_pseudo.keys()) + filtered_edges = [ + e + for e in graph.edges + if str(e.source) in active_node_keys and str(e.target) in active_node_keys + ] + + # Step 4: Sort by strength and apply edge limit + filtered_edges.sort(key=lambda e: abs(e.strength), reverse=True) + filtered_edges = filtered_edges[:edge_limit] + + # Build edges data + edges_data = [ + { + "src": str(e.source), + "tgt": str(e.target), + "val": e.strength, + } + for e in filtered_edges + ] + + # Compute max abs attr from filtered edges + max_abs_attr = max((abs(e.strength) for e in filtered_edges), default=0.0) + + # Filter nodeSubcompActs to match nodeCiVals + filtered_subcomp_acts = { + k: v for k, v in graph.node_subcomp_acts.items() if k in node_ci_vals_with_pseudo + } + + # Build artifact data (self-contained GraphData, same structure as API response) + artifact = { + "type": "graph", + "id": artifact_id, + "caption": caption, + "graph_id": graph_id, + "data": { + "tokens": tokens, + "edges": edges_data, + "outputProbs": { + k: { + "prob": v.prob, + "logit": v.logit, + "target_prob": v.target_prob, + "target_logit": v.target_logit, + "token": v.token, + } + for k, v in out_probs.items() + }, + "nodeCiVals": node_ci_vals_with_pseudo, + "nodeSubcompActs": filtered_subcomp_acts, + "maxAbsAttr": max_abs_attr, + "l0_total": l0_total, + }, + } + + # Save artifact + artifact_path = artifacts_dir / f"{artifact_id}.json" + artifact_path.write_text(json.dumps(artifact, indent=2)) + + _log_event( + "artifact_saved", + f"Saved graph artifact: {artifact_id}", + {"artifact_id": artifact_id, "graph_id": graph_id, "path": str(artifact_path)}, + ) + + return {"artifact_id": artifact_id, "path": str(artifact_path)} + + +def _tool_probe_component(params: dict[str, Any]) -> dict[str, Any]: + """Fast CI probing on custom text for a specific component.""" + manager, loaded = _get_state() + + text = params["text"] + layer = params["layer"] + component_idx = params["component_idx"] + + _log_event( + "tool_call", + f"probe_component: '{text[:50]}...' layer={layer} idx={component_idx}", + {"text": text, "layer": layer, "component_idx": component_idx}, + ) + + token_ids = loaded.tokenizer.encode(text) + assert token_ids, "Text produced no tokens" + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + concrete_layer = loaded.topology.canon_to_target(layer) + + with manager.gpu_lock(): + result = compute_ci_only( + model=loaded.model, tokens=tokens_tensor, sampling=loaded.config.sampling + ) + + ci_values = result.ci_lower_leaky[concrete_layer][0, :, component_idx].tolist() + subcomp_acts = result.component_acts[concrete_layer][0, :, component_idx].tolist() + + # Get next token probs from target model output + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = result.target_out_probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "tokens": token_strings, + "ci_values": ci_values, + "subcomp_acts": subcomp_acts, + "next_token_probs": next_token_probs, + } + + +def _tool_get_component_activation_examples(params: dict[str, Any]) -> dict[str, Any]: + """Get activation examples from harvest data.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + limit = params.get("limit", 10) + + concrete_layer = loaded.topology.canon_to_target(layer) + component_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_activation_examples: {component_key}", + {"layer": layer, "component_idx": component_idx, "limit": limit}, + ) + + assert loaded.harvest is not None, "harvest data not loaded" + canonical_key = f"{layer}:{component_idx}" + comp = loaded.harvest.get_component(component_key) + if comp is None: + return {"component_key": canonical_key, "examples": [], "total": 0} + + examples = [] + for ex in comp.activation_examples[:limit]: + token_strings = [loaded.tokenizer.get_tok_display(t) for t in ex.token_ids] + examples.append( + { + "tokens": token_strings, + "ci_values": ex.activations["causal_importance"], + "component_acts": ex.activations["component_activation"], + } + ) + + return { + "component_key": canonical_key, + "examples": examples, + "total": len(comp.activation_examples), + "mean_ci": comp.mean_activations["causal_importance"], + } + + +def _tool_get_model_info(_params: dict[str, Any]) -> dict[str, Any]: + """Get architecture details about the pretrained model.""" + _, loaded = _get_state() + + _log_event("tool_call", "get_model_info", {}) + + info = _get_pretrain_info(loaded.config) + return info.model_dump() + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +_STREAMING_TOOLS: dict[str, Callable[..., Generator[dict[str, Any]]]] = { + "optimize_graph": _tool_optimize_graph, +} + +_SIMPLE_TOOLS: dict[str, Callable[..., dict[str, Any]]] = { + "get_component_info": _tool_get_component_info, + "run_ablation": _tool_run_ablation, + "search_dataset": _tool_search_dataset, + "create_prompt": _tool_create_prompt, + "update_research_log": _tool_update_research_log, + "save_explanation": _tool_save_explanation, + "set_investigation_summary": _tool_set_investigation_summary, + "save_graph_artifact": _tool_save_graph_artifact, + "probe_component": _tool_probe_component, + "get_component_activation_examples": _tool_get_component_activation_examples, + # "get_component_attributions": _tool_get_component_attributions, + "get_model_info": _tool_get_model_info, +} + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name in _STREAMING_TOOLS: + return _STREAMING_TOOLS[name](arguments) + + if name in _SIMPLE_TOOLS: + result = _SIMPLE_TOOLS[name](arguments) + return {"content": [{"type": "text", "text": json.dumps(result, indent=2)}]} + + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(exclude_none=True), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True) + ) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump(exclude_none=True))}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump(exclude_none=True))}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump( + exclude_none=True + ) + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump(exclude_none=True) + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump(exclude_none=True) + ) diff --git a/spd/app/backend/routers/pretrain_info.py b/spd/app/backend/routers/pretrain_info.py new file mode 100644 index 000000000..424f7b035 --- /dev/null +++ b/spd/app/backend/routers/pretrain_info.py @@ -0,0 +1,246 @@ +"""Pretrained model architecture info endpoint. + +Fetches target model architecture from pretrain runs, without loading checkpoints. +Used by the run picker to show architecture summaries and by the data sources tab +to show topology and raw pretrain config. +""" + +from typing import Any + +import wandb +import yaml +from fastapi import APIRouter +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors +from spd.configs import Config +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import download_wandb_file, fetch_wandb_run_dir, parse_wandb_run_path + +router = APIRouter(prefix="/api/pretrain_info", tags=["pretrain_info"]) + + +class BlockStructure(BaseModel): + index: int + attn_type: str # "separate" or "fused" + attn_projections: list[str] # e.g. ["q","k","v","o"] or ["qkv","o"] + ffn_type: str # "glu" or "mlp" + ffn_projections: list[str] # e.g. ["gate","up","down"] or ["up","down"] + + +class TopologyInfo(BaseModel): + n_blocks: int + block_structure: list[BlockStructure] + + +class PretrainInfoResponse(BaseModel): + model_type: str + summary: str + dataset_short: str | None + target_model_config: dict[str, Any] | None + pretrain_config: dict[str, Any] | None + pretrain_wandb_path: str | None + topology: TopologyInfo | None + + +def _load_spd_config_lightweight(wandb_path: str) -> Config: + """Load just the SPD config YAML without downloading checkpoints.""" + entity, project, run_id = parse_wandb_run_path(wandb_path) + + # Check local cache first + run_dir = SPD_OUT_DIR / "runs" / f"{project}-{run_id}" + config_path = run_dir / "final_config.yaml" + + if not config_path.exists(): + logger.info(f"[pretrain_info] Downloading config for {entity}/{project}/{run_id}") + api = wandb.Api() + run = api.run(f"{entity}/{project}/{run_id}") + run_dir = fetch_wandb_run_dir(run_id) + config_path = download_wandb_file(run, run_dir, "final_config.yaml") + + with open(config_path) as f: + return Config(**yaml.safe_load(f)) + + +def _load_pretrain_configs(pretrain_path: str) -> tuple[dict[str, Any], dict[str, Any]]: + """Load model config and training config from a pretrain run, config files only.""" + entity, project, run_id = parse_wandb_run_path(pretrain_path) + + cache_dir = SPD_OUT_DIR / "pretrain_cache" / f"{project}-{run_id}" + model_config_path = cache_dir / "model_config.yaml" + config_path = cache_dir / "final_config.yaml" + + if not model_config_path.exists() or not config_path.exists(): + logger.info(f"[pretrain_info] Downloading pretrain configs for {pretrain_path}") + api = wandb.Api() + run = api.run(f"{entity}/{project}/{run_id}") + cache_dir.mkdir(parents=True, exist_ok=True) + for f in run.files(): + if f.name in ("model_config.yaml", "final_config.yaml"): + f.download(root=str(cache_dir), exist_ok=True) + + assert model_config_path.exists(), f"model_config.yaml not found at {model_config_path}" + assert config_path.exists(), f"final_config.yaml not found at {config_path}" + + with open(model_config_path) as f: + target_model_config = yaml.safe_load(f) + with open(config_path) as f: + pretrain_config = yaml.safe_load(f) + + return target_model_config, pretrain_config + + +_MODEL_TYPE_TOPOLOGY: dict[str, tuple[str, list[str], str, list[str]]] = { + # model_type -> (attn_type, attn_projs, ffn_type, ffn_projs) + "LlamaSimple": ("separate", ["q", "k", "v", "o"], "glu", ["gate", "up", "down"]), + "LlamaSimpleMLP": ("separate", ["q", "k", "v", "o"], "mlp", ["up", "down"]), + "GPT2Simple": ("separate", ["q", "k", "v", "o"], "mlp", ["up", "down"]), + "GPT2": ("fused", ["qkv", "o"], "mlp", ["up", "down"]), + "Llama": ("separate", ["q", "k", "v", "o"], "glu", ["gate", "up", "down"]), +} + + +def _build_topology(model_type: str, n_blocks: int) -> TopologyInfo | None: + topo = _MODEL_TYPE_TOPOLOGY.get(model_type) + if topo is None: + return None + attn_type, attn_projs, ffn_type, ffn_projs = topo + blocks = [ + BlockStructure( + index=i, + attn_type=attn_type, + attn_projections=attn_projs, + ffn_type=ffn_type, + ffn_projections=ffn_projs, + ) + for i in range(n_blocks) + ] + return TopologyInfo(n_blocks=n_blocks, block_structure=blocks) + + +def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) -> str: + """One-line architecture summary for the run picker.""" + if target_model_config is None: + return model_type + + parts = [model_type] + + n_layer = target_model_config.get("n_layer") + n_embd = target_model_config.get("n_embd") + n_intermediate = target_model_config.get("n_intermediate") + n_head = target_model_config.get("n_head") + n_kv = target_model_config.get("n_key_value_heads") + vocab = target_model_config.get("vocab_size") + ctx = target_model_config.get("n_ctx") + + if n_layer is not None: + parts.append(f"{n_layer}L") + dims = [] + if n_embd is not None: + dims.append(f"d={n_embd}") + if n_intermediate is not None: + dims.append(f"ff={n_intermediate}") + if dims: + parts.append(" ".join(dims)) + heads = [] + if n_head is not None: + heads.append(f"{n_head}h") + if n_kv is not None and n_kv != n_head: + heads.append(f"{n_kv}kv") + if heads: + parts.append("/".join(heads)) + meta = [] + if vocab is not None: + meta.append(f"vocab={vocab}") + if ctx is not None: + meta.append(f"ctx={ctx}") + if meta: + parts.append(" ".join(meta)) + + return " · ".join(parts) + + +_DATASET_SHORT_NAMES: dict[str, str] = { + "simplestories": "SS", + "pile": "Pile", + "tinystories": "TS", +} + + +def _get_dataset_short(pretrain_config: dict[str, Any] | None) -> str | None: + """Extract a short dataset label from the pretrain config.""" + if pretrain_config is None: + return None + dataset_name: str = ( + pretrain_config.get("train_dataset_config", {}).get("name", "") + or pretrain_config.get("dataset", "") + ).lower() + for key, short in _DATASET_SHORT_NAMES.items(): + if key in dataset_name: + return short + return None + + +def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: + """Extract pretrain info from an SPD config.""" + model_class_name = spd_config.pretrained_model_class + model_type = model_class_name.split(".")[-1] + + # Determine the pretrain wandb path + pretrain_path = spd_config.pretrained_model_name or ( + str(spd_config.pretrained_model_path) if spd_config.pretrained_model_path else None + ) + + target_model_config: dict[str, Any] | None = None + pretrain_config: dict[str, Any] | None = None + pretrain_wandb_path: str | None = None + + if pretrain_path and model_class_name.startswith("spd.pretrain.models."): + try: + pretrain_wandb_path = pretrain_path + target_model_config, pretrain_config = _load_pretrain_configs(pretrain_path) + # Use model_type from config if available + if "model_type" in target_model_config: + model_type = target_model_config["model_type"] + except Exception: + logger.exception( + f"[pretrain_info] Failed to load pretrain configs from {pretrain_path}" + ) + + n_blocks = target_model_config.get("n_layer", 0) if target_model_config else 0 + topology = _build_topology(model_type, n_blocks) + summary = _build_summary(model_type, target_model_config) + dataset_short = _get_dataset_short(pretrain_config) + + return PretrainInfoResponse( + model_type=model_type, + summary=summary, + dataset_short=dataset_short, + target_model_config=target_model_config, + pretrain_config=pretrain_config, + pretrain_wandb_path=pretrain_wandb_path, + topology=topology, + ) + + +@router.get("") +@log_errors +def get_pretrain_info_for_run(wandb_path: str) -> PretrainInfoResponse: + """Get pretrained model architecture info for an SPD run. + + Fetches only config files (no checkpoints) for efficiency. + """ + spd_config = _load_spd_config_lightweight(wandb_path) + return _get_pretrain_info(spd_config) + + +@router.get("/loaded") +@log_errors +def get_pretrain_info_for_loaded_run(loaded: DepLoadedRun) -> PretrainInfoResponse: + """Get pretrained model architecture info for the currently loaded run. + + Uses the already-loaded config (no additional wandb downloads). + """ + return _get_pretrain_info(loaded.config) diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 8002aa11c..01fd2e1ec 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -1,22 +1,11 @@ -"""Prompt listing and generation endpoints.""" - -import json -from collections.abc import Generator -from typing import Annotated +"""Prompt listing endpoints.""" import torch -from fastapi import APIRouter, HTTPException, Query -from fastapi.responses import StreamingResponse +from fastapi import APIRouter from pydantic import BaseModel -from spd.app.backend.compute import compute_ci_only, extract_active_from_ci from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.utils import log_errors -from spd.configs import LMTaskConfig -from spd.data import DatasetConfig, create_data_loader -from spd.log import logger -from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data # ============================================================================= # Schemas @@ -30,26 +19,41 @@ class PromptPreview(BaseModel): token_ids: list[int] tokens: list[str] preview: str + next_token_probs: list[float | None] # Probability of next token (last is None) + +PREVIEW_MAX_CHARS = 60 -class PromptSearchQuery(BaseModel): - """Query parameters for prompt search.""" - components: list[str] - mode: str +def _make_preview(spans: list[str]) -> str: + text = "".join(spans) + if len(text) <= PREVIEW_MAX_CHARS: + return text + return text[:PREVIEW_MAX_CHARS] + "..." -class PromptSearchResponse(BaseModel): - """Response from prompt search endpoint.""" +router = APIRouter(prefix="/api/prompts", tags=["prompts"]) - query: PromptSearchQuery - count: int - results: list[PromptPreview] +def compute_next_token_probs(token_ids: list[int], loaded: DepLoadedRun) -> list[float | None]: + """Compute P(next_token | prefix) for each position.""" + if len(token_ids) == 0: + return [] -router = APIRouter(prefix="/api/prompts", tags=["prompts"]) + device = next(loaded.model.parameters()).device + tokens_tensor = torch.tensor([token_ids], device=device) -DEVICE = get_device() + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + result: list[float | None] = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + result.append(prob) + result.append(None) # No next token for last position + return result @router.get("") @@ -63,208 +67,46 @@ def list_prompts(manager: DepStateManager, loaded: DepLoadedRun) -> list[PromptP for pid in prompt_ids: prompt = db.get_prompt(pid) assert prompt is not None, f"Prompt {pid} in index but not in DB" - token_strings = [loaded.token_strings[t] for t in prompt.token_ids] + spans = loaded.tokenizer.get_spans(prompt.token_ids) + next_token_probs = compute_next_token_probs(prompt.token_ids, loaded) results.append( PromptPreview( id=prompt.id, token_ids=prompt.token_ids, - tokens=token_strings, - preview="".join(token_strings[:10]) + ("..." if len(token_strings) > 10 else ""), + tokens=spans, + preview=_make_preview(spans), + next_token_probs=next_token_probs, ) ) return results -BATCH_SIZE = 32 - - -@router.post("/generate") -@log_errors -def generate_prompts( - n_prompts: int, - manager: DepStateManager, - loaded: DepLoadedRun, -) -> StreamingResponse: - """Generate prompts from training data with CI harvesting. - - Streams progress updates and stores prompts with their active components - (for the inverted index used by search). - """ - db = manager.db - spd_config = loaded.config - - task_config = spd_config.task_config - assert isinstance(task_config, LMTaskConfig) - train_data_config = DatasetConfig( - name=task_config.dataset_name, - hf_tokenizer_path=spd_config.tokenizer_name, - split=task_config.train_data_split, - n_ctx=loaded.context_length, - is_tokenized=task_config.is_tokenized, - streaming=task_config.streaming, - column_name=task_config.column_name, - shuffle_each_epoch=task_config.shuffle_each_epoch, - ) - logger.info(f"[API] Creating train loader for run {loaded.run.wandb_path}") - train_loader, _ = create_data_loader( - dataset_config=train_data_config, - batch_size=BATCH_SIZE, - buffer_size=task_config.buffer_size, - global_seed=spd_config.seed, - ) - - def generate() -> Generator[str]: - added_count = 0 - - for batch in train_loader: - if added_count >= n_prompts: - break - - tokens = extract_batch_data(batch).to(DEVICE) - batch_size, n_seq = tokens.shape - - # Compute CI for the whole batch - ci_result = compute_ci_only( - model=loaded.model, - tokens=tokens, - sampling=loaded.config.sampling, - ) - - # Process each sequence in the batch - prompts = [] - for i in range(batch_size): - if i % 5 == 0: - progress = min(added_count / n_prompts, 1.0) - progress_data = {"type": "progress", "progress": progress, "count": added_count} - yield f"data: {json.dumps(progress_data)}\n\n" - - if added_count >= n_prompts: - break - - token_ids = tokens[i].tolist() - - # Slice CI for this single sequence - ci_single = {k: v[i : i + 1] for k, v in ci_result.ci_lower_leaky.items()} - target_out_probs_single = ci_result.target_out_probs[i : i + 1] - - # Extract active components for inverted index - active_components = extract_active_from_ci( - ci_lower_leaky=ci_single, - target_out_probs=target_out_probs_single, - ci_threshold=0.0, - output_prob_threshold=0.01, - n_seq=n_seq, - ) - - # Add to DB with active components - prompts.append((token_ids, active_components)) - added_count += 1 - - db.add_prompts(loaded.run.id, prompts, loaded.context_length) - - # Final result - total = db.get_prompt_count(loaded.run.id, loaded.context_length) - complete_data = { - "type": "complete", - "prompts_added": added_count, - "total_prompts": total, - } - yield f"data: {json.dumps(complete_data)}\n\n" - logger.info(f"[API] Generated {added_count} prompts for run {loaded.run.id}") - - return StreamingResponse(generate(), media_type="text/event-stream") - - -@router.get("/search") +@router.delete("/{prompt_id}") @log_errors -def search_prompts( - manager: DepStateManager, - loaded: DepLoadedRun, - components: str = "", - mode: Annotated[str, Query(pattern="^(all|any)$")] = "all", -) -> PromptSearchResponse: - """Search for prompts with specified components in the loaded run.""" - db = manager.db - - component_list = [c.strip() for c in components.split(",") if c.strip()] - if not component_list: - raise HTTPException(status_code=400, detail="No components specified") - - require_all = mode == "all" - prompt_ids = db.find_prompts_with_components( - loaded.run.id, component_list, require_all=require_all - ) - - results: list[PromptPreview] = [] - for pid in prompt_ids: - prompt = db.get_prompt(pid) - assert prompt is not None, f"Prompt {pid} in index but not in DB" - token_strings = [loaded.token_strings[t] for t in prompt.token_ids] - results.append( - PromptPreview( - id=prompt.id, - token_ids=prompt.token_ids, - tokens=token_strings, - preview="".join(token_strings[:10]) + ("..." if len(token_strings) > 10 else ""), - ) - ) - - return PromptSearchResponse( - query=PromptSearchQuery(components=component_list, mode=mode), - count=len(results), - results=results, - ) +def delete_prompt(prompt_id: int, manager: DepStateManager) -> None: + """Delete a prompt and all associated data (graphs, interventions).""" + manager.db.delete_prompt(prompt_id) @router.post("/custom") @log_errors -def create_custom_prompt( - text: str, - manager: DepStateManager, - loaded: DepLoadedRun, -) -> PromptPreview: - """Create a custom prompt from text, computing CI and storing it. - - Returns the created prompt with its ID for further operations. - """ - db = manager.db - - # Tokenize - token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) - if not token_ids: - raise HTTPException(status_code=400, detail="Text produced no tokens") +def add_custom_prompt(text: str, manager: DepStateManager, loaded: DepLoadedRun) -> PromptPreview: + """Add a custom text prompt.""" + token_ids = loaded.tokenizer.encode(text) + assert len(token_ids) > 0, "Text produced no tokens" - n_seq = len(token_ids) - tokens_tensor = torch.tensor([token_ids], device=DEVICE) + # Truncate to context length + token_ids = token_ids[: loaded.context_length] - # Compute CI - ci_result = compute_ci_only( - model=loaded.model, - tokens=tokens_tensor, - sampling=loaded.config.sampling, - ) - - # Extract active components for inverted index - active_components = extract_active_from_ci( - ci_lower_leaky=ci_result.ci_lower_leaky, - target_out_probs=ci_result.target_out_probs, - ci_threshold=0.0, - output_prob_threshold=0.01, - n_seq=n_seq, - ) - - # Save to DB - prompt_id = db.add_custom_prompt( - run_id=loaded.run.id, - token_ids=token_ids, - active_components=active_components, - context_length=loaded.context_length, - ) + db = manager.db + prompt_id = db.add_custom_prompt(loaded.run.id, token_ids, loaded.context_length) + spans = loaded.tokenizer.get_spans(token_ids) + next_token_probs = compute_next_token_probs(token_ids, loaded) - token_strings = [loaded.token_strings[t] for t in token_ids] return PromptPreview( id=prompt_id, token_ids=token_ids, - tokens=token_strings, - preview="".join(token_strings[:10]) + ("..." if len(token_strings) > 10 else ""), + tokens=spans, + preview=_make_preview(spans), + next_token_probs=next_token_probs, ) diff --git a/spd/app/backend/routers/run_registry.py b/spd/app/backend/routers/run_registry.py new file mode 100644 index 000000000..d44a7108f --- /dev/null +++ b/spd/app/backend/routers/run_registry.py @@ -0,0 +1,95 @@ +"""Run registry endpoint. + +Returns architecture and data availability for requested SPD runs. +The canonical run list lives in the frontend; the backend just hydrates it. +""" + +import asyncio +from pathlib import Path + +from fastapi import APIRouter +from pydantic import BaseModel + +from spd.app.backend.routers.pretrain_info import _get_pretrain_info, _load_spd_config_lightweight +from spd.app.backend.utils import log_errors +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/run_registry", tags=["run_registry"]) + + +class DataAvailability(BaseModel): + harvest: bool + autointerp: bool + attributions: bool + graph_interp: bool + + +class RunInfoResponse(BaseModel): + wandb_run_id: str + architecture: str | None + availability: DataAvailability + + +def _has_glob_match(pattern_dir: Path, glob_pattern: str) -> bool: + """Check if any file matches a glob pattern under a directory.""" + if not pattern_dir.exists(): + return False + return next(pattern_dir.glob(glob_pattern), None) is not None + + +def _check_availability(run_id: str) -> DataAvailability: + """Lightweight filesystem checks for post-processing data availability.""" + harvest_dir = SPD_OUT_DIR / "harvest" / run_id + autointerp_dir = SPD_OUT_DIR / "autointerp" / run_id + attributions_dir = SPD_OUT_DIR / "dataset_attributions" / run_id + graph_interp_dir = SPD_OUT_DIR / "graph_interp" / run_id + + return DataAvailability( + harvest=_has_glob_match(harvest_dir, "h-*/harvest.db"), + autointerp=_has_glob_match(autointerp_dir, "a-*/.done"), + attributions=_has_glob_match(attributions_dir, "da-*/dataset_attributions.pt"), + graph_interp=_has_glob_match(graph_interp_dir, "*/interp.db"), + ) + + +def _get_architecture_summary(wandb_path: str) -> str | None: + """Get a short architecture label for a run. Returns None on failure.""" + try: + spd_config = _load_spd_config_lightweight(wandb_path) + info = _get_pretrain_info(spd_config) + parts: list[str] = [] + if info.dataset_short: + parts.append(info.dataset_short) + parts.append(info.model_type) + cfg = info.target_model_config + if cfg: + n_layer = cfg.get("n_layer") + n_embd = cfg.get("n_embd") + if n_layer is not None: + parts.append(f"{n_layer}L") + if n_embd is not None: + parts.append(f"d{n_embd}") + return " ".join(parts) + except Exception: + logger.exception(f"[run_registry] Failed to get architecture for {wandb_path}") + return None + + +def _build_run_info(wandb_run_id: str) -> RunInfoResponse: + _, _, run_id = parse_wandb_run_path(wandb_run_id) + return RunInfoResponse( + wandb_run_id=wandb_run_id, + architecture=_get_architecture_summary(wandb_run_id), + availability=_check_availability(run_id), + ) + + +@router.post("") +@log_errors +async def get_run_info(wandb_run_ids: list[str]) -> list[RunInfoResponse]: + """Return architecture and availability for the requested runs.""" + loop = asyncio.get_running_loop() + tasks = [loop.run_in_executor(None, _build_run_info, wid) for wid in wandb_run_ids] + return list(await asyncio.gather(*tasks)) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index c62df6688..3b9c2300e 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -7,18 +7,25 @@ import yaml from fastapi import APIRouter from pydantic import BaseModel -from transformers import AutoTokenizer -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from spd.app.backend.compute import get_sources_by_target +from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.dependencies import DepStateManager -from spd.app.backend.state import HarvestCache, RunState -from spd.app.backend.utils import build_token_lookup, log_errors +from spd.app.backend.state import RunState +from spd.app.backend.utils import log_errors +from spd.autointerp.repo import InterpRepo +from spd.configs import LMTaskConfig +from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo +from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.topology import TransformerTopology, get_sources_by_target from spd.utils.distributed_utils import get_device from spd.utils.wandb_utils import parse_wandb_run_path +# Datasets small enough to load into memory for search +_SEARCHABLE_DATASETS = {"SimpleStories/SimpleStories"} + # ============================================================================= # Schemas # ============================================================================= @@ -34,6 +41,10 @@ class LoadedRun(BaseModel): prompt_count: int context_length: int backend_user: str + dataset_attributions_available: bool + dataset_search_enabled: bool + graph_interp_available: bool + autointerp_available: bool router = APIRouter(prefix="/api", tags=["runs"]) @@ -100,26 +111,27 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): spd_config = run_info.config assert spd_config.tokenizer_name is not None logger.info(f"[API] Loading tokenizer for run {run.id}: {spd_config.tokenizer_name}") - loaded_tokenizer = AutoTokenizer.from_pretrained(spd_config.tokenizer_name) - assert isinstance(loaded_tokenizer, PreTrainedTokenizerFast) + app_tokenizer = AppTokenizer.from_pretrained(spd_config.tokenizer_name) - # Build sources_by_target mapping - logger.info(f"[API] Building sources_by_target mapping for run {run.id}") - sources_by_target = get_sources_by_target(model, DEVICE, spd_config.sampling) + # Build topology and sources_by_target mapping + logger.info(f"[API] Building topology for run {run.id}") + topology = TransformerTopology(model.target_model) - # Build token lookup for activation contexts - logger.info(f"[API] Building token lookup for run {run.id}") - token_strings = build_token_lookup(loaded_tokenizer, spd_config.tokenizer_name) + logger.info(f"[API] Building sources_by_target mapping for run {run.id}") + sources_by_target = get_sources_by_target(model, topology, DEVICE, spd_config.sampling) manager.run_state = RunState( run=run, model=model, - tokenizer=loaded_tokenizer, + topology=topology, + tokenizer=app_tokenizer, sources_by_target=sources_by_target, config=spd_config, - token_strings=token_strings, context_length=context_length, - harvest=HarvestCache(run_id=run_id), + harvest=HarvestRepo.open_most_recent(run_id), + interp=InterpRepo.open(run_id), + attributions=AttributionRepo.open(run_id), + graph_interp=GraphInterpRepo.open(run_id), ) logger.info(f"[API] Run {run.id} loaded on {DEVICE}") @@ -142,6 +154,11 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: prompt_count = manager.db.get_prompt_count(run.id, context_length) + task_config = manager.run_state.config.task_config + dataset_search_enabled = ( + isinstance(task_config, LMTaskConfig) and task_config.dataset_name in _SEARCHABLE_DATASETS + ) + return LoadedRun( id=run.id, wandb_path=run.wandb_path, @@ -150,6 +167,10 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: prompt_count=prompt_count, context_length=context_length, backend_user=getpass.getuser(), + dataset_attributions_available=manager.run_state.attributions is not None, + dataset_search_enabled=dataset_search_enabled, + graph_interp_available=manager.run_state.graph_interp is not None, + autointerp_available=manager.run_state.interp is not None, ) diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index 61fa3d9d2..bf5f3e9b2 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -27,17 +27,6 @@ class OutputProbability(BaseModel): # ============================================================================= -class ActivationContextsGenerationConfig(BaseModel): - """Configuration for generating activation contexts.""" - - importance_threshold: float = 0.01 - n_batches: int = 100 - batch_size: int = 32 - n_tokens_either_side: int = 5 - topk_examples: int = 20 - separation_tokens: int = 0 - - class SubcomponentMetadata(BaseModel): """Lightweight metadata for a subcomponent (without examples/token_prs)""" diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 2feb2fdd9..afbff5db6 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -8,10 +8,12 @@ python -m spd.app.backend.server --port 8000 """ +import os import time import traceback from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from pathlib import Path import fire import torch @@ -29,15 +31,22 @@ agents_router, clusters_router, correlations_router, + data_sources_router, dataset_attributions_router, dataset_search_router, + graph_interp_router, graphs_router, intervention_router, + investigations_router, + mcp_router, + pretrain_info_router, prompts_router, + run_registry_router, runs_router, ) from spd.app.backend.state import StateManager from spd.log import logger +from spd.settings import SPD_APP_DEFAULT_RUN from spd.utils.distributed_utils import get_device DEVICE = get_device() @@ -46,6 +55,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + from spd.app.backend.routers.mcp import InvestigationConfig, set_investigation_config + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -56,6 +67,24 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for investigation mode (derives paths from investigation dir) + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + inv_dir = Path(investigation_dir) + set_investigation_config( + InvestigationConfig( + events_log_path=inv_dir / "events.jsonl", + investigation_dir=inv_dir, + ) + ) + logger.info(f"[STARTUP] Investigation mode enabled: dir={investigation_dir}") + + if SPD_APP_DEFAULT_RUN is not None: + from spd.app.backend.routers.runs import load_run + + logger.info(f"[STARTUP] Auto-loading default run: {SPD_APP_DEFAULT_RUN}") + load_run(SPD_APP_DEFAULT_RUN, context_length=512, manager=manager) + yield manager.close() @@ -155,6 +184,12 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_search_router) app.include_router(dataset_attributions_router) app.include_router(agents_router) +app.include_router(investigations_router) +app.include_router(mcp_router) +app.include_router(data_sources_router) +app.include_router(graph_interp_router) +app.include_router(pretrain_info_router) +app.include_router(run_registry_router) def cli(port: int = 8000) -> None: diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index 47dacfe51..2cdabda73 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -1,108 +1,27 @@ """Application state management for the SPD backend. Contains: -- RunState: Runtime state for a loaded run (model, tokenizer, caches) +- RunState: Runtime state for a loaded run (model, tokenizer, repos) - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any -from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from fastapi import HTTPException +from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB, Run -from spd.autointerp.loaders import load_interpretations -from spd.autointerp.schemas import InterpretationResult +from spd.autointerp.repo import InterpRepo from spd.configs import Config -from spd.dataset_attributions import DatasetAttributionStorage, load_dataset_attributions -from spd.harvest.loaders import ( - load_activation_contexts_summary, - load_correlations, - load_token_stats, -) -from spd.harvest.schemas import ComponentSummary -from spd.harvest.storage import CorrelationStorage, TokenStatsStorage +from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo +from spd.harvest.repo import HarvestRepo from spd.models.component_model import ComponentModel - -_NOT_LOADED = object() - - -class HarvestCache: - """Lazily-loaded harvest data for a run. - - All fields are loaded on first access and cached for the lifetime of the run. - Uses a sentinel pattern to distinguish "not loaded" from "loaded but None". - """ - - def __init__(self, run_id: str) -> None: - self.run_id = run_id - self._correlations = _NOT_LOADED - self._token_stats = _NOT_LOADED - self._interpretations = _NOT_LOADED - self._activation_contexts_summary = _NOT_LOADED - self._dataset_attributions = _NOT_LOADED - - @property - def correlations(self) -> CorrelationStorage: - if self._correlations is _NOT_LOADED: - self._correlations = load_correlations(self.run_id) - assert isinstance(self._correlations, CorrelationStorage) - return self._correlations - - @property - def token_stats(self) -> TokenStatsStorage: - if self._token_stats is _NOT_LOADED: - self._token_stats = load_token_stats(self.run_id) - assert isinstance(self._token_stats, TokenStatsStorage) - return self._token_stats - - @property - def interpretations(self) -> dict[str, InterpretationResult]: - if self._interpretations is _NOT_LOADED: - self._interpretations = load_interpretations(self.run_id) - assert isinstance(self._interpretations, dict) - return self._interpretations - - def _load_activation_contexts_summary(self) -> dict[str, ComponentSummary] | None: - if self._activation_contexts_summary is _NOT_LOADED: - self._activation_contexts_summary = load_activation_contexts_summary(self.run_id) - if self._activation_contexts_summary is None: - return None - assert isinstance(self._activation_contexts_summary, dict) - return self._activation_contexts_summary - - def has_activation_contexts_summary(self) -> bool: - """Check if activation contexts summary is available.""" - return self._load_activation_contexts_summary() is not None - - @property - def activation_contexts_summary(self) -> dict[str, ComponentSummary]: - """Lightweight summary of activation contexts, keyed by component_key (e.g. 'h.0.mlp.c_fc:5').""" - result = self._load_activation_contexts_summary() - assert result is not None, f"No activation contexts summary found for run {self.run_id}" - return result - - def _load_dataset_attributions(self) -> DatasetAttributionStorage | None: - if self._dataset_attributions is _NOT_LOADED: - self._dataset_attributions = load_dataset_attributions(self.run_id) - if self._dataset_attributions is None: - return None - assert isinstance(self._dataset_attributions, DatasetAttributionStorage) - return self._dataset_attributions - - def has_dataset_attributions(self) -> bool: - """Check if dataset attributions are available.""" - return self._load_dataset_attributions() is not None - - @property - def dataset_attributions(self) -> DatasetAttributionStorage: - """Dataset-aggregated attribution matrix.""" - result = self._load_dataset_attributions() - assert result is not None, ( - f"No dataset attributions found for run {self.run_id}. " - "Run: spd-attributions --n_batches N" - ) - return result +from spd.topology import TransformerTopology @dataclass @@ -111,12 +30,15 @@ class RunState: run: Run model: ComponentModel - tokenizer: PreTrainedTokenizerBase + topology: TransformerTopology + tokenizer: AppTokenizer sources_by_target: dict[str, list[str]] config: Config - token_strings: dict[int, str] context_length: int - harvest: HarvestCache + harvest: HarvestRepo | None + interp: InterpRepo | None + attributions: AttributionRepo | None + graph_interp: GraphInterpRepo | None @dataclass @@ -147,6 +69,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -189,3 +112,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() diff --git a/spd/app/backend/utils.py b/spd/app/backend/utils.py index e743c893f..dc79631f4 100644 --- a/spd/app/backend/utils.py +++ b/spd/app/backend/utils.py @@ -33,36 +33,32 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: _PUNCT_NO_SPACE = set(".,!?;:'\")-]}>/") -def build_token_lookup( - tokenizer: Any, - tokenizer_name: str, -) -> dict[int, str]: - """Build token ID -> string lookup. +def delimit_tokens(tokens: list[tuple[str, bool]]) -> str: + """Join token strings, wrapping active spans in <>. - Uses tokenizer-specific strategy to produce strings that concatenate correctly. + Consecutive active tokens are grouped: [(" over", T), (" the", T), (" moon", T)] + produces " <>". """ - lookup: dict[int, str] = {} - vocab_size: int = tokenizer.vocab_size - - for tid in range(vocab_size): - decoded: str = tokenizer.decode([tid], skip_special_tokens=False) - - match tokenizer_name: - case "SimpleStories/test-SimpleStories-gpt2-1.25M": - # WordPiece handling: - if decoded.startswith("##"): - lookup[tid] = decoded[2:] - elif decoded and decoded[0] in _PUNCT_NO_SPACE: - lookup[tid] = decoded - else: - lookup[tid] = " " + decoded - case "openai-community/gpt2": - # BPE (GPT-2 style): spaces encoded in token via Ġ -> space - lookup[tid] = decoded - case _: - raise ValueError(f"Unsupported tokenizer name: {tokenizer_name}") - - return lookup + parts: list[str] = [] + in_span = False + for tok, active in tokens: + if active and not in_span: + stripped = tok.lstrip() + parts.append(tok[: len(tok) - len(stripped)]) + parts.append("<<") + parts.append(stripped) + in_span = True + elif active: + parts.append(tok) + elif in_span: + parts.append(">>") + parts.append(tok) + in_span = False + else: + parts.append(tok) + if in_span: + parts.append(">>") + return "".join(parts) @contextmanager diff --git a/spd/app/frontend/package-lock.json b/spd/app/frontend/package-lock.json index b6c451303..32da0218c 100644 --- a/spd/app/frontend/package-lock.json +++ b/spd/app/frontend/package-lock.json @@ -7,6 +7,9 @@ "": { "name": "frontend", "version": "0.0.0", + "dependencies": { + "marked": "^17.0.1" + }, "devDependencies": { "@eslint/js": "^9.38.0", "@sveltejs/vite-plugin-svelte": "^6.2.1", @@ -2347,6 +2350,18 @@ "@jridgewell/sourcemap-codec": "^1.5.5" } }, + "node_modules/marked": { + "version": "17.0.1", + "resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz", + "integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==", + "license": "MIT", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 20" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", diff --git a/spd/app/frontend/package.json b/spd/app/frontend/package.json index f54e1bb3d..f298885ce 100644 --- a/spd/app/frontend/package.json +++ b/spd/app/frontend/package.json @@ -27,5 +27,8 @@ "typescript": "~5.9.3", "typescript-eslint": "^8.46.2", "vite": "^7.1.7" + }, + "dependencies": { + "marked": "^17.0.1" } } diff --git a/spd/app/frontend/src/App.svelte b/spd/app/frontend/src/App.svelte index 9d5c922df..c7a073625 100644 --- a/spd/app/frontend/src/App.svelte +++ b/spd/app/frontend/src/App.svelte @@ -12,7 +12,7 @@ let backendUser = $state>({ status: "uninitialized" }); - let showWhichView = $derived(runState.run?.status === "loaded" ? "run-view" : "run-selector"); + let showWhichView = $derived(runState.run.status === "loaded" ? "run-view" : "run-selector"); async function handleLoadRun(wandbPath: string, contextLength: number) { await runState.loadRun(wandbPath, contextLength); @@ -20,15 +20,15 @@ onMount(() => { runState.syncStatus(); - api.getWhoami().then((user) => (backendUser = { status: "loaded", data: user })); + api.whoami().then((user) => (backendUser = { status: "loaded", data: user })); }); {#if showWhichView === "run-selector"} {:else} diff --git a/spd/app/frontend/src/app.css b/spd/app/frontend/src/app.css index 2c706d19b..bf6649aee 100644 --- a/spd/app/frontend/src/app.css +++ b/spd/app/frontend/src/app.css @@ -1,27 +1,33 @@ :root { - /* Punchy Research - crisp whites, bold contrasts */ + /* Goodfire-inspired - warm whites, navy text, vibrant blue accent */ --bg-base: #ffffff; --bg-surface: #ffffff; --bg-elevated: #ffffff; - --bg-inset: #f8f9fa; + --bg-inset: #f7f6f2; + --bg-hover: #f0efeb; - --border-subtle: #e0e0e0; - --border-default: #c0c0c0; - --border-strong: #888888; + --border-subtle: #e5e3dc; + --border-default: #c8c5bc; + --border-strong: #8a8780; - --text-primary: #111111; - --text-secondary: #555555; - --text-muted: #999999; + --text-primary: #1d272a; + --text-secondary: #646464; + --text-muted: #b4b4b4; - --accent-primary: #2563eb; - --accent-primary-dim: #1d4ed8; + --accent-primary: #7c4d33; + --accent-primary-bright: #96613f; + --accent-primary-dim: #5e3a27; --status-positive: #16a34a; --status-positive-bright: #22c55e; --status-negative: #dc2626; --status-negative-bright: #ef4444; - --status-info: #2563eb; - --status-info-bright: #3b82f6; + --status-warning: #eab308; + --status-warning-bright: #facc15; + --status-info: #4d65ff; + --status-info-bright: #6b7fff; + + --focus-ring: #4d65ff; /* Typography - Clean system fonts with mono for code */ --font-mono: "SF Mono", "Menlo", "Monaco", "Consolas", monospace; @@ -43,6 +49,16 @@ --radius-sm: 4px; --radius-md: 6px; --radius-lg: 8px; + + /* Shadows - standardized opacity levels */ + --shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.08); + --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.12); + --shadow-lg: 0 8px 24px rgba(0, 0, 0, 0.16); + + /* Transitions - standardized timing */ + --transition-fast: 0.1s ease; + --transition-normal: 0.15s ease; + --transition-slow: 0.2s ease; } * { @@ -71,9 +87,9 @@ button { border-radius: var(--radius-sm); cursor: pointer; transition: - background-color 0.1s, - border-color 0.1s, - color 0.1s; + background-color var(--transition-fast), + border-color var(--transition-fast), + color var(--transition-fast); } button:disabled { @@ -115,8 +131,8 @@ button:disabled { opacity: 0; pointer-events: none; z-index: 1000; - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); - transition: opacity 0.2s; + box-shadow: var(--shadow-md); + transition: opacity var(--transition-slow); } .info-icon:hover::after { diff --git a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte index 146581ee8..c9c304950 100644 --- a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte +++ b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte @@ -1,25 +1,64 @@
- @@ -103,43 +135,70 @@
-
- - -
-
- - + Center on peak + +
+ {#if loading} +
+
+ {#each Array(pageSize) as _, i (i)} +
{/each} - +
-
-
-
- {#each paginatedIndices as idx (idx)} -
- + {:else} + {@const d = loaded!} +
+ {#if displaySettings.centerOnPeak} +
+ {#each paginatedIndices as idx (idx)} + {@const fp = firingPositions[idx]} +
+
+ +
+
+ +
+
+ +
+
+ {/each}
- {/each} + {:else} +
+ {#each paginatedIndices as idx (idx)} +
+ +
+ {/each} +
+ {/if}
-
+ {/if}
diff --git a/spd/app/frontend/src/components/ActivationContextsTab.svelte b/spd/app/frontend/src/components/ActivationContextsTab.svelte index 8866a68f1..ca025ad97 100644 --- a/spd/app/frontend/src/components/ActivationContextsTab.svelte +++ b/spd/app/frontend/src/components/ActivationContextsTab.svelte @@ -18,6 +18,8 @@
Loading activation contexts summary...
{:else if summary.status === "error"} Error loading summary: {String(summary.error)} + {:else if summary.data === null} + No harvest data available. Run postprocessing first. {:else} {/if} diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index 987e38f10..a90b951a7 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -1,21 +1,21 @@
@@ -203,6 +313,7 @@
+
+ + {#if plotData} + + + {#each plotData.yTicks as tick (tick.value)} + + + {tick.label} + + {/each} + + + {#if plotData.points.length > 1} + `${p.x},${p.y}`).join(" ")} + fill="none" + stroke="var(--accent-primary)" + stroke-width="1.5" + /> + {/if} + + + {#each plotData.points as point (point.rank)} + handlePlotClick(point.rank)} + /> + {/each} + + + {#if currentPointIndex !== null && plotData.points[currentPointIndex]} + {@const cp = plotData.points[currentPointIndex]} + + + {/if} + + + + Component rank ({plotData.n} total) + + + {/if} +
+
Mean CI: {formatMeanCi(currentMetadata.mean_ci)} + {#if currentIntruderScore !== null} + Intruder: {Math.round(currentIntruderScore * 100)}% + {/if} - +
+ + {#if currentGraphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail?.status === "loading"} -
Loading component data...
- {:else if componentData.componentDetail?.status === "loaded"} - token)} - {maxAbsComponentAct} - /> - {:else if componentData.componentDetail?.status === "error"} - Error loading component data: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading component data: {String(activationExamples.error)} {:else} - Something went wrong loading component data. + {/if} + + {#if componentData.datasetAttributions?.status === "loaded" && componentData.datasetAttributions.data} + + {:else if componentData.datasetAttributions?.status === "loading"} +
+ + Loading... +
+ {:else if componentData.datasetAttributions?.status === "error"} +
+ + Error: {String(componentData.datasetAttributions.error)} +
+ {/if} +
{#if componentData.tokenStats.status === "uninitialized" || componentData.tokenStats.status === "loading"} Loading token stats... @@ -285,7 +488,7 @@
- {#if showCorrelations} + {#if anyCorrelationStatsEnabled()}
{#if componentData.correlations.status === "uninitialized" || componentData.correlations.status === "loading"} @@ -295,7 +498,10 @@ {:else if componentData.correlations.data === null} No correlations data. Run harvest pipeline first. {:else} - + {/if}
{/if} @@ -436,6 +642,45 @@ margin: 0; } + .ci-plot { + position: relative; + width: 100%; + border: 1px solid var(--border-default); + background: var(--bg-elevated); + } + + .plot-toggle { + position: absolute; + top: var(--space-1); + right: var(--space-2); + display: flex; + align-items: center; + gap: var(--space-1); + font-size: var(--text-xs); + font-family: var(--font-mono); + color: var(--text-muted); + cursor: pointer; + z-index: 1; + } + + .plot-toggle input { + cursor: pointer; + } + + .ci-plot svg { + display: block; + } + + .ci-plot .plot-label { + font-size: var(--text-xs); + font-family: var(--font-mono); + fill: var(--text-muted); + } + + .ci-plot .plot-hitarea { + cursor: pointer; + } + .component-section { display: flex; flex-direction: column; @@ -468,11 +713,15 @@ gap: var(--space-2); } - .loading { - padding: var(--space-4); - text-align: center; - font-size: var(--text-sm); - font-family: var(--font-sans); - color: var(--text-muted); + .interpretation-badges { + display: flex; + flex-direction: column; + gap: var(--space-2); + } + + .dataset-attributions-loading { + display: flex; + flex-direction: column; + gap: var(--space-2); } diff --git a/spd/app/frontend/src/components/ClusterComponentCard.svelte b/spd/app/frontend/src/components/ClusterComponentCard.svelte new file mode 100644 index 000000000..e5c6769b5 --- /dev/null +++ b/spd/app/frontend/src/components/ClusterComponentCard.svelte @@ -0,0 +1,254 @@ + + +
+
+

{layer}:{cIdx}

+
+ {#if componentData.componentDetail.status === "loaded"} + Mean CI: {formatNumericalValue(componentData.componentDetail.data.mean_ci)} + {/if} + {#if intruderScore !== null} + Intruder: {Math.round(intruderScore * 100)}% + {/if} +
+
+ + + +
+ + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + + {/if} +
+ + + + {#if componentData.datasetAttributions.status === "uninitialized"} + uninitialized + {:else if componentData.datasetAttributions.status === "loaded"} + {#if componentData.datasetAttributions.data !== null} + + {:else} + No dataset attributions available. + {/if} + {:else if componentData.datasetAttributions.status === "loading"} +
+ + Loading... +
+ {:else if componentData.datasetAttributions.status === "error"} +
+ + Error: {String(componentData.datasetAttributions.error)} +
+ {/if} + +
+ +
+ {#if componentData.tokenStats.status === "loading" || componentData.tokenStats.status === "uninitialized"} + Loading token stats... + {:else if componentData.tokenStats.status === "error"} + Error: {String(componentData.tokenStats.error)} + {:else} + + + + {/if} +
+
+ + {#if anyCorrelationStatsEnabled()} +
+ + {#if componentData.correlations.status === "loading"} + Loading... + {:else if componentData.correlations.status === "loaded" && componentData.correlations.data} + + {:else if componentData.correlations.status === "error"} + Error loading correlations: {String(componentData.correlations.error)} + {:else} + No correlations available. + {/if} +
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClusterPathInput.svelte b/spd/app/frontend/src/components/ClusterPathInput.svelte index d4f769a55..6adcfb2b3 100644 --- a/spd/app/frontend/src/components/ClusterPathInput.svelte +++ b/spd/app/frontend/src/components/ClusterPathInput.svelte @@ -1,14 +1,14 @@
- {#if runState.clusterMapping?.filePath} + {#if runState.clusterMapping}
Clusters: - {runState.clusterMapping?.filePath.split("_").pop()?.replace(".json", "")} + {clusterRunId(runState.clusterMapping.filePath)} {#if showLoadedTooltip} @@ -85,7 +85,7 @@ {#if loadedClusterNotes}
{loadedClusterNotes}
{/if} -
{runState.clusterMapping?.filePath}
+
{runState.clusterMapping.filePath}
{/if}
@@ -94,7 +94,7 @@
- {mapping.path.split("_").pop()?.replace(".json", "")} + {clusterRunId(mapping.path)} {mapping.notes} @@ -306,7 +306,7 @@ background: var(--bg-elevated); border: 1px solid var(--border-strong); border-radius: var(--radius-md); - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); + box-shadow: var(--shadow-md); z-index: 1000; min-width: 200px; max-width: 400px; diff --git a/spd/app/frontend/src/components/ClustersTab.svelte b/spd/app/frontend/src/components/ClustersTab.svelte new file mode 100644 index 000000000..4ecf586c3 --- /dev/null +++ b/spd/app/frontend/src/components/ClustersTab.svelte @@ -0,0 +1,27 @@ + + +
+ {#if clusterMapping} + + {:else} + No clusters loaded. Use the cluster path input in the header bar to load a cluster mapping. + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClustersViewer.svelte b/spd/app/frontend/src/components/ClustersViewer.svelte new file mode 100644 index 000000000..6ff194c82 --- /dev/null +++ b/spd/app/frontend/src/components/ClustersViewer.svelte @@ -0,0 +1,249 @@ + + +
+ {#if selectedClusterId === null} +
+

Clusters ({clusterGroups.sorted.length})

+ {#each clusterGroups.sorted as [clusterId, members] (clusterId)} + {@const previewLabels = getPreviewLabels(members)} + + {/each} + {#if clusterGroups.singletons.length > 0} + + {/if} +
+ {:else} +
+
+ +

+ {selectedClusterId === "unclustered" ? "Unclustered" : `Cluster ${selectedClusterId}`} +

+ {selectedMembers.length} components +
+
+ {#each selectedMembers as member (`${member.layer}:${member.cIdx}`)} +
+ +
+ {/each} +
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ComponentProbeInput.svelte b/spd/app/frontend/src/components/ComponentProbeInput.svelte index bde9576ba..bf537c38e 100644 --- a/spd/app/frontend/src/components/ComponentProbeInput.svelte +++ b/spd/app/frontend/src/components/ComponentProbeInput.svelte @@ -1,8 +1,8 @@
-
-
Custom Text
- {#if displaySettings.exampleColorMode === "ci"} - (Change "Color by" above to "Both" to see subcomponent activations) - {/if} -
- - {#if probeLoading} +
Custom Text
+ + {#if probeResult.status === "loading"}

Loading...

- {:else if probeError} -

{probeError}

- {:else if probeResult && probeResult.tokens.length > 0} + {:else if probeResult.status === "error"} +

{probeResult.error}

+ {:else if probeResult.status === "loaded" && probeResult.data.tokens.length > 0}
@@ -93,28 +78,6 @@ border: 1px solid var(--border-default); } - .header-with-hint { - display: flex; - align-items: baseline; - gap: var(--space-2); - margin-bottom: var(--space-2); - } - - h5 { - margin: 0; - font-size: var(--text-sm); - font-family: var(--font-sans); - color: var(--text-secondary); - font-weight: 600; - } - - .hint { - font-size: var(--text-xs); - font-family: var(--font-sans); - color: var(--text-muted); - font-style: italic; - } - .probe-input { width: 100%; padding: var(--space-2); diff --git a/spd/app/frontend/src/components/DataSourcesTab.svelte b/spd/app/frontend/src/components/DataSourcesTab.svelte new file mode 100644 index 000000000..c54a1fa0a --- /dev/null +++ b/spd/app/frontend/src/components/DataSourcesTab.svelte @@ -0,0 +1,350 @@ + + +
+ +
+ {#if runState.run.status === "loaded" && runState.run.data.config_yaml} +
+

Run Config

+
{runState.run.data.config_yaml}
+
+ {/if} + +
+

Target Model

+ {#if pretrainData.status === "loading"} +

Loading...

+ {:else if pretrainData.status === "loaded"} + {@const pt = pretrainData.data} +
+ Architecture + {pt.summary} + {#if pt.pretrain_wandb_path} + Pretrain run + {pt.pretrain_wandb_path} + {/if} +
+ {#if pt.topology} +
+ +
+ {/if} + {#if pt.pretrain_config} +
+ Pretraining config +
{formatPretrainConfigYaml(pt.pretrain_config)}
+
+ {/if} + {:else if pretrainData.status === "error"} +

Failed to load target model info

+ {/if} +
+
+ + +
+ +
+
+ +

Harvest

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.harvest} + {@const harvest = data.data.harvest} +
+ Subrun + {harvest.subrun_id} + Components + {harvest.n_components.toLocaleString()} + Intruder eval + {harvest.has_intruder_scores ? "yes" : "no"} + {#each Object.entries(harvest.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Autointerp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.autointerp} + {@const autointerp = data.data.autointerp} +
+ Subrun + {autointerp.subrun_id} + Interpretations + {autointerp.n_interpretations.toLocaleString()} + Eval scores + + {#if autointerp.eval_scores.length > 0} + {autointerp.eval_scores.join(", ")} + {:else} + none + {/if} + + {#each Object.entries(autointerp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Dataset Attributions

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.attributions} + {@const attributions = data.data.attributions} +
+ Subrun + {attributions.subrun_id} + Tokens + {attributions.n_tokens_processed.toLocaleString()} + CI threshold + {attributions.ci_threshold} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Graph Interp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.graph_interp} + {@const graph_interp = data.data.graph_interp} +
+ Subrun + {graph_interp.subrun_id} + {#each Object.entries(graph_interp.label_counts) as [key, value] (key)} + {key} labels + {value.toLocaleString()} + {/each} + {#if graph_interp.config} + {#each Object.entries(graph_interp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} + {/if} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+
+
+ + diff --git a/spd/app/frontend/src/components/DatasetExplorerTab.svelte b/spd/app/frontend/src/components/DatasetExplorerTab.svelte new file mode 100644 index 000000000..4befbc7b2 --- /dev/null +++ b/spd/app/frontend/src/components/DatasetExplorerTab.svelte @@ -0,0 +1,97 @@ + + +
+
+ + +
+ +
+
+ +
+
+ +
+
+
+ + diff --git a/spd/app/frontend/src/components/DatasetRandomPanel.svelte b/spd/app/frontend/src/components/DatasetRandomPanel.svelte new file mode 100644 index 000000000..cda9ef0d7 --- /dev/null +++ b/spd/app/frontend/src/components/DatasetRandomPanel.svelte @@ -0,0 +1,357 @@ + + +
+
+
+ Random Dataset Samples +
+ + +
+
+
+
+ + +
+
+ + +
+
+ {#if randomSamples.status === "loaded"} + + {/if} +
+ +
+ {#if randomPageResults} +
+
+ P(token): + Low + + High +
+
+ {#each randomPageResults.results as sample, idx (idx)} + + {/each} +
+ {#if randomPageResults.total_pages > 1} + + {/if} +
+ {:else if randomSamples.status === "loading"} +
Loading random samples...
+ {:else if randomSamples.status === "error"} +
Error: {randomSamples.error}
+ {:else} +
+

Click "Load Samples" to fetch random stories

+
+ {/if} +
+
+ + diff --git a/spd/app/frontend/src/components/DatasetSearchTab.svelte b/spd/app/frontend/src/components/DatasetSearchPanel.svelte similarity index 56% rename from spd/app/frontend/src/components/DatasetSearchTab.svelte rename to spd/app/frontend/src/components/DatasetSearchPanel.svelte index 655ef40c6..cdb073509 100644 --- a/spd/app/frontend/src/components/DatasetSearchTab.svelte +++ b/spd/app/frontend/src/components/DatasetSearchPanel.svelte @@ -1,115 +1,118 @@ -
+
- Search SimpleStories Dataset -
-
+
- +
- -
- {#if metadata} + {#if searchMetadata} - {/if} - {#if error} -
- {error} + Found {searchMetadata.total_results} results in {searchMetadata.search_time_seconds.toFixed(2)}s
{/if}
- {#if currentPageResults} + {#if searchResults.status === "loaded"} - {:else if loading} + {:else if searchResults.status === "loading"}
Searching dataset...
+ {:else if searchResults.status === "error"} +
Error: {searchResults.error}
{:else}

No search performed yet

-

Enter a query above to search the SimpleStories dataset

+

Enter a query above to search the dataset

{/if}
diff --git a/spd/app/frontend/src/components/DatasetSearchResults.svelte b/spd/app/frontend/src/components/DatasetSearchResults.svelte index 38a62daa9..799f7e38c 100644 --- a/spd/app/frontend/src/components/DatasetSearchResults.svelte +++ b/spd/app/frontend/src/components/DatasetSearchResults.svelte @@ -1,8 +1,9 @@ + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ + {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} + {#if selected.data.research_log} + + {:else} +

No research log available

+ {/if} + {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ +
{ + e.preventDefault(); + handleLaunch(); + }} + > + + +
+ {#if launchState.status === "error"} +
{launchState.error}
+ {/if} + + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

+ No investigations found. Run spd-investigate to create one. +

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ModelGraph.svelte b/spd/app/frontend/src/components/ModelGraph.svelte new file mode 100644 index 000000000..ae49a25e3 --- /dev/null +++ b/spd/app/frontend/src/components/ModelGraph.svelte @@ -0,0 +1,520 @@ + + +
+ +
+
+ + + +
+
+ +
+
+ +
+
+ +
+
+ {filteredNodes.length} nodes, {visibleEdges.length} edges +
+
+ + + +
+
+ + + + + + {#if tooltipNode} +
+
{tooltipNode.label}
+
+ {tooltipNode.confidence} + {tooltipNode.key} +
+
+ {/if} + + + {#if selectedNodeKey} + {@const selectedNode = layout.nodes.get(selectedNodeKey)} + {#if selectedNode} +
+
+ {selectedNode.label} + {selectedNode.confidence} + +
+
{selectedNode.key}
+
+ {#if selectedNodeEdges.length > 0} +
+ {#each selectedNodeEdges as e, i (i)} + {@const other = e.source === selectedNodeKey ? e.target : e.source} + {@const otherNode = layout.nodes.get(other)} +
+ {e.source === selectedNodeKey ? "to" : "from"} + {otherNode?.label ?? other} + {e.attribution.toFixed(3)} +
+ {/each} +
+ {:else} + No edges + {/if} +
+
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ProbColoredTokens.svelte b/spd/app/frontend/src/components/ProbColoredTokens.svelte new file mode 100644 index 000000000..7b30870c3 --- /dev/null +++ b/spd/app/frontend/src/components/ProbColoredTokens.svelte @@ -0,0 +1,35 @@ + + +{#each tokens as tok, i (i)}{@const prob = getProbAtPosition(nextTokenProbs, i)}{/each} + + diff --git a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte index 6d9862175..c5e182917 100644 --- a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte @@ -1,16 +1,16 @@
- (activeCardPromptId = id)} + {tabView} + onSelectCard={handleSelectCard} onCloseCard={handleCloseCard} - onAddClick={() => (showPromptPicker = !showPromptPicker)} - /> - {}} - onGenerate={handleGeneratePrompts} - onClose={() => (showPromptPicker = false)} + onSelectDraft={handleStartNewDraft} + onAddClick={handleStartNewDraft} />
- {#if activeCard} - - - - {#if activeGraph} - -
- - +
+ + {#if prompts.length > 0} +
+ +
+ {#each prompts as prompt (prompt.id)} + {#if confirmingDeleteId === prompt.id} +
+ Delete prompt #{prompt.id}? + + +
+ {:else} +
+ + +
+ {/if} + {/each} +
+
{/if} - +
+
+ {:else if activeCard} + +
+
- {#if activeCard.activeView === "graph"} -
- (hideUnpinnedEdges = v)} - onHideNodeCardChange={(v) => (hideNodeCard = v)} + + + + {#if activeGraph} + + {#if activeGraph.data.optimization} + -
- L0: - {activeGraph.data.l0_total.toFixed(0)} active at ci threshold {activeGraph - .viewSettings.ciThreshold} - {#if pinnedNodes.length > 0} - {pinnedNodes.length} pinned - {/if} -
- {#key activeGraph.id} - +
+ + + {#if activeCard.activeView === "graph"} +
+ (hideUnpinnedEdges = v)} + onHideNodeCardChange={(v) => (hideNodeCard = v)} + /> +
+ L0: + {activeGraph.data.l0_total.toFixed(0)} active at ci threshold {activeGraph + .viewSettings.ciThreshold} + {#if pinnedNodes.length > 0} + {pinnedNodes.length} pinned + {/if} +
+ {#key activeGraph.id} + (filteredEdgeCount = count)} + onHoveredNodeChange={(node) => (hoveredNode = node)} + /> + {/key} +
+ + {:else if activeInterventionState} + (filteredEdgeCount = count)} + hideNodeCard={true} + onTopKChange={handleTopKChange} + onComponentGapChange={handleComponentGapChange} + onLayerGapChange={handleLayerGapChange} + onNormalizeChange={handleNormalizeChange} + onCiThresholdChange={handleCiThresholdChange} + onHideUnpinnedEdgesChange={(v) => (hideUnpinnedEdges = v)} + onHideNodeCardChange={(v) => (hideNodeCard = v)} + {runningIntervention} + {generatingSubgraph} + onSelectionChange={handleDraftSelectionChange} + onForwardDraft={handleForwardDraft} + onCloneRun={handleCloneRun} + onSelectVersion={handleSelectVersion} + onDeleteRun={handleDeleteRun} + onGenerateGraphFromSelection={handleGenerateGraphFromSelection} + onHoveredNodeChange={(node) => (hoveredNode = node)} /> - {/key} -
- - {:else if activeComposerState} - (hideUnpinnedEdges = v)} - onHideNodeCardChange={(v) => (hideNodeCard = v)} - {runningIntervention} - {generatingSubgraph} - onSelectionChange={handleComposerSelectionChange} - onRunIntervention={handleRunIntervention} - onSelectRun={handleSelectRun} - onDeleteRun={handleDeleteRun} - onForkRun={handleForkRun} - onDeleteFork={handleDeleteFork} - onGenerateGraphFromSelection={handleGenerateGraphFromSelection} - /> - {/if} - {:else} - - {#if graphCompute.status === "error"} -
- {graphCompute.error} - - + {/if}
- {/if} - -
- {#if graphCompute.status === "computing" && graphCompute.cardId === activeCard.id} - - {:else} -
-

Click Compute to generate the attribution graph

+ {:else} + + {#if graphCompute.status === "error"} +
+ {graphCompute.error} + +
{/if} + +
+ {#if graphCompute.status === "computing" && graphCompute.cardId === activeCard.id} + + {:else} +
+
+ {#if !hasStandardGraph} + + {/if} + {#if hasStandardGraph || activeCard.useOptimized} + + {/if} +
+ + {#if hasStandardGraph || activeCard.useOptimized} + + {/if} +
+
+
+ {/if} +
+ {/if} + {:else if tabView.view === "loading"} +
+

Loading prompt...

+
+ {:else if tabView.view === "error"} +
+

Error loading prompt: {tabView.error}

+
{/if} - {:else if promptCardLoading.status === "loading"} -
-

Loading prompt...

-
- {:else if promptCardLoading.status === "error"} -
-

Error loading prompt: {promptCardLoading.error}

- -
- {:else} -
-

Click + Add Prompt to get started

-

{prompts.length} prompts available

+
+ + {#if !hideNodeCard && stickyComponentNode && activeGraph} + +
+
+ {#key `${stickyComponentNode.layer}:${stickyComponentNode.cIdx}`} + { + handlePinnedNodesChange([ + ...pinnedNodes.filter( + (p) => !(p.layer === layer && p.seqIdx === seqIdx && p.cIdx === cIdx), + ), + { layer, seqIdx, cIdx }, + ]); + }} + /> + {/key}
{/if}
@@ -916,58 +1302,49 @@ .card-content { flex: 1; display: flex; - flex-direction: column; - gap: var(--space-2); min-height: 0; + min-width: 0; padding: var(--space-4); border: 1px solid var(--border-default); background: var(--bg-inset); } - .graph-view-tabs { + .card-content-main { + flex: 1; display: flex; - gap: var(--space-1); + flex-direction: column; + gap: var(--space-2); + min-height: 0; + min-width: 0; + overflow: auto; } - .graph-view-tab { - padding: var(--space-1) var(--space-3); - background: var(--bg-elevated); - border: 1px solid var(--border-default); - font-size: var(--text-sm); - font-weight: 500; - color: var(--text-secondary); - display: inline-flex; - align-items: center; - gap: var(--space-1); + .resize-handle { + width: 6px; + cursor: col-resize; + background: transparent; + flex-shrink: 0; + position: relative; } - .graph-view-tab:hover { - color: var(--text-primary); - border-color: var(--border-strong); - background: var(--bg-surface); + .resize-handle:hover, + .resize-handle:active { + background: var(--accent-primary-dim); } - .graph-view-tab.active { - color: white; - background: var(--accent-primary); - border-color: var(--accent-primary); + .node-detail-panel { + flex-shrink: 0; + overflow-y: auto; + border: 1px solid var(--border-default); + background: var(--bg-elevated); + padding: var(--space-3); } - .graph-view-tab .badge { - display: inline-flex; + .prompt-tokens { + display: flex; + flex-wrap: wrap; + gap: 1px; align-items: center; - justify-content: center; - min-width: 16px; - height: 16px; - padding: 0 4px; - font-size: var(--text-xs); - font-weight: 600; - background: rgba(255, 255, 255, 0.2); - border-radius: 8px; - } - - .graph-view-tab.active .badge { - background: rgba(255, 255, 255, 0.3); } .graph-info { @@ -1022,18 +1399,73 @@ font-size: var(--text-base); } - .empty-state strong { + .empty-state .error-text { + color: var(--status-negative-bright); + } + + .btn-compute-center { + padding: var(--space-2) var(--space-4); + background: var(--bg-elevated); + border: 1px dashed var(--accent-primary-dim); + font-size: var(--text-base); + font-family: var(--font-mono); + font-weight: 500; color: var(--accent-primary); + cursor: pointer; + } + + .btn-compute-center:hover { + background: var(--bg-inset); + border-style: solid; + border-color: var(--accent-primary); + } + + .compute-buttons { + display: flex; + gap: var(--space-2); + justify-content: center; } - .empty-state .hint { + .btn-compute-batch { + padding: var(--space-2) var(--space-3); + background: var(--bg-elevated); + border: 1px dashed var(--border-default); font-size: var(--text-sm); - color: var(--text-muted); font-family: var(--font-mono); + color: var(--text-secondary); + cursor: pointer; } - .empty-state .error-text { - color: var(--status-negative-bright); + .btn-compute-batch:hover { + background: var(--bg-inset); + border-style: solid; + border-color: var(--accent-primary-dim); + color: var(--accent-primary); + } + + .compute-controls { + display: flex; + flex-direction: column; + align-items: center; + gap: var(--space-3); + } + + .optimize-checkbox { + display: flex; + align-items: center; + gap: var(--space-1); + font-size: var(--text-sm); + font-family: var(--font-sans); + color: var(--text-secondary); + cursor: pointer; + } + + .optimize-checkbox:hover { + color: var(--text-primary); + } + + .optimize-checkbox input { + cursor: pointer; } .error-banner { @@ -1060,4 +1492,223 @@ .error-banner button:hover { background: var(--status-negative-bright); } + + /* Draft staging area styles */ + .draft-staging { + flex: 1; + display: flex; + align-items: center; + justify-content: center; + padding: var(--space-6); + } + + .draft-main { + display: grid; + grid-template-columns: 1fr 1fr; + gap: var(--space-6); + max-width: 900px; + width: 100%; + align-items: start; + } + + .draft-input-section { + display: flex; + flex-direction: column; + gap: var(--space-2); + } + + .draft-label { + font-size: var(--text-sm); + font-weight: 500; + color: var(--text-secondary); + } + + .draft-textarea { + width: 100%; + padding: var(--space-3); + border: 1px solid var(--border-default); + background: var(--bg-elevated); + color: var(--text-primary); + font-size: var(--text-sm); + font-family: var(--font-mono); + resize: vertical; + min-height: 120px; + } + + .draft-textarea:focus { + outline: none; + border-color: var(--accent-primary); + } + + .draft-textarea::placeholder { + color: var(--text-muted); + } + + .btn-add-prompt { + align-self: flex-start; + padding: var(--space-1) var(--space-3); + background: var(--accent-primary); + border: none; + color: white; + font-size: var(--text-sm); + font-family: var(--font-mono); + font-weight: 500; + cursor: pointer; + } + + .btn-add-prompt:hover:not(:disabled) { + background: var(--accent-primary-bright); + } + + .btn-add-prompt:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + .token-preview-row { + display: flex; + flex-wrap: wrap; + gap: 1px; + align-items: center; + } + + .token-preview-row.loading { + font-family: var(--font-mono); + font-size: var(--text-sm); + color: var(--text-muted); + } + + .token-preview-row.error { + font-family: var(--font-mono); + font-size: var(--text-sm); + color: var(--status-negative); + } + + .token-preview-row .token-count { + font-family: var(--font-mono); + font-size: var(--text-xs); + color: var(--text-muted); + margin-left: var(--space-2); + } + + .existing-prompts-section { + display: flex; + flex-direction: column; + gap: var(--space-2); + } + + .prompt-list { + display: flex; + flex-direction: column; + max-height: 400px; + overflow-y: auto; + background: var(--bg-inset); + border: 1px solid var(--border-default); + } + + .prompt-item-row { + display: flex; + align-items: stretch; + border-bottom: 1px solid var(--border-subtle); + } + + .prompt-item-row:last-child { + border-bottom: none; + } + + .prompt-item { + flex: 1; + padding: var(--space-2) var(--space-3); + background: transparent; + border: none; + cursor: pointer; + text-align: left; + display: flex; + gap: var(--space-2); + align-items: baseline; + color: var(--text-primary); + min-width: 0; + } + + .prompt-item:hover { + background: var(--bg-surface); + } + + .btn-delete-prompt { + padding: 0 var(--space-2); + background: transparent; + border: none; + color: var(--text-muted); + font-size: var(--text-base); + cursor: pointer; + flex-shrink: 0; + } + + .btn-delete-prompt:hover { + color: var(--status-negative-bright); + background: var(--bg-surface); + } + + .confirm-delete { + border-bottom: 1px solid var(--border-subtle); + padding: var(--space-2) var(--space-3); + display: flex; + align-items: center; + gap: var(--space-2); + } + + .confirm-delete:last-child { + border-bottom: none; + } + + .confirm-text { + font-size: var(--text-sm); + font-family: var(--font-mono); + color: var(--text-secondary); + flex: 1; + } + + .confirm-yes, + .confirm-no { + padding: var(--space-1) var(--space-2); + border: none; + font-size: var(--text-xs); + font-family: var(--font-mono); + cursor: pointer; + } + + .confirm-yes { + background: var(--status-negative); + color: white; + } + + .confirm-yes:hover { + background: var(--status-negative-bright); + } + + .confirm-no { + background: var(--bg-elevated); + color: var(--text-secondary); + border: 1px solid var(--border-default); + } + + .confirm-no:hover { + background: var(--bg-surface); + } + + .prompt-id { + font-size: var(--text-xs); + font-family: var(--font-mono); + color: var(--text-muted); + flex-shrink: 0; + } + + .prompt-text { + font-family: var(--font-mono); + font-size: var(--text-sm); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + color: var(--text-primary); + } diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index 3dac8bd1e..19db5f3bc 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -1,5 +1,22 @@
- {#if runState.run?.status === "loaded" && runState.run.data} -
(showRunMenu = true)} onmouseleave={() => (showRunMenu = false)}> - - {#if showRunMenu} -
-
{runState.run.data.config_yaml}
- -
- {/if} -
+ {#if runState.run.status === "loaded"} + {@const wandbParts = runState.run.data.wandb_path.split("/")} + + {runState.run.data.wandb_path} + (wandb) + {/if}
- {#if runState.run?.status === "loaded" && runState.run.data} + {#if runState.run.status === "loaded"}
@@ -70,29 +123,42 @@
- {#if runState.run?.status === "error"} + {#if runState.run.status === "error"}
{runState.run.error}
{/if} - -
- + +
+
- {#if runState.prompts.status === "loaded" && runState.allTokens.status === "loaded"} + {#if runState.prompts.status === "loaded"}
- +
- {:else if runState.run.status === "loading" || runState.prompts.status === "loading" || runState.allTokens.status === "loading"} -
+ {#if datasetSearchEnabled} +
+ +
+ {/if} +
+ +
+ {#if runState.clusterMapping} +
+ +
+ {/if} + {:else if runState.run.status === "loading" || runState.prompts.status === "loading"} +

Loading run...

{:else} -
+

Enter a W&B Path above to get started

{/if} @@ -116,80 +182,45 @@ min-height: 44px; } - /* Run menu - hoverable dropdown */ - .run-menu { - position: relative; - display: flex; - align-items: stretch; - } - - .run-menu-trigger { + .run-path { display: flex; align-items: center; gap: var(--space-2); padding: 0 var(--space-3); - margin: 0; - background: none; - border: none; border-right: 1px solid var(--border-default); - border-radius: 0; - cursor: pointer; - font: inherit; - font-size: var(--text-sm); - transition: background 0.15s; - } - - .run-menu-trigger:hover .run-path { - background: var(--bg-inset); - } - - .run-path { font-family: var(--font-mono); + font-size: var(--text-sm); color: var(--text-primary); } - .run-menu-dropdown { - position: absolute; - top: 100%; - left: 0; - z-index: 1000; - display: flex; - flex-direction: column; - gap: var(--space-2); - padding: var(--space-3); - background: var(--bg-elevated); - border: 1px solid var(--border-strong); - border-radius: var(--radius-md); - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08); + .wandb-link { + color: var(--text-muted); + text-decoration: none; + font-size: var(--text-xs); } - .config-yaml { - max-width: 420px; - max-height: 50vh; - overflow: auto; - margin: 0; - font-size: var(--text-xs); - font-family: var(--font-mono); - color: var(--text-primary); - white-space: pre-wrap; - word-wrap: break-word; + .wandb-link:hover { + color: var(--accent-primary); } .change-run-button { - padding: var(--space-2) var(--space-3); - background: var(--bg-inset); - border: 1px solid var(--border-default); - border-radius: var(--radius-sm); + display: flex; + align-items: center; + padding: 0 var(--space-3); + margin: 0; + background: none; + border: none; + border-right: 1px solid var(--border-default); + border-radius: 0; + font: inherit; font-size: var(--text-sm); - font-family: var(--font-sans); - color: var(--text-secondary); font-weight: 500; + color: var(--text-muted); cursor: pointer; - text-align: center; } .change-run-button:hover { - background: var(--bg-surface); + background: var(--bg-inset); color: var(--text-primary); } @@ -211,8 +242,8 @@ color: var(--text-muted); cursor: pointer; transition: - color 0.15s, - background 0.15s; + color var(--transition-normal), + background var(--transition-normal); } .tab-button:hover { diff --git a/spd/app/frontend/src/components/TokenHighlights.svelte b/spd/app/frontend/src/components/TokenHighlights.svelte index cae4cfb15..456cdbc72 100644 --- a/spd/app/frontend/src/components/TokenHighlights.svelte +++ b/spd/app/frontend/src/components/TokenHighlights.svelte @@ -1,38 +1,35 @@ + +
+
+ #{index + 1} + {#each Object.entries(sample.metadata) as [metaKey, metaVal] (metaKey)} + {metaVal} + {/each} +
+
+ +
+
+ + diff --git a/spd/app/frontend/src/components/TokenizedSearchResultCard.svelte b/spd/app/frontend/src/components/TokenizedSearchResultCard.svelte new file mode 100644 index 000000000..c49e2454f --- /dev/null +++ b/spd/app/frontend/src/components/TokenizedSearchResultCard.svelte @@ -0,0 +1,142 @@ + + +
+
+ #{index + 1} + {#if result.occurrence_count > 0} + {result.occurrence_count} occurrence{result.occurrence_count !== 1 ? "s" : ""} + {/if} + {#each Object.entries(result.metadata) as [metaKey, metaVal] (metaKey)} + {metaVal} + {/each} +
+
+ {#each result.tokens as tok, i (i)}{@const prob = getProbAtPosition(result.next_token_probs, i)}{/each} +
+
+ + diff --git a/spd/app/frontend/src/components/TopologyDiagram.svelte b/spd/app/frontend/src/components/TopologyDiagram.svelte new file mode 100644 index 000000000..abdd55cdc --- /dev/null +++ b/spd/app/frontend/src/components/TopologyDiagram.svelte @@ -0,0 +1,112 @@ + + +
+
embed
+ {#each topology.block_structure as block (block.index)} +
+ {block.index} +
+
+ {block.attn_type === "fused" ? "attn_fused" : "attn"} +
+ {#each block.attn_projections as proj (proj)} + {proj} + {/each} +
+
+
+ {block.ffn_type} +
+ {#each block.ffn_projections as proj (proj)} + {proj} + {/each} +
+
+
+
+ {/each} +
output
+
+ + diff --git a/spd/app/frontend/src/components/investigations/ArtifactGraph.svelte b/spd/app/frontend/src/components/investigations/ArtifactGraph.svelte new file mode 100644 index 000000000..567d373dd --- /dev/null +++ b/spd/app/frontend/src/components/investigations/ArtifactGraph.svelte @@ -0,0 +1,434 @@ + + +
+ {#if caption} +
{caption}
+ {/if} + + +
+ + +
+ + + {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} + + {getRowLabel(getRowKey(layer))} + + {/each} + + +
+ +
+ + + + + {@html edgesSvg} + + + + {#each Object.entries(layout.nodePositions) as [key, pos] (key)} + {@const style = nodeStyles[key]} + {@const [layer, seqIdxStr, cIdxStr] = key.split(":")} + {@const seqIdx = parseInt(seqIdxStr)} + {@const cIdx = parseInt(cIdxStr)} + + handleNodeHover(e, layer, seqIdx, cIdx)} + onmouseleave={handleNodeLeave} + /> + + + {/each} + + + + +
+ + + {#each data.tokens as token, i (i)} + {@const colLeft = layout.seqXStarts[i] + 8} + + {token} + + [{i}] + {/each} + + +
+
+
+ +
+ L0: {data.l0_total} · Edges: {filteredEdges.length} +
+ + + {#if hoveredNode && runState} + (isHoveringTooltip = true)} + onMouseLeave={() => { + isHoveringTooltip = false; + hoveredNode = null; + }} + /> + {/if} +
+ + diff --git a/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte new file mode 100644 index 000000000..ef9b8d40d --- /dev/null +++ b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte @@ -0,0 +1,223 @@ + + +
+ {#each contentBlocks as block, i (i)} + {#if block.type === "html"} + +
{@html block.content}
+ {:else if block.type === "graph"} + {@const artifact = artifacts[block.artifactId]} + {#if artifact} + + {:else if artifactsLoading} +
+ Loading graph: {block.artifactId}... +
+ {:else} +
+ Graph artifact not found: {block.artifactId} +
+ {/if} + {/if} + {/each} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/ComponentCorrelationPills.svelte b/spd/app/frontend/src/components/prompt-attr/ComponentCorrelationPills.svelte index 802927d32..47b5e3831 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComponentCorrelationPills.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComponentCorrelationPills.svelte @@ -1,13 +1,13 @@
- - {#if componentData.componentDetail?.status === "loaded"} - Mean CI: {formatNumericalValue(componentData.componentDetail.data.mean_ci)} - {/if} - +
+

{layer}:{seqIdx}:{cIdx}

+
"{token}"
+
+ {#if ciVal !== null} + CI: {formatNumericalValue(ciVal)} + {/if} + {#if subcompAct !== null} + Subcomp Act: {formatNumericalValue(subcompAct)} + {/if} + {#if clusterId !== undefined} + Cluster: {clusterId ?? "null"} + {/if} + {#if componentData.componentDetail.status === "loaded"} + Mean CI: {formatNumericalValue(componentData.componentDetail.data.mean_ci)} + {/if} + {#if intruderScore !== null} + Intruder: {Math.round(intruderScore * 100)}% + {/if} +
+
- +
+ + {#if graphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail?.status === "loading"} - Loading details... - {:else if componentData.componentDetail?.status === "loaded"} - {#if componentData.componentDetail.data.example_tokens.length > 0} - - {/if} - {:else if componentData.componentDetail?.status === "error"} - Error loading details: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + {:else} - Something went wrong loading details. + {/if}
@@ -218,72 +235,86 @@ title="Prompt Attributions" incomingLabel="Incoming" outgoingLabel="Outgoing" - {incomingPositive} - {incomingNegative} - {outgoingPositive} - {outgoingNegative} - pageSize={4} + {incoming} + {outgoing} + pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE} onClick={handleEdgeNodeClick} - {tokens} - {outputProbs} /> {/if} - {#if componentData.datasetAttributions?.status === "loaded" && componentData.datasetAttributions.data} - - {:else if componentData.datasetAttributions?.status === "loading"} + {#if componentData.datasetAttributions.status === "loading" || componentData.datasetAttributions.status === "uninitialized"}
- Loading... +
+
+
+
- {:else if componentData.datasetAttributions?.status === "error"} + {:else if componentData.datasetAttributions.status === "loaded"} + {#if componentData.datasetAttributions.data !== null} + + {/if} + {:else if componentData.datasetAttributions.status === "error"}
Error: {String(componentData.datasetAttributions.error)}
{/if} -
- {#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} - Loading token stats... - {:else if componentData.tokenStats.status === "error"} - Error: {String(componentData.tokenStats.error)} - {:else} - +
+ +
+ {#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} +
+
+
+
+
+
+ {:else if componentData.tokenStats.status === "error"} + Error: {String(componentData.tokenStats.error)} + {:else} + - - {/if} + + {/if} +
-
- - {#if componentData.correlations?.status === "loading"} - Loading... - {:else if componentData.correlations?.status === "loaded" && componentData.correlations.data} - - {:else if componentData.correlations?.status === "error"} - Error loading correlations: {String(componentData.correlations.error)} - {:else} - No correlations available. - {/if} -
+ {#if anyCorrelationStatsEnabled()} +
+ + {#if componentData.correlations.status === "loading"} +
+
+
+
+ {:else if componentData.correlations.status === "loaded" && componentData.correlations.data} + + {:else if componentData.correlations.status === "error"} + Error loading correlations: {String(componentData.correlations.error)} + {:else} + No correlations available. + {/if} +
+ {/if}
diff --git a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte index cf4731a6f..23619c281 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte @@ -1,46 +1,54 @@
-
- {#each state.stages as stage, i (i)} - {@const isCurrent = i === state.currentStage} - {@const isComplete = i < state.currentStage} -
-
- {i + 1} - {stage.name} - {#if isComplete} - - {/if} -
- {#if isCurrent} -
- {#if stage.progress !== null} -
- {:else} -
+
+ {#if ciSnapshot} + + {/if} +
+ {#each state.stages as stage, i (i)} + {@const isCurrent = i === state.currentStage} + {@const isComplete = i < state.currentStage} +
+
+ {i + 1} + {stage.name} + {#if isComplete} + {/if}
- {:else if isComplete} -
-
-
- {:else} -
- {/if} -
- {/each} + {#if isCurrent} +
+ {#if stage.progress !== null} +
+ {:else} +
+ {/if} +
+ {:else if isComplete} +
+
+
+ {:else} +
+ {/if} +
+ {/each} +
@@ -54,6 +62,13 @@ z-index: 100; } + .content { + display: flex; + flex-direction: column; + align-items: center; + gap: var(--space-6); + } + .stages { display: flex; flex-direction: column; @@ -132,7 +147,7 @@ .progress-fill { height: 100%; background: var(--accent-primary); - transition: width 0.15s ease-out; + transition: width var(--transition-normal); } .stage.complete .progress-fill { diff --git a/spd/app/frontend/src/components/prompt-attr/GraphTabs.svelte b/spd/app/frontend/src/components/prompt-attr/GraphTabs.svelte new file mode 100644 index 000000000..2d403a8b3 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/GraphTabs.svelte @@ -0,0 +1,131 @@ + + +
+ {#each graphs as graph (graph.id)} +
+ + +
+ {/each} + {#if isNewGraphMode} +
+ New Graph +
+ {/if} + {#if graphs.length > 0 && !isNewGraphMode} + + {/if} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index 8d9bb906f..84d0fdc25 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -1,16 +1,20 @@ @@ -659,30 +662,53 @@
- {selectedCount} / {interventableCount} selected + {selectedCount} / {interventableCount} nodes{#if !isEditable} (read-only){/if} ?
- - - - + {#if isEditable} + + + + + + + + + + {:else} + + {/if}
@@ -696,16 +722,8 @@ onmouseup={zoom.endPan} onmouseleave={zoom.endPan} > - - -
+
+ + {#if predRows && PRED_AREA_HEIGHT > 0} +
+ + + {#each predRows as row, rowIdx (`pred-${row.label}`)} + {@const rowY = rowIdx * (PRED_ROW_HEIGHT + PRED_ROW_GAP) + PRED_ROW_GAP} + + {row.label} + + {#each row.preds as preds, seqIdx (seqIdx)} + {@const colX = layout.seqXStarts[seqIdx]} + {@const colW = layout.seqWidths[seqIdx]} + {@const chipW = 48} + {@const chipH = PRED_ROW_HEIGHT} + {@const chipGap = 1} + {@const isLabelPos = + interventionResult?.label != null && + seqIdx === interventionResult.label.position} + {@const labelTokenId = isLabelPos ? (row.labelPred?.token_id ?? null) : null} + {@const labelInTopk = + labelTokenId != null && preds.some((p) => p.token_id === labelTokenId)} + {@const maxChips = Math.min( + preds.length, + Math.max(1, Math.floor((colW - 2 + chipGap) / (chipW + chipGap))), + )} + {#each preds.slice(0, maxChips) as pred, rank (rank)} + {@const cx = colX + rank * (chipW + chipGap)} + {@const isLabel = labelTokenId != null && pred.token_id === labelTokenId} + + handlePredMouseEnter(e, pred, row.label, seqIdx)} + onmouseleave={handlePredMouseLeave} + > + + 0.5 ? "white" : colors.textPrimary} + >{pred.token} + + {/each} + + {#if isLabelPos && !labelInTopk && row.labelPred} + {@const cx = colX + maxChips * (chipW + chipGap) + chipGap} + + + handlePredMouseEnter(e, row.labelPred!, row.label, seqIdx)} + onmouseleave={handlePredMouseLeave} + > + + 0.5 ? "white" : colors.textPrimary} + >{row.labelPred.token} + + {/if} + {/each} + {/each} + + +
+ {/if} + + + {#if optimizationTarget} + {@const pos = optimizationTarget.position} + {@const xStart = layout.seqXStarts[pos]} + {@const width = layout.seqWidths[pos]} + + {/if} + {#each layout.clusterSpans as span (`${span.layer}:${span.seqIdx}:${span.clusterId}`)} @@ -837,7 +979,7 @@ y={selectionRect.y} width={selectionRect.width} height={selectionRect.height} - fill="rgba(99, 102, 241, 0.1)" + fill={rgbaToCss(colors.positiveRgb, 0.1)} stroke={colors.accent} stroke-width="1" stroke-dasharray="4 2" @@ -855,11 +997,11 @@ > {#each tokens as token, i (i)} - {@const colCenter = layout.seqXStarts[i] + layout.seqWidths[i] / 2} + {@const colX = layout.seqXStarts[i]} [{i}] + {#if optimizationTarget && i === optimizationTarget.position} + + {/if} {/each}
+ + +
+ +
- +
- Run History - {graph.interventionRuns.length} runs + Versions + {interventionState.runs.length}
- {#if graph.interventionRuns.length === 0} -
-

No runs yet

-

Select nodes and click Run

-
- {:else} -
- {#each graph.interventionRuns.slice().reverse() as run (run.id)} - {@const isActive = activeRunId === run.id} -
onSelectRun(run.id)} - onkeydown={(e) => e.key === "Enter" && onSelectRun(run.id)} - > +
+ {#each interventionState.runs as run, index (runIdentityKey(run, index))} + {@const isActive = index === interventionState.activeIndex} +
onSelectVersion(index)} + onkeydown={(e) => e.key === "Enter" && onSelectVersion(index)} + > + {#if run.kind === "draft"}
- {formatTime(run.created_at)} - {run.selected_nodes.length} nodes - - + Draft + {run.selectedNodes.size} nodes + (not forwarded)
- - -
- - - - {#each run.result.input_tokens as token, idx (idx)} - - {/each} - - - - {#each Array(Math.min(3, MAX_PREDICTIONS)) as _, rank (rank)} - - {#each run.result.predictions_per_position as preds, idx (idx)} - {@const pred = preds[rank]} - - {/each} - - {/each} - -
- "{token}" -
- {#if pred} - "{pred.token}" - SPD: {formatProb(pred.spd_prob)} (logit: {formatLogit( - pred.logit, - )}) - Targ: {formatProb(pred.target_prob)} (logit: {formatLogit( - pred.target_logit, - )}) - {:else} - - - {/if} -
+ {:else} +
+ {index === 0 ? "Base" : formatTime(run.createdAt)} + {run.selectedNodes.size} nodes + {#if index > 0} + + {/if}
- - - {#if run.forked_runs && run.forked_runs.length > 0} -
-
- - {run.forked_runs.length} fork{run.forked_runs.length > 1 ? "s" : ""} + {#if isActive} + {@const opt = graph.data.optimization} + {@const lossLabel = opt + ? opt.loss.type === "ce" + ? `CE "${opt.loss.label_str}" @ ${opt.loss.position}` + : `KL @ ${opt.loss.position}` + : "mean KL"} +
+
+ CI + {run.result.ci_loss.toFixed(3)}
- {#each run.forked_runs as fork (fork.id)} -
-
- {formatTime(fork.created_at)} - {fork.token_replacements.length} change{fork.token_replacements - .length > 1 - ? "s" - : ""} - -
- -
- - - - {#each fork.result.input_tokens as token, idx (idx)} - {@const isChanged = fork.token_replacements.some( - (r) => r[0] === idx, - )} - - {/each} - - - - {#each Array(Math.min(3, MAX_PREDICTIONS)) as _, rank (rank)} - - {#each fork.result.predictions_per_position as preds, idx (idx)} - {@const pred = preds[rank]} - - {/each} - - {/each} - -
- "{token}" -
- {#if pred} - "{pred.token}" - SPD: {formatProb(pred.spd_prob)} (logit: {formatLogit( - pred.logit, - )}) - Targ: {formatProb(pred.target_prob)} (logit: - {formatLogit(pred.target_logit)}) - {:else} - - - {/if} -
-
+
+ stoch + {run.result.stochastic_loss.toFixed(3)} +
+
+ adv + {run.result.adversarial_loss.toFixed(3)} +
+ {#if run.result.ablated_loss != null} +
+ T\S + {run.result.ablated_loss.toFixed(3)}
- {/each} + {/if} +
+ metric + {lossLabel} +
+ {#if opt} +
+ L0 + {opt.metrics.l0_total.toFixed(1)} +
+ {/if}
{/if} -
- {/each} -
- {/if} + {/if} +
+ {/each} +
+ +
+ +
@@ -1066,8 +1139,8 @@ nodeCiVals={graph.data.nodeCiVals} nodeSubcompActs={graph.data.nodeSubcompActs} {tokens} - edgesBySource={graph.data.edgesBySource} - edgesByTarget={graph.data.edgesByTarget} + edgesBySource={activeEdgesBySource} + edgesByTarget={activeEdgesByTarget} onMouseEnter={() => (isHoveringTooltip = true)} onMouseLeave={() => { isHoveringTooltip = false; @@ -1076,57 +1149,30 @@ /> {/if} - - {#if forkingRunId !== null} + + {#if hoveredPred} + {@const p = hoveredPred.pred} + {@const pos = predTooltipPos} {/if} + +
diff --git a/spd/app/frontend/src/components/prompt-attr/NodeTooltip.svelte b/spd/app/frontend/src/components/prompt-attr/NodeTooltip.svelte index 84ead3b54..32af775bc 100644 --- a/spd/app/frontend/src/components/prompt-attr/NodeTooltip.svelte +++ b/spd/app/frontend/src/components/prompt-attr/NodeTooltip.svelte @@ -1,10 +1,13 @@
e.stopPropagation()} > -

{getLayerDisplayName(hoveredNode.layer)}:{hoveredNode.seqIdx}:{hoveredNode.cIdx}

+

{hoveredNode.layer}:{hoveredNode.seqIdx}:{hoveredNode.cIdx}

{#if isComponent && ciVal !== null}
CI: {ciVal.toFixed(3)}
{/if} @@ -103,8 +117,19 @@ {hoveredNode.seqIdx}

+ {#if displaySettings.showEdgeAttributions && wteOutgoing.length > 0} + {}} + /> + {/if} {:else if isOutput} - + {:else if !hideNodeCard} {#key `${hoveredNode.layer}:${hoveredNode.cIdx}`} @@ -112,6 +137,9 @@ layer={hoveredNode.layer} cIdx={hoveredNode.cIdx} seqIdx={hoveredNode.seqIdx} + {ciVal} + {subcompAct} + {token} {edgesBySource} {edgesByTarget} {tokens} @@ -133,31 +161,15 @@ max-width: 800px; max-height: 80vh; overflow-y: auto; - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); + box-shadow: var(--shadow-md); } - .ci-value { - font-size: var(--text-sm); + .node-tooltip h3 { + font-size: var(--text-base); font-family: var(--font-mono); - color: var(--accent-primary); - font-weight: 600; - margin: var(--space-1) 0 var(--space-2) 0; - } - - .subcomp-act { - font-size: var(--text-sm); - font-family: var(--font-mono); - color: var(--accent-secondary); - font-weight: 600; - margin: var(--space-1) 0 var(--space-2) 0; - } - - .cluster-id { - font-size: var(--text-sm); - font-family: var(--font-mono); - color: var(--text-secondary); font-weight: 600; - margin: var(--space-1) 0 var(--space-2) 0; + color: var(--text-primary); + margin: 0 0 var(--space-2) 0; } .wte-info { diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte new file mode 100644 index 000000000..3fa6a0213 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte @@ -0,0 +1,134 @@ + + +
+
+ + Step {snapshot.step}/{snapshot.total_steps} + + + L0: {Math.round(snapshot.l0_total)} / {initialL0} + ({(fractionRemaining * 100).toFixed(0)}%) + + {#if snapshot.loss > 0} + loss: {snapshot.loss.toFixed(4)} + {/if} +
+ +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte new file mode 100644 index 000000000..83d9f0594 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte @@ -0,0 +1,150 @@ + + +
+ steps{optimization.steps} + imp_min{optimization.imp_min_coeff} + pnorm{optimization.pnorm} + beta{optimization.beta} + mask{optimization.mask_type} + + {optimization.loss.type}{optimization.loss.coeff} + + + pos + {optimization.loss.position} + {#if tokenAtPos !== null} + ({tokenAtPos}) + {/if} + + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + + label({optimization.loss.label_str}) + + {/if} + {#if optimization.pgd} + + pgd_steps{optimization.pgd.n_steps} + + + pgd_lr{optimization.pgd.step_size} + + {/if} + + + L0{optimization.metrics.l0_total.toFixed(1)} + + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + + CI prob{formatProb(optimization.metrics.ci_masked_label_prob)} + + + stoch prob{formatProb(optimization.metrics.stoch_masked_label_prob)} + + + adv prob{formatProb(optimization.metrics.adv_pgd_label_prob)} + + {/if} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte new file mode 100644 index 000000000..3c15034b4 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte @@ -0,0 +1,537 @@ + + +
+ +
+ + + +
+ + +
+ +
+ {#each tokens as tok, i (i)} + {@const prob = getProbAtPosition(nextTokenProbs, i)} + + {/each} +
+
+ pos {config.loss.position} + {#if config.loss.type === "ce" || config.loss.type === "logit"} + {config.loss.type === "logit" ? "maximize" : "predict"} + { + if (config.loss.type !== "ce" && config.loss.type !== "logit") + throw new Error("inconsistent state: Token dropdown rendered but loss type has no label"); + + if (tokenId !== null) { + onChange({ + ...config, + loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, + }); + } + }} + placeholder="token..." + /> + {/if} +
+
+ + +
+
+ + { + const val = parseFloat(e.currentTarget.value); + if (!isNaN(val) && val > 0) { + onChange({ ...config, impMinCoeff: val }); + } + }} + /> +
+ handleSliderChange(parseInt(e.currentTarget.value))} + /> +
+ 1e-5 + 10 +
+
+ + + + + {#if showAdvanced} +
+
+ + + + + + + +
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte index ff1e8afe8..9dcab7f0e 100644 --- a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte @@ -1,14 +1,25 @@ - -
-
- {#each card.tokens as tok, i (i)} - {tok} - {/each} -
- -
- {#each card.graphs as graph (graph.id)} -
- - -
- {/each} - {#if card.graphs.length > 0} - - {/if} -
- -
- - {#if isNewGraphMode} - -
-
- ? - - {#if useOptimized} - - - - - - - - - - - - - {/if} - - {#if showComputeButton} - - {:else if hasMatchingGraph} - Graph already exists - {/if} -
-
- {:else if activeGraph?.data.optimization} - -
-
- Optimized graph params: - - - - - {#if displayConfig.ceLossCoeff > 0} - - {#if displayConfig.labelTokenId !== null} - - {/if} - {/if} - {#if displayConfig.klLossCoeff > 0} - - {/if} - -
-
- {:else if activeGraph} - -
- Standard graph (no optimization) -
- {/if} -
- - diff --git a/spd/app/frontend/src/components/prompt-attr/PromptPicker.svelte b/spd/app/frontend/src/components/prompt-attr/PromptPicker.svelte index 42a89604b..0932a9b87 100644 --- a/spd/app/frontend/src/components/prompt-attr/PromptPicker.svelte +++ b/spd/app/frontend/src/components/prompt-attr/PromptPicker.svelte @@ -1,7 +1,9 @@
@@ -122,9 +121,9 @@ class="dropdown-input" /> - {#if isOpen && filteredTokens.length > 0} + {#if isOpen && searchResults.length > 0} - {:else if isOpen && inputValue.trim() && filteredTokens.length === 0} + {:else if isOpen && inputValue.trim() && searchResults.length === 0} @@ -153,7 +155,7 @@ } .dropdown-input { - width: 100px; + width: 120px; padding: var(--space-1); border: 1px solid var(--border-default); background: var(--bg-elevated); @@ -173,7 +175,7 @@ .dropdown-list { position: fixed; - min-width: 200px; + min-width: 250px; max-height: 300px; overflow-y: auto; margin: 0; @@ -182,7 +184,7 @@ background: var(--bg-elevated); border: 1px solid var(--border-strong); border-radius: var(--radius-sm); - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); + box-shadow: var(--shadow-md); z-index: 10000; } @@ -199,21 +201,38 @@ color: var(--text-primary); font-family: var(--font-mono); font-size: var(--text-sm); + gap: var(--space-3); } .dropdown-item:hover, .dropdown-item.highlighted { - background: var(--bg-inset); + background: var(--bg-surface); } .token-string { white-space: pre; + background: var(--bg-base); + padding: 1px 3px; + border: 1px solid var(--border-subtle); + border-radius: var(--radius-sm); + } + + .token-meta { + display: flex; + align-items: center; + gap: var(--space-2); + flex-shrink: 0; + } + + .token-prob { + font-size: var(--text-xs); + color: var(--text-secondary); + font-variant-numeric: tabular-nums; } .token-id { font-size: var(--text-xs); color: var(--text-muted); - margin-left: var(--space-2); } .dropdown-empty { diff --git a/spd/app/frontend/src/components/prompt-attr/ViewControls.svelte b/spd/app/frontend/src/components/prompt-attr/ViewControls.svelte index 2e207b119..26c57b1b0 100644 --- a/spd/app/frontend/src/components/prompt-attr/ViewControls.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ViewControls.svelte @@ -42,7 +42,7 @@ let localTopK = $state(topK); let localComponentGap = $state(componentGap); let localLayerGap = $state(layerGap); - let localCiThreshold = $derived.by(() => (ciThreshold?.status === "loaded" ? ciThreshold.data.toString() : "")); + let localCiThreshold = $derived.by(() => (ciThreshold.status === "loaded" ? ciThreshold.data.toString() : "")); // Sync local state when props change externally $effect(() => void (localTopK = topK)); @@ -71,28 +71,28 @@ onblur={() => applyIfChanged(localTopK, topK, onTopKChange)} onkeydown={(e) => e.key === "Enter" && e.currentTarget.blur()} min={0} - max={10_000} + max={50_000} step={100} /> -
@@ -105,6 +134,42 @@ {/each}
+
+

Edge Variant

+

Attribution target: value or |value|

+
+ {#each edgeVariants as variant (variant)} + + {/each} +
+
+
+

Component Filtering

+

Filter components in Components tab by mean CI

+
+ + { + const val = parseFloat((e.target as HTMLInputElement).value); + if (!isNaN(val) && val >= 0) { + displaySettings.meanCiCutoff = val; + } + }} + /> +
+
{/if}
@@ -154,7 +219,7 @@ background: var(--bg-elevated); border: 1px solid var(--border-strong); border-radius: var(--radius-md); - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08); + box-shadow: var(--shadow-md); padding: var(--space-3); } @@ -241,4 +306,32 @@ cursor: pointer; accent-color: var(--accent-primary); } + + .cutoff-control { + display: flex; + align-items: center; + gap: var(--space-2); + } + + .cutoff-control label { + font-size: var(--text-sm); + font-weight: 500; + color: var(--text-primary); + } + + .cutoff-control input[type="number"] { + width: 100px; + padding: var(--space-1) var(--space-2); + border: 1px solid var(--border-default); + border-radius: var(--radius-sm); + font-size: var(--text-sm); + font-family: var(--font-mono); + background: var(--bg-surface); + color: var(--text-primary); + } + + .cutoff-control input[type="number"]:focus { + outline: none; + border-color: var(--accent-primary); + } diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte index bf677e2e8..ad3f821f5 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte @@ -1,5 +1,5 @@ -{#if hasAnyData} -
- -
- {#if hasAnyIncoming} -
-
{incomingLabel}
-
- {#if incomingPositive.length > 0} -
- -
- {/if} - {#if incomingNegative.length > 0} -
- -
- {/if} -
-
- {/if} +{#if incoming.length > 0} +
+ + +
+{/if} - {#if hasAnyOutgoing} -
-
{outgoingLabel}
-
- {#if outgoingPositive.length > 0} -
- -
- {/if} - {#if outgoingNegative.length > 0} -
- -
- {/if} -
-
- {/if} -
+{#if outgoing.length > 0} +
+ +
{/if} diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte index f2b32875b..3809aceaf 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte @@ -1,7 +1,8 @@ + +
+ + + {#if expanded && detail} +
+
+
+ Input + {#if detail.input?.reasoning} +

{detail.input.reasoning}

+ {/if} + {#each incomingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+ Output + {#if detail.output?.reasoning} +

{detail.output.reasoning}

+ {/if} + {#each outgoingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ui/InterpretationBadge.svelte b/spd/app/frontend/src/components/ui/InterpretationBadge.svelte index bfb142936..f9ef4fb66 100644 --- a/spd/app/frontend/src/components/ui/InterpretationBadge.svelte +++ b/spd/app/frontend/src/components/ui/InterpretationBadge.svelte @@ -2,6 +2,7 @@ import type { Loadable } from "../../lib"; import type { InterpretationDetail } from "../../lib/api"; import type { InterpretationBackendState } from "../../lib/useRun.svelte"; + import { displaySettings } from "../../lib/displaySettings.svelte"; interface Props { interpretation: Loadable; @@ -12,6 +13,12 @@ let { interpretation, interpretationDetail, onGenerate }: Props = $props(); let showPrompt = $state(false); + + function scoreClass(score: number): string { + if (score >= 0.7) return "score-high"; + if (score >= 0.5) return "score-medium"; + return "score-low"; + }
@@ -35,16 +42,31 @@ {interpretationData.data.confidence} + {#if interpretationData.data.detection_score !== null} + Det {Math.round(interpretationData.data.detection_score * 100)}% + {/if} + {#if interpretationData.data.fuzzing_score !== null} + Fuz {Math.round(interpretationData.data.fuzzing_score * 100)}% + {/if}
{#if interpretationDetail.status === "loaded" && interpretationDetail.data?.reasoning} {interpretationDetail.data.reasoning} {:else if interpretationDetail.status === "loading"} Loading reasoning... +


+ {/if}
- + {#if displaySettings.showAutoInterpPromptButton} + + {/if} {:else if interpretationData.status === "generation-error"} {String(interpretationData.error)} @@ -56,7 +78,7 @@ {/if}
- {#if showPrompt} + {#if displaySettings.showAutoInterpPromptButton && showPrompt}
{#if interpretationDetail.status === "loading"} Loading prompt... @@ -83,9 +105,9 @@ align-items: flex-start; gap: var(--space-2); padding: var(--space-2) var(--space-3); - background: var(--bg-secondary); + background: var(--bg-inset); border-radius: var(--radius-md); - border-left: 3px solid var(--color-accent, #6366f1); + border-left: 3px solid var(--accent-primary); } .interpretation-content { @@ -135,20 +157,20 @@ .confidence { font-size: var(--text-xs); - padding: 2px 6px; + padding: var(--space-1) var(--space-2); border-radius: var(--radius-sm); text-transform: uppercase; font-weight: 600; } .confidence-high { - background: color-mix(in srgb, #22c55e 20%, transparent); - color: #22c55e; + background: color-mix(in srgb, var(--status-positive-bright) 20%, transparent); + color: var(--status-positive-bright); } .confidence-medium { - background: color-mix(in srgb, #eab308 20%, transparent); - color: #eab308; + background: color-mix(in srgb, var(--status-warning) 20%, transparent); + color: var(--status-warning); } .confidence-low { @@ -156,6 +178,29 @@ color: var(--text-muted); } + .score-pill { + font-size: var(--text-xs); + padding: var(--space-1) var(--space-2); + border-radius: var(--radius-sm); + font-weight: 600; + white-space: nowrap; + } + + .score-high { + background: color-mix(in srgb, var(--status-positive-bright) 20%, transparent); + color: var(--status-positive-bright); + } + + .score-medium { + background: color-mix(in srgb, var(--status-warning) 20%, transparent); + color: var(--status-warning); + } + + .score-low { + background: color-mix(in srgb, var(--text-muted) 20%, transparent); + color: var(--text-muted); + } + .generate-btn, .retry-btn { padding: var(--space-1) var(--space-2); @@ -202,7 +247,7 @@ } .prompt-display { - background: var(--bg-primary); + background: var(--bg-surface); border: 1px solid var(--border-default); border-radius: var(--radius-md); padding: var(--space-3); diff --git a/spd/app/frontend/src/components/ui/SetOverlapVis.svelte b/spd/app/frontend/src/components/ui/SetOverlapVis.svelte index a60d09497..73a61fe5a 100644 --- a/spd/app/frontend/src/components/ui/SetOverlapVis.svelte +++ b/spd/app/frontend/src/components/ui/SetOverlapVis.svelte @@ -1,5 +1,5 @@ -
- {#each items as { token, value } (token)} - {token} - {/each} +
+
+ {#each visibleItems as { token, value }, i (i)} + {token} + {/each} +
+ {#if hasPagination} + + {/if}
diff --git a/spd/app/frontend/src/components/ui/TokenSpan.svelte b/spd/app/frontend/src/components/ui/TokenSpan.svelte new file mode 100644 index 000000000..4a5fa9b81 --- /dev/null +++ b/spd/app/frontend/src/components/ui/TokenSpan.svelte @@ -0,0 +1,43 @@ + + +{sanitizeToken(token)} + + diff --git a/spd/app/frontend/src/lib/ZoomControls.svelte b/spd/app/frontend/src/lib/ZoomControls.svelte index 569dc8639..d455829c6 100644 --- a/spd/app/frontend/src/lib/ZoomControls.svelte +++ b/spd/app/frontend/src/lib/ZoomControls.svelte @@ -47,7 +47,7 @@ color: var(--text-secondary); font-size: var(--text-sm); cursor: pointer; - transition: all 0.1s; + transition: all var(--transition-fast); } .zoom-btn:hover { diff --git a/spd/app/frontend/src/lib/api/activationContexts.ts b/spd/app/frontend/src/lib/api/activationContexts.ts index 1744f17c0..a126739a1 100644 --- a/spd/app/frontend/src/lib/api/activationContexts.ts +++ b/spd/app/frontend/src/lib/api/activationContexts.ts @@ -2,30 +2,41 @@ * API client for /api/activation_contexts endpoints. */ -import type { ActivationContextsSummary, ComponentDetail, ComponentProbeResult } from "../promptAttributionsTypes"; -import { API_URL, fetchJson } from "./index"; +import type { + ActivationContextsSummary, + SubcomponentProbeResult, + SubcomponentActivationContexts, +} from "../promptAttributionsTypes"; +import { ApiError, fetchJson } from "./index"; -// Types for activation contexts -export type SubcomponentActivationContexts = { - subcomponent_idx: number; - mean_ci: number; - example_tokens: string[][]; - example_ci: number[][]; - example_component_acts: number[][]; -}; - -export async function getActivationContextsSummary(): Promise { - return fetchJson(`${API_URL}/api/activation_contexts/summary`); +export async function getActivationContextsSummary(): Promise { + try { + return await fetchJson("/api/activation_contexts/summary"); + } catch (e) { + if (e instanceof ApiError && e.status === 404) return null; + throw e; + } } -export async function getComponentDetail(layer: string, componentIdx: number): Promise { - return fetchJson( - `${API_URL}/api/activation_contexts/${encodeURIComponent(layer)}/${componentIdx}`, +/** Default limit for initial load - 100 examples = 10 pages at 10 per page. */ +const ACTIVATION_EXAMPLES_INITIAL_LIMIT = 100; + +export async function getActivationContextDetail( + layer: string, + componentIdx: number, + limit: number = ACTIVATION_EXAMPLES_INITIAL_LIMIT, +): Promise { + return fetchJson( + `/api/activation_contexts/${encodeURIComponent(layer)}/${componentIdx}?limit=${limit}`, ); } -export async function probeComponent(text: string, layer: string, componentIdx: number): Promise { - return fetchJson(`${API_URL}/api/activation_contexts/probe`, { +export async function probeComponent( + text: string, + layer: string, + componentIdx: number, +): Promise { + return fetchJson("/api/activation_contexts/probe", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ text, layer, component_idx: componentIdx }), diff --git a/spd/app/frontend/src/lib/api/clusters.ts b/spd/app/frontend/src/lib/api/clusters.ts index b05160ad9..f1c4744a3 100644 --- a/spd/app/frontend/src/lib/api/clusters.ts +++ b/spd/app/frontend/src/lib/api/clusters.ts @@ -2,14 +2,14 @@ * API client for /api/clusters endpoints. */ -import { API_URL } from "./index"; +import { apiUrl } from "./index"; export type ClusterMapping = { mapping: Record; }; export async function loadClusterMapping(filePath: string): Promise { - const url = new URL(`${API_URL}/api/clusters/load`); + const url = apiUrl("/api/clusters/load"); url.searchParams.set("file_path", filePath); const response = await fetch(url.toString(), { method: "POST" }); diff --git a/spd/app/frontend/src/lib/api/correlations.ts b/spd/app/frontend/src/lib/api/correlations.ts index 15633eaf1..8dcc63f04 100644 --- a/spd/app/frontend/src/lib/api/correlations.ts +++ b/spd/app/frontend/src/lib/api/correlations.ts @@ -2,33 +2,35 @@ * API client for /api/correlations endpoints. */ -import type { ComponentCorrelations, TokenStats } from "../promptAttributionsTypes"; -import { API_URL, fetchJson } from "./index"; +import type { SubcomponentCorrelationsResponse, TokenStatsResponse } from "../promptAttributionsTypes"; +import { ApiError, apiUrl, fetchJson } from "./index"; export async function getComponentCorrelations( layer: string, componentIdx: number, topK: number, -): Promise { - const url = new URL(`${API_URL}/api/correlations/components/${encodeURIComponent(layer)}/${componentIdx}`); +): Promise { + const url = apiUrl(`/api/correlations/components/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("top_k", String(topK)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } export async function getComponentTokenStats( layer: string, componentIdx: number, topK: number, -): Promise { - const url = new URL(`${API_URL}/api/correlations/token_stats/${encodeURIComponent(layer)}/${componentIdx}`); +): Promise { + const url = apiUrl(`/api/correlations/token_stats/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("top_k", String(topK)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } // Interpretation headline (bulk-fetched) - lightweight data for badges -export type Interpretation = { +export type InterpretationHeadline = { label: string; confidence: "low" | "medium" | "high"; + detection_score: number | null; + fuzzing_score: number | null; }; // Interpretation detail (fetched on-demand) - reasoning and prompt @@ -37,19 +39,34 @@ export type InterpretationDetail = { prompt: string; }; -export async function getAllInterpretations(): Promise> { - return fetchJson>(`${API_URL}/api/correlations/interpretations`); +export async function getAllInterpretations(): Promise> { + return fetchJson>("/api/correlations/interpretations"); } -export async function getInterpretationDetail(layer: string, componentIdx: number): Promise { - return fetchJson( - `${API_URL}/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, - ); +export async function getIntruderScores(): Promise> { + return fetchJson>("/api/correlations/intruder_scores"); } -export async function requestComponentInterpretation(layer: string, componentIdx: number): Promise { - return fetchJson( - `${API_URL}/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, +export async function getInterpretationDetail( + layer: string, + componentIdx: number, +): Promise { + try { + return await fetchJson( + `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, + ); + } catch (e) { + if (e instanceof ApiError && e.status === 404) return null; + throw e; + } +} + +export async function requestComponentInterpretation( + layer: string, + componentIdx: number, +): Promise { + return fetchJson( + `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, { method: "POST" }, ); } diff --git a/spd/app/frontend/src/lib/api/dataSources.ts b/spd/app/frontend/src/lib/api/dataSources.ts new file mode 100644 index 000000000..ac20b7220 --- /dev/null +++ b/spd/app/frontend/src/lib/api/dataSources.ts @@ -0,0 +1,42 @@ +/** + * API client for /api/data_sources endpoint. + */ + +import { fetchJson } from "./index"; + +export type HarvestInfo = { + subrun_id: string; + config: Record; + n_components: number; + has_intruder_scores: boolean; +}; + +export type AutointerpInfo = { + subrun_id: string; + config: Record; + n_interpretations: number; + eval_scores: string[]; +}; + +export type AttributionsInfo = { + subrun_id: string; + n_tokens_processed: number; + ci_threshold: number; +}; + +export type GraphInterpInfoDS = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + +export type DataSourcesResponse = { + harvest: HarvestInfo | null; + autointerp: AutointerpInfo | null; + attributions: AttributionsInfo | null; + graph_interp: GraphInterpInfoDS | null; +}; + +export async function fetchDataSources(): Promise { + return fetchJson("/api/data_sources"); +} diff --git a/spd/app/frontend/src/lib/api/dataset.ts b/spd/app/frontend/src/lib/api/dataset.ts index 19e8bb63a..62c0f736c 100644 --- a/spd/app/frontend/src/lib/api/dataset.ts +++ b/spd/app/frontend/src/lib/api/dataset.ts @@ -2,18 +2,18 @@ * API client for /api/dataset endpoints. */ -import { API_URL } from "./index"; +import { apiUrl } from "./index"; export type DatasetSearchResult = { - story: string; + text: string; occurrence_count: number; - topic: string | null; - theme: string | null; + metadata: Record; }; export type DatasetSearchMetadata = { query: string; split: string; + dataset_name: string; total_results: number; search_time_seconds: number; }; @@ -27,7 +27,7 @@ export type DatasetSearchPage = { }; export async function searchDataset(query: string, split: string): Promise { - const url = new URL(`${API_URL}/api/dataset/search`); + const url = apiUrl("/api/dataset/search"); url.searchParams.set("query", query); url.searchParams.set("split", split); @@ -40,8 +40,8 @@ export async function searchDataset(query: string, split: string): Promise { - const url = new URL(`${API_URL}/api/dataset/results`); +export async function getDatasetResults(page: number, pageSize: number): Promise { + const url = apiUrl("/api/dataset/results"); url.searchParams.set("page", String(page)); url.searchParams.set("page_size", String(pageSize)); @@ -53,3 +53,96 @@ export async function getDatasetSearchPage(page: number, pageSize: number): Prom return (await response.json()) as DatasetSearchPage; } + +export type TokenizedSearchResult = { + tokens: string[]; + next_token_probs: (number | null)[]; + occurrence_count: number; + metadata: Record; +}; + +export type TokenizedSearchPage = { + results: TokenizedSearchResult[]; + query: string; + page: number; + page_size: number; + total_results: number; + total_pages: number; +}; + +export async function getTokenizedResults( + page: number, + pageSize: number = 10, + maxTokens: number = 256, +): Promise { + const url = apiUrl("/api/dataset/results_tokenized"); + url.searchParams.set("page", String(page)); + url.searchParams.set("page_size", String(pageSize)); + url.searchParams.set("max_tokens", String(maxTokens)); + + const response = await fetch(url.toString(), { method: "GET" }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to get tokenized results"); + } + + return (await response.json()) as TokenizedSearchPage; +} + +export type RandomSamplesResult = { + results: DatasetSearchResult[]; + total_available: number; + seed: number; +}; + +export async function getRandomSamples( + nSamples: number = 100, + seed: number = 42, + split: "train" | "test" = "train", +): Promise { + const url = apiUrl("/api/dataset/random"); + url.searchParams.set("n_samples", String(nSamples)); + url.searchParams.set("seed", String(seed)); + url.searchParams.set("split", split); + + const response = await fetch(url.toString(), { method: "GET" }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to get random samples"); + } + + return (await response.json()) as RandomSamplesResult; +} + +export type TokenizedSample = { + tokens: string[]; + next_token_probs: (number | null)[]; // Probability of next token; null for last position + metadata: Record; +}; + +export type RandomSamplesWithLossResult = { + results: TokenizedSample[]; + total_available: number; + seed: number; +}; + +export async function getRandomSamplesWithLoss( + nSamples: number = 20, + seed: number = 42, + split: "train" | "test" = "train", + maxTokens: number = 256, +): Promise { + const url = apiUrl("/api/dataset/random_with_loss"); + url.searchParams.set("n_samples", String(nSamples)); + url.searchParams.set("seed", String(seed)); + url.searchParams.set("split", split); + url.searchParams.set("max_tokens", String(maxTokens)); + + const response = await fetch(url.toString(), { method: "GET" }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || "Failed to get random samples with loss"); + } + + return (await response.json()) as RandomSamplesWithLossResult; +} diff --git a/spd/app/frontend/src/lib/api/datasetAttributions.ts b/spd/app/frontend/src/lib/api/datasetAttributions.ts index 6fa2a6595..030eae6c6 100644 --- a/spd/app/frontend/src/lib/api/datasetAttributions.ts +++ b/spd/app/frontend/src/lib/api/datasetAttributions.ts @@ -2,145 +2,44 @@ * API client for /api/dataset_attributions endpoints. */ -import { API_URL, fetchJson } from "./index"; +import { apiUrl, fetchJson } from "./index"; export type DatasetAttributionEntry = { - componentKey: string; + component_key: string; layer: string; - componentIdx: number; + component_idx: number; value: number; + token_str: string | null; }; -export type DatasetAttributionMetadata = { - available: boolean; - nBatchesProcessed: number | null; - nTokensProcessed: number | null; - nComponentLayerKeys: number | null; - vocabSize: number | null; - ciThreshold: number | null; +export type SignedAttributions = { + positive_sources: DatasetAttributionEntry[]; + negative_sources: DatasetAttributionEntry[]; + positive_targets: DatasetAttributionEntry[]; + negative_targets: DatasetAttributionEntry[]; }; -export async function getDatasetAttributionMetadata(): Promise { - const data = await fetchJson<{ - available: boolean; - n_batches_processed: number | null; - n_tokens_processed: number | null; - n_component_layer_keys: number | null; - vocab_size: number | null; - ci_threshold: number | null; - }>(`${API_URL}/api/dataset_attributions/metadata`); +export type AttrMetric = "attr" | "attr_abs"; - return { - available: data.available, - nBatchesProcessed: data.n_batches_processed, - nTokensProcessed: data.n_tokens_processed, - nComponentLayerKeys: data.n_component_layer_keys, - vocabSize: data.vocab_size, - ciThreshold: data.ci_threshold, - }; -} +export type AllMetricAttributions = { + attr: SignedAttributions; + attr_abs: SignedAttributions; +}; -export type ComponentAttributions = { - positiveSources: DatasetAttributionEntry[]; - negativeSources: DatasetAttributionEntry[]; - positiveTargets: DatasetAttributionEntry[]; - negativeTargets: DatasetAttributionEntry[]; +export type DatasetAttributionsMetadata = { + available: boolean; }; -function mapEntries( - entries: Array<{ component_key: string; layer: string; component_idx: number; value: number }>, -): DatasetAttributionEntry[] { - return entries.map((e) => ({ - componentKey: e.component_key, - layer: e.layer, - componentIdx: e.component_idx, - value: e.value, - })); +export async function getDatasetAttributionsMetadata(): Promise { + return fetchJson(apiUrl("/api/dataset_attributions/metadata").toString()); } export async function getComponentAttributions( layer: string, componentIdx: number, k: number = 10, -): Promise { - const url = new URL(`${API_URL}/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); - url.searchParams.set("k", String(k)); - - const data = await fetchJson<{ - positive_sources: Array<{ component_key: string; layer: string; component_idx: number; value: number }>; - negative_sources: Array<{ component_key: string; layer: string; component_idx: number; value: number }>; - positive_targets: Array<{ component_key: string; layer: string; component_idx: number; value: number }>; - negative_targets: Array<{ component_key: string; layer: string; component_idx: number; value: number }>; - }>(url.toString()); - - return { - positiveSources: mapEntries(data.positive_sources), - negativeSources: mapEntries(data.negative_sources), - positiveTargets: mapEntries(data.positive_targets), - negativeTargets: mapEntries(data.negative_targets), - }; -} - -export async function getAttributionSources( - layer: string, - componentIdx: number, - k: number = 10, - sign: "positive" | "negative" = "positive", -): Promise { - const url = new URL(`${API_URL}/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}/sources`); +): Promise { + const url = apiUrl(`/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("k", String(k)); - url.searchParams.set("sign", sign); - - const data = await fetchJson< - Array<{ - component_key: string; - layer: string; - component_idx: number; - value: number; - }> - >(url.toString()); - - return data.map((entry) => ({ - componentKey: entry.component_key, - layer: entry.layer, - componentIdx: entry.component_idx, - value: entry.value, - })); -} - -export async function getAttributionTargets( - layer: string, - componentIdx: number, - k: number = 10, - sign: "positive" | "negative" = "positive", -): Promise { - const url = new URL(`${API_URL}/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}/targets`); - url.searchParams.set("k", String(k)); - url.searchParams.set("sign", sign); - - const data = await fetchJson< - Array<{ - component_key: string; - layer: string; - component_idx: number; - value: number; - }> - >(url.toString()); - - return data.map((entry) => ({ - componentKey: entry.component_key, - layer: entry.layer, - componentIdx: entry.component_idx, - value: entry.value, - })); -} - -export async function getAttributionBetween( - sourceLayer: string, - sourceIdx: number, - targetLayer: string, - targetIdx: number, -): Promise { - const url = `${API_URL}/api/dataset_attributions/between/${encodeURIComponent(sourceLayer)}/${sourceIdx}/${encodeURIComponent(targetLayer)}/${targetIdx}`; - return fetchJson(url); + return fetchJson(url.toString()); } diff --git a/spd/app/frontend/src/lib/api/graphInterp.ts b/spd/app/frontend/src/lib/api/graphInterp.ts new file mode 100644 index 000000000..8229e757c --- /dev/null +++ b/spd/app/frontend/src/lib/api/graphInterp.ts @@ -0,0 +1,81 @@ +/** + * API client for /api/graph_interp endpoints. + */ + +import { fetchJson } from "./index"; + +export type GraphInterpHeadline = { + label: string; + confidence: string; + output_label: string | null; + input_label: string | null; +}; + +export type LabelDetail = { + label: string; + confidence: string; + reasoning: string; + prompt: string; +}; + +export type GraphInterpDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; +}; + +export type PromptEdgeResponse = { + related_key: string; + pass_name: string; + attribution: number; + related_label: string | null; + related_confidence: string | null; + token_str: string | null; +}; + +export type GraphInterpComponentDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; + edges: PromptEdgeResponse[]; +}; + +export type GraphNode = { + component_key: string; + label: string; + confidence: string; +}; + +export type GraphEdge = { + source: string; + target: string; + attribution: number; + pass_name: string; +}; + +export type ModelGraphResponse = { + nodes: GraphNode[]; + edges: GraphEdge[]; +}; + +export type GraphInterpInfo = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + +export async function getAllGraphInterpLabels(): Promise> { + return fetchJson>("/api/graph_interp/labels"); +} + +export async function getGraphInterpDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/labels/${layer}/${cIdx}`); +} + +export async function getGraphInterpComponentDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/detail/${layer}/${cIdx}`); +} + +export async function getModelGraph(): Promise { + return fetchJson("/api/graph_interp/graph"); +} diff --git a/spd/app/frontend/src/lib/api/graphs.ts b/spd/app/frontend/src/lib/api/graphs.ts index 3d2811b62..39dcaec19 100644 --- a/spd/app/frontend/src/lib/api/graphs.ts +++ b/spd/app/frontend/src/lib/api/graphs.ts @@ -2,9 +2,24 @@ * API client for /api/graphs endpoints. */ -import type { GraphData, TokenizeResult, TokenInfo } from "../promptAttributionsTypes"; +import type { GraphData, EdgeData, TokenizeResponse, TokenSearchResult, CISnapshot } from "../promptAttributionsTypes"; import { buildEdgeIndexes } from "../promptAttributionsTypes"; -import { API_URL, ApiError, fetchJson } from "./index"; +import { apiUrl, ApiError, fetchJson } from "./index"; + +/** Hydrate a raw API graph response into a full GraphData with edge indexes. */ +function hydrateGraph(raw: Record): GraphData { + const g = raw as Omit; + const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); + const edgesAbs = (g.edgesAbs satisfies EdgeData[] | null) ?? null; + let edgesAbsBySource: Map | null = null; + let edgesAbsByTarget: Map | null = null; + if (edgesAbs) { + const absIndexes = buildEdgeIndexes(edgesAbs); + edgesAbsBySource = absIndexes.edgesBySource; + edgesAbsByTarget = absIndexes.edgesByTarget; + } + return { ...g, edgesBySource, edgesByTarget, edgesAbs, edgesAbsBySource, edgesAbsByTarget }; +} export type NormalizeType = "none" | "target" | "layer"; @@ -22,14 +37,13 @@ export type ComputeGraphParams = { includedNodes?: string[]; }; -/** - * Parse SSE stream and return GraphData result. - * Handles progress updates, errors, and completion messages. - */ -async function parseGraphSSEStream( +/** Generic SSE stream parser. Delegates result extraction to the caller via extractResult. */ +async function parseSSEStream( response: Response, + extractResult: (data: Record) => T, onProgress?: (progress: GraphProgress) => void, -): Promise { + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { const reader = response.body?.getReader(); if (!reader) { throw new Error("Response body is not readable"); @@ -37,7 +51,7 @@ async function parseGraphSSEStream( const decoder = new TextDecoder(); let buffer = ""; - let result: GraphData | null = null; + let result: T | null = null; while (true) { const { done, value } = await reader.read(); @@ -55,11 +69,12 @@ async function parseGraphSSEStream( if (data.type === "progress" && onProgress) { onProgress({ current: data.current, total: data.total, stage: data.stage }); + } else if (data.type === "ci_snapshot" && onCISnapshot) { + onCISnapshot(data as CISnapshot); } else if (data.type === "error") { throw new ApiError(data.error, 500); } else if (data.type === "complete") { - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(data.data.edges); - result = { ...data.data, edgesBySource, edgesByTarget }; + result = extractResult(data); await reader.cancel(); break; } @@ -75,11 +90,11 @@ async function parseGraphSSEStream( return result; } -export async function computeGraphStreaming( +export async function computeGraphStream( params: ComputeGraphParams, onProgress?: (progress: GraphProgress) => void, ): Promise { - const url = new URL(`${API_URL}/api/graphs`); + const url = apiUrl("/api/graphs"); url.searchParams.set("prompt_id", String(params.promptId)); url.searchParams.set("normalize", String(params.normalize)); url.searchParams.set("ci_threshold", String(params.ciThreshold)); @@ -93,10 +108,11 @@ export async function computeGraphStreaming( throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); } - return parseGraphSSEStream(response, onProgress); + return parseSSEStream(response, (data) => hydrateGraph(data.data as Record), onProgress); } export type MaskType = "stochastic" | "ci"; +export type LossType = "ce" | "kl" | "logit"; export type ComputeGraphOptimizedParams = { promptId: number; @@ -105,38 +121,42 @@ export type ComputeGraphOptimizedParams = { pnorm: number; beta: number; normalize: NormalizeType; - outputProbThreshold: number; ciThreshold: number; - labelToken?: number; - ceLossCoeff?: number; - klLossCoeff?: number; maskType: MaskType; + lossType: LossType; + lossCoeff: number; + lossPosition: number; + labelToken?: number; // Required for CE loss + advPgdNSteps?: number; + advPgdStepSize?: number; }; -export async function computeGraphOptimizedStreaming( +export async function computeGraphOptimizedStream( params: ComputeGraphOptimizedParams, onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, ): Promise { - const url = new URL(`${API_URL}/api/graphs/optimized/stream`); + const url = apiUrl("/api/graphs/optimized/stream"); url.searchParams.set("prompt_id", String(params.promptId)); url.searchParams.set("imp_min_coeff", String(params.impMinCoeff)); url.searchParams.set("steps", String(params.steps)); url.searchParams.set("pnorm", String(params.pnorm)); url.searchParams.set("beta", String(params.beta)); url.searchParams.set("normalize", String(params.normalize)); - url.searchParams.set("output_prob_threshold", String(params.outputProbThreshold)); url.searchParams.set("ci_threshold", String(params.ciThreshold)); - + url.searchParams.set("mask_type", params.maskType); + url.searchParams.set("loss_type", params.lossType); + url.searchParams.set("loss_coeff", String(params.lossCoeff)); + url.searchParams.set("loss_position", String(params.lossPosition)); if (params.labelToken !== undefined) { url.searchParams.set("label_token", String(params.labelToken)); } - if (params.ceLossCoeff !== undefined) { - url.searchParams.set("ce_loss_coeff", String(params.ceLossCoeff)); + if (params.advPgdNSteps !== undefined) { + url.searchParams.set("adv_pgd_n_steps", String(params.advPgdNSteps)); } - if (params.klLossCoeff !== undefined) { - url.searchParams.set("kl_loss_coeff", String(params.klLossCoeff)); + if (params.advPgdStepSize !== undefined) { + url.searchParams.set("adv_pgd_step_size", String(params.advPgdStepSize)); } - url.searchParams.set("mask_type", params.maskType); const response = await fetch(url.toString(), { method: "POST" }); if (!response.ok) { @@ -144,27 +164,98 @@ export async function computeGraphOptimizedStreaming( throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); } - return parseGraphSSEStream(response, onProgress); + return parseSSEStream( + response, + (data) => hydrateGraph(data.data as Record), + onProgress, + onCISnapshot, + ); +} + +export type ComputeGraphOptimizedBatchParams = { + promptId: number; + impMinCoeffs: number[]; + steps: number; + pnorm: number; + beta: number; + normalize: NormalizeType; + ciThreshold: number; + maskType: MaskType; + lossType: LossType; + lossCoeff: number; + lossPosition: number; + labelToken?: number; + advPgdNSteps?: number; + advPgdStepSize?: number; +}; + +export async function computeGraphOptimizedBatchStream( + params: ComputeGraphOptimizedBatchParams, + onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { + const url = apiUrl("/api/graphs/optimized/batch/stream"); + + const body: Record = { + prompt_id: params.promptId, + imp_min_coeffs: params.impMinCoeffs, + steps: params.steps, + pnorm: params.pnorm, + beta: params.beta, + normalize: params.normalize, + ci_threshold: params.ciThreshold, + mask_type: params.maskType, + loss_type: params.lossType, + loss_coeff: params.lossCoeff, + loss_position: params.lossPosition, + }; + if (params.labelToken !== undefined) body.label_token = params.labelToken; + if (params.advPgdNSteps !== undefined) body.adv_pgd_n_steps = params.advPgdNSteps; + if (params.advPgdStepSize !== undefined) body.adv_pgd_step_size = params.advPgdStepSize; + + const response = await fetch(url.toString(), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + if (!response.ok) { + const error = await response.json(); + throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); + } + + return parseSSEStream( + response, + (data) => (data.data as { graphs: Record[] }).graphs.map((g) => hydrateGraph(g)), + onProgress, + onCISnapshot, + ); } export async function getGraphs(promptId: number, normalize: NormalizeType, ciThreshold: number): Promise { - const url = new URL(`${API_URL}/api/graphs/${promptId}`); + const url = apiUrl(`/api/graphs/${promptId}`); url.searchParams.set("normalize", normalize); url.searchParams.set("ci_threshold", String(ciThreshold)); - const graphs = await fetchJson[]>(url.toString()); - return graphs.map((g) => { - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); - return { ...g, edgesBySource, edgesByTarget }; - }); + const graphs = await fetchJson[]>(url.toString()); + return graphs.map((g) => hydrateGraph(g)); } -export async function tokenizeText(text: string): Promise { - const url = new URL(`${API_URL}/api/graphs/tokenize`); +export async function tokenizeText(text: string): Promise { + const url = apiUrl("/api/graphs/tokenize"); url.searchParams.set("text", text); - return fetchJson(url.toString(), { method: "POST" }); + return fetchJson(url.toString(), { method: "POST" }); } -export async function getAllTokens(): Promise { - const response = await fetchJson<{ tokens: TokenInfo[] }>(`${API_URL}/api/graphs/tokens`); +export async function searchTokens( + query: string, + promptId: number, + position: number, + limit: number = 20, +): Promise { + const url = apiUrl("/api/graphs/tokens/search"); + url.searchParams.set("q", query); + url.searchParams.set("limit", String(limit)); + url.searchParams.set("prompt_id", String(promptId)); + url.searchParams.set("position", String(position)); + const response = await fetchJson<{ tokens: TokenSearchResult[] }>(url.toString()); return response.tokens; } diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 7f4bcc9e8..d2d810283 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -1,9 +1,17 @@ /** * Shared API utilities and exports. + * + * In development, Vite proxies /api requests to the backend. + * This allows the frontend to work regardless of which port the backend is on. */ -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export const API_URL = (import.meta as any).env.VITE_API_URL || "http://localhost:8000"; +/** + * Build a URL for an API endpoint. + * Uses relative paths which Vite's proxy forwards to the backend. + */ +export function apiUrl(path: string): URL { + return new URL(path, window.location.origin); +} export class ApiError extends Error { constructor( @@ -17,13 +25,20 @@ export class ApiError extends Error { export async function fetchJson(url: string, options?: RequestInit): Promise { const response = await fetch(url, options); - const data = await response.json(); + const text = await response.text(); if (!response.ok) { - throw new ApiError(data.detail || data.error || `HTTP ${response.status}`, response.status); + let message = `HTTP ${response.status}`; + try { + const data = JSON.parse(text); + message = data.detail || data.error || message; + } catch { + message = text.slice(0, 200) || message; + } + throw new ApiError(message, response.status); } - return data as T; + return JSON.parse(text) as T; } // Re-export all API modules @@ -36,3 +51,8 @@ export * from "./datasetAttributions"; export * from "./intervention"; export * from "./dataset"; export * from "./clusters"; +export * from "./investigations"; +export * from "./dataSources"; +export * from "./graphInterp"; +export * from "./pretrainInfo"; +export * from "./runRegistry"; diff --git a/spd/app/frontend/src/lib/api/intervention.ts b/spd/app/frontend/src/lib/api/intervention.ts index 399fdff43..154228181 100644 --- a/spd/app/frontend/src/lib/api/intervention.ts +++ b/spd/app/frontend/src/lib/api/intervention.ts @@ -2,11 +2,10 @@ * API client for /api/intervention endpoints. */ -import type { ForkedInterventionRun, InterventionRunSummary, RunInterventionRequest } from "../interventionTypes"; -import { API_URL } from "./index"; +import type { InterventionRunSummary, RunInterventionRequest } from "../interventionTypes"; export async function runAndSaveIntervention(request: RunInterventionRequest): Promise { - const response = await fetch(`${API_URL}/api/intervention/run`, { + const response = await fetch("/api/intervention/run", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify(request), @@ -19,7 +18,7 @@ export async function runAndSaveIntervention(request: RunInterventionRequest): P } export async function getInterventionRuns(graphId: number): Promise { - const response = await fetch(`${API_URL}/api/intervention/runs/${graphId}`); + const response = await fetch(`/api/intervention/runs/${graphId}`); if (!response.ok) { const error = await response.json(); throw new Error(error.detail || "Failed to get intervention runs"); @@ -28,7 +27,7 @@ export async function getInterventionRuns(graphId: number): Promise { - const response = await fetch(`${API_URL}/api/intervention/runs/${runId}`, { + const response = await fetch(`/api/intervention/runs/${runId}`, { method: "DELETE", }); if (!response.ok) { @@ -36,30 +35,3 @@ export async function deleteInterventionRun(runId: number): Promise { throw new Error(error.detail || "Failed to delete intervention run"); } } - -export async function forkInterventionRun( - runId: number, - tokenReplacements: [number, number][], - topK: number = 10, -): Promise { - const response = await fetch(`${API_URL}/api/intervention/runs/${runId}/fork`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ token_replacements: tokenReplacements, top_k: topK }), - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to fork intervention run"); - } - return (await response.json()) as ForkedInterventionRun; -} - -export async function deleteForkedInterventionRun(forkId: number): Promise { - const response = await fetch(`${API_URL}/api/intervention/forks/${forkId}`, { - method: "DELETE", - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to delete forked intervention run"); - } -} diff --git a/spd/app/frontend/src/lib/api/investigations.ts b/spd/app/frontend/src/lib/api/investigations.ts new file mode 100644 index 000000000..42f1fb1f3 --- /dev/null +++ b/spd/app/frontend/src/lib/api/investigations.ts @@ -0,0 +1,101 @@ +/** + * API client for investigation results. + */ + +export interface InvestigationSummary { + id: string; // inv_id (e.g., "inv-abc12345") + wandb_path: string | null; + prompt: string | null; + created_at: string; + has_research_log: boolean; + has_explanations: boolean; + event_count: number; + last_event_time: string | null; + last_event_message: string | null; + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; // in_progress, completed, inconclusive +} + +export interface EventEntry { + event_type: string; + timestamp: string; + message: string; + details: Record | null; +} + +export interface InvestigationDetail { + id: string; + wandb_path: string | null; + prompt: string | null; + created_at: string; + research_log: string | null; + events: EventEntry[]; + explanations: Record[]; + artifact_ids: string[]; // List of artifact IDs available for this investigation + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; +} + +import type { EdgeData, OutputProbability } from "../promptAttributionsTypes"; + +/** Data for a graph artifact (subset of GraphData, self-contained for offline viewing) */ +export interface ArtifactGraphData { + tokens: string[]; + edges: EdgeData[]; + outputProbs: Record; + nodeCiVals: Record; + nodeSubcompActs: Record; + maxAbsAttr: number; + l0_total: number; +} + +export interface GraphArtifact { + type: "graph"; + id: string; + caption: string | null; + graph_id: number; + data: ArtifactGraphData; +} + +export interface LaunchResponse { + inv_id: string; + job_id: string; +} + +export async function launchInvestigation(prompt: string): Promise { + const res = await fetch("/api/investigations/launch", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ prompt }), + }); + if (!res.ok) throw new Error(`Failed to launch investigation: ${res.statusText}`); + return res.json(); +} + +export async function listInvestigations(): Promise { + const res = await fetch("/api/investigations"); + if (!res.ok) throw new Error(`Failed to list investigations: ${res.statusText}`); + return res.json(); +} + +export async function getInvestigation(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}`); + if (!res.ok) throw new Error(`Failed to get investigation: ${res.statusText}`); + return res.json(); +} + +export async function listArtifacts(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts`); + if (!res.ok) throw new Error(`Failed to list artifacts: ${res.statusText}`); + return res.json(); +} + +export async function getArtifact(invId: string, artifactId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts/${artifactId}`); + if (!res.ok) throw new Error(`Failed to get artifact: ${res.statusText}`); + return res.json(); +} diff --git a/spd/app/frontend/src/lib/api/pretrainInfo.ts b/spd/app/frontend/src/lib/api/pretrainInfo.ts new file mode 100644 index 000000000..0cd66bd97 --- /dev/null +++ b/spd/app/frontend/src/lib/api/pretrainInfo.ts @@ -0,0 +1,37 @@ +/** + * API client for /api/pretrain_info endpoint. + */ + +import { fetchJson } from "./index"; + +export type BlockStructure = { + index: number; + attn_type: "separate" | "fused"; + attn_projections: string[]; + ffn_type: "glu" | "mlp"; + ffn_projections: string[]; +}; + +export type TopologyInfo = { + n_blocks: number; + block_structure: BlockStructure[]; +}; + +export type PretrainInfoResponse = { + model_type: string; + summary: string; + dataset_short: string | null; + target_model_config: Record | null; + pretrain_config: Record | null; + pretrain_wandb_path: string | null; + topology: TopologyInfo | null; +}; + +export async function fetchPretrainInfo(wandbPath: string): Promise { + const params = new URLSearchParams({ wandb_path: wandbPath }); + return fetchJson(`/api/pretrain_info?${params}`); +} + +export async function fetchPretrainInfoForLoadedRun(): Promise { + return fetchJson("/api/pretrain_info/loaded"); +} diff --git a/spd/app/frontend/src/lib/api/prompts.ts b/spd/app/frontend/src/lib/api/prompts.ts index bdc0530bd..76d562ab7 100644 --- a/spd/app/frontend/src/lib/api/prompts.ts +++ b/spd/app/frontend/src/lib/api/prompts.ts @@ -3,78 +3,18 @@ */ import type { PromptPreview } from "../promptAttributionsTypes"; -import { API_URL, ApiError, fetchJson } from "./index"; +import { apiUrl, fetchJson } from "./index"; export async function listPrompts(): Promise { - return fetchJson(`${API_URL}/api/prompts`); + return fetchJson("/api/prompts"); } export async function createCustomPrompt(text: string): Promise { - const url = new URL(`${API_URL}/api/prompts/custom`); + const url = apiUrl("/api/prompts/custom"); url.searchParams.set("text", text); return fetchJson(url.toString(), { method: "POST" }); } -export type GeneratePromptsConfig = { - nPrompts: number; -}; - -export type GeneratePromptsResult = { - prompts_added: number; - total_prompts: number; -}; - -export async function generatePrompts( - config: GeneratePromptsConfig, - onProgress?: (progress: number, count: number) => void, -): Promise { - const url = new URL(`${API_URL}/api/prompts/generate`); - url.searchParams.set("n_prompts", String(config.nPrompts)); - - const response = await fetch(url.toString(), { method: "POST" }); - if (!response.ok) { - const error = await response.json(); - throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); - } - - const reader = response.body?.getReader(); - if (!reader) { - throw new Error("Response body is not readable"); - } - - const decoder = new TextDecoder(); - let buffer = ""; - let result: GeneratePromptsResult | null = null; - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - - const lines = buffer.split("\n\n"); - buffer = lines.pop() || ""; - - for (const line of lines) { - if (!line.trim() || !line.startsWith("data: ")) continue; - - const data = JSON.parse(line.substring(6)); - - if (data.type === "progress" && onProgress) { - onProgress(data.progress, data.count); - } else if (data.type === "complete") { - result = { prompts_added: data.prompts_added, total_prompts: data.total_prompts }; - await reader.cancel(); - break; - } - } - - if (result) break; - } - - if (!result) { - throw new Error("No result received from stream"); - } - - return result; +export async function deletePrompt(promptId: number): Promise { + await fetchJson(`/api/prompts/${promptId}`, { method: "DELETE" }); } diff --git a/spd/app/frontend/src/lib/api/runRegistry.ts b/spd/app/frontend/src/lib/api/runRegistry.ts new file mode 100644 index 000000000..c727f4dcc --- /dev/null +++ b/spd/app/frontend/src/lib/api/runRegistry.ts @@ -0,0 +1,26 @@ +/** + * API client for /api/run_registry endpoint. + */ + +import { fetchJson } from "./index"; + +export type DataAvailability = { + harvest: boolean; + autointerp: boolean; + attributions: boolean; + graph_interp: boolean; +}; + +export type RunInfoResponse = { + wandb_run_id: string; + architecture: string | null; + availability: DataAvailability; +}; + +export async function fetchRunInfo(wandbRunIds: string[]): Promise { + return fetchJson("/api/run_registry", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(wandbRunIds), + }); +} diff --git a/spd/app/frontend/src/lib/api/runs.ts b/spd/app/frontend/src/lib/api/runs.ts index 297c670a2..d898c8671 100644 --- a/spd/app/frontend/src/lib/api/runs.ts +++ b/spd/app/frontend/src/lib/api/runs.ts @@ -2,9 +2,9 @@ * API client for /api/runs endpoints. */ -import { API_URL } from "./index"; +import { apiUrl } from "./index"; -export type RunState = { +export type LoadedRun = { id: number; wandb_path: string; config_yaml: string; @@ -12,22 +12,26 @@ export type RunState = { prompt_count: number; context_length: number; backend_user: string; + dataset_attributions_available: boolean; + dataset_search_enabled: boolean; + graph_interp_available: boolean; + autointerp_available: boolean; }; -export async function getStatus(): Promise { - const response = await fetch(`${API_URL}/api/status`); +export async function getStatus(): Promise { + const response = await fetch("/api/status"); const data = await response.json(); return data; } -export async function getWhoami(): Promise { - const response = await fetch(`${API_URL}/api/whoami`); +export async function whoami(): Promise { + const response = await fetch("/api/whoami"); const data = await response.json(); return data.user; } export async function loadRun(wandbRunPath: string, contextLength: number): Promise { - const url = new URL(`${API_URL}/api/runs/load`); + const url = apiUrl("/api/runs/load"); url.searchParams.set("wandb_path", wandbRunPath); url.searchParams.set("context_length", String(contextLength)); const response = await fetch(url.toString(), { method: "POST" }); diff --git a/spd/app/frontend/src/lib/colors.ts b/spd/app/frontend/src/lib/colors.ts index 70c023b23..d15462693 100644 --- a/spd/app/frontend/src/lib/colors.ts +++ b/spd/app/frontend/src/lib/colors.ts @@ -1,45 +1,46 @@ /** * Centralized color definitions for graph visualization. * These match the CSS variables in app.css but are available for inline styles in SVG elements. + * + * RGB values for dynamic opacity (rgba) are stored as {r, g, b} objects. + * Hex values are used for direct color application. */ export const colors = { - // Text - punchy contrast - textPrimary: "#111111", - textSecondary: "#555555", - textMuted: "#999999", + // Text - warm navy contrast (matches --text-*) + textPrimary: "#1d272a", + textSecondary: "#646464", + textMuted: "#b4b4b4", - // Status colors for edges/data - vivid - positive: "#2563eb", + // Status colors for edges/data (matches --accent-primary, --status-negative) + positive: "#4d65ff", negative: "#dc2626", - // Output node gradient (green) - vivid green + // RGB components for dynamic opacity + positiveRgb: { r: 77, g: 101, b: 255 }, // vibrant blue - matches --accent-primary + negativeRgb: { r: 220, g: 38, b: 38 }, // red - matches --status-negative + + // Output node gradient (green) - matches --status-positive outputBase: { r: 22, g: 163, b: 74 }, - // Token highlight - vivid green + // Token highlight - matches --status-positive tokenHighlight: { r: 22, g: 163, b: 74 }, tokenHighlightOpacity: 0.4, // Node default - nodeDefault: "#6b7280", + nodeDefault: "#8a8780", - // Accent (for active states) - blue - accent: "#2563eb", + // Accent (for active states) - matches --accent-primary + accent: "#7C4D33", // Set overlap visualization (A/B/intersection) setOverlap: { - self: "rgb(20, 184, 166)", // teal - A-only - both: "rgb(100, 116, 139)", // slate - intersection - other: "rgb(249, 115, 22)", // orange - B-only + self: { r: 20, g: 184, b: 166 }, // teal - A-only + both: { r: 100, g: 116, b: 139 }, // slate - intersection + other: { r: 249, g: 115, b: 22 }, // orange - B-only }, } as const; -/** Get output node fill color based on probability */ -export function getOutputNodeColor(prob: number): string { - const { r, g, b } = colors.outputBase; - return `rgb(${r + Math.round(prob * 10)}, ${g + Math.round(prob * 25)}, ${b + Math.round(prob * 16)})`; -} - /** Get edge color based on value sign */ export function getEdgeColor(val: number): string { return val > 0 ? colors.positive : colors.negative; @@ -58,11 +59,8 @@ export function getTokenHighlightBg(ci: number): string { /** Get color for component activations (blue for positive, red for negative) */ export function getComponentActivationColor(value: number, normalizedAbs: number): string { - if (value >= 0) { - return `rgba(37, 99, 235, ${normalizedAbs})`; // blue - } else { - return `rgba(220, 38, 38, ${normalizedAbs})`; // red - } + const { r, g, b } = value >= 0 ? colors.positiveRgb : colors.negativeRgb; + return `rgba(${r}, ${g}, ${b}, ${normalizedAbs})`; } /** Compute the max absolute value across all component activations (for normalization) */ @@ -81,8 +79,32 @@ export function computeMaxAbsComponentAct(exampleComponentActs: number[][]): num export function getOutputHeaderColor(prob: number): string { const { r, g, b } = colors.outputBase; const opacity = Math.min(0.8, prob + 0.05); - return `rgba(${r},${g},${b},${opacity})`; + return `rgba(${r}, ${g}, ${b}, ${opacity})`; } /** Background color with opacity for overlays */ export const bgBaseRgb = { r: 255, g: 255, b: 255 }; + +/** Convert RGB object to CSS rgb() string */ +export function rgbToCss(rgb: { r: number; g: number; b: number }): string { + return `rgb(${rgb.r}, ${rgb.g}, ${rgb.b})`; +} + +/** Convert RGB object to CSS rgba() string with opacity */ +export function rgbaToCss(rgb: { r: number; g: number; b: number }, opacity: number): string { + return `rgba(${rgb.r}, ${rgb.g}, ${rgb.b}, ${opacity})`; +} + +/** + * Get background color for next-token probability visualization. + * High probability = green (expected), low probability = white. + */ +export function getNextTokenProbBgColor(prob: number | null): string { + if (prob === null) return "white"; + const { r: gR, g: gG, b: gB } = colors.outputBase; // green + // Interpolate from white (255,255,255) to green based on probability + const r = Math.round(255 + (gR - 255) * prob); + const g = Math.round(255 + (gG - 255) * prob); + const b = Math.round(255 + (gB - 255) * prob); + return `rgb(${r}, ${g}, ${b})`; +} diff --git a/spd/app/frontend/src/lib/componentCardConstants.ts b/spd/app/frontend/src/lib/componentCardConstants.ts new file mode 100644 index 000000000..97fb9c423 --- /dev/null +++ b/spd/app/frontend/src/lib/componentCardConstants.ts @@ -0,0 +1,15 @@ +/** + * Shared constants for component card displays. + * Centralizes magic numbers to ensure consistency across ComponentNodeCard and ActivationContextsViewer. + */ + +export const COMPONENT_CARD_CONSTANTS = { + /** Number of correlations per page */ + CORRELATIONS_PAGE_SIZE: 10, + + /** Number of dataset attributions per page */ + DATASET_ATTRIBUTIONS_PAGE_SIZE: 4, + + /** Number of prompt attributions per page */ + PROMPT_ATTRIBUTIONS_PAGE_SIZE: 4, +} as const; diff --git a/spd/app/frontend/src/lib/componentKeys.ts b/spd/app/frontend/src/lib/componentKeys.ts new file mode 100644 index 000000000..ff83bda06 --- /dev/null +++ b/spd/app/frontend/src/lib/componentKeys.ts @@ -0,0 +1,17 @@ +/** + * Utilities for component key display (e.g. rendering embed/output keys with token strings). + */ + +export function isTokenNode(key: string): boolean { + const layer = key.split(":")[0]; + return layer === "embed" || layer === "output"; +} + +export function formatComponentKey(key: string, tokenStr: string | null): string { + if (tokenStr && isTokenNode(key)) { + const layer = key.split(":")[0]; + const label = layer === "embed" ? "input" : "output"; + return `'${tokenStr}' (${label})`; + } + return key; +} diff --git a/spd/app/frontend/src/lib/displaySettings.svelte.ts b/spd/app/frontend/src/lib/displaySettings.svelte.ts index 31435402b..4254acb31 100644 --- a/spd/app/frontend/src/lib/displaySettings.svelte.ts +++ b/spd/app/frontend/src/lib/displaySettings.svelte.ts @@ -13,6 +13,14 @@ export const NODE_COLOR_MODE_LABELS: Record = { subcomp_act: "Subcomp Act", }; +// Edge variant for attribution graphs +export type EdgeVariant = "signed" | "abs_target"; + +export const EDGE_VARIANT_LABELS: Record = { + signed: "Signed", + abs_target: "Abs Target", +}; + // Example color mode for activation contexts viewer export type ExampleColorMode = "ci" | "component_act" | "both"; @@ -36,18 +44,39 @@ export const CORRELATION_STAT_DESCRIPTIONS: Record jaccard: "Intersection over union", }; -export const displaySettings = $state({ - showPmi: true, - showPrecision: true, - showRecall: true, - showJaccard: true, +type DisplaySettings = { + showPmi: boolean; + showPrecision: boolean; + showRecall: boolean; + showJaccard: boolean; + showSetOverlapVis: boolean; + showEdgeAttributions: boolean; + nodeColorMode: NodeColorMode; + exampleColorMode: ExampleColorMode; + meanCiCutoff: number; + centerOnPeak: boolean; + showAutoInterpPromptButton: boolean; + curvedEdges: boolean; + edgeVariant: EdgeVariant; +}; + +export const displaySettings = $state({ + showPmi: false, + showPrecision: false, + showRecall: false, + showJaccard: false, showSetOverlapVis: true, showEdgeAttributions: true, - nodeColorMode: "ci" as NodeColorMode, - exampleColorMode: "ci" as ExampleColorMode, + nodeColorMode: "ci", + exampleColorMode: "ci", + meanCiCutoff: 1e-7, + centerOnPeak: false, + showAutoInterpPromptButton: false, + curvedEdges: true, + edgeVariant: "signed", }); -export function hasAnyCorrelationStats() { +export function anyCorrelationStatsEnabled() { return ( displaySettings.showPmi || displaySettings.showPrecision || diff --git a/spd/app/frontend/src/lib/graphLayout.ts b/spd/app/frontend/src/lib/graphLayout.ts new file mode 100644 index 000000000..cc3e6fa19 --- /dev/null +++ b/spd/app/frontend/src/lib/graphLayout.ts @@ -0,0 +1,136 @@ +/** + * Graph layout utilities for canonical transformer addresses. + * + * Canonical address format: + * "embed" — embedding + * "output" — unembed / logits + * "{block}.{sublayer}.{projection}" — e.g. "0.attn.q", "2.mlp.down" + * + * Node key format: + * "{layer}:{seqIdx}:{cIdx}" — e.g. "0.attn.q:3:5", "embed:0:0" + */ + +export type LayerInfo = { + name: string; + block: number; // -1 for embed, Infinity for output + sublayer: string; // "attn" | "attn_fused" | "mlp" | "glu" | "embed" | "output" + projection: string | null; // "q" | "k" | "v" | "o" | "qkv" | "up" | "down" | "gate" | null +}; + +const SUBLAYER_ORDER = ["attn", "attn_fused", "glu", "mlp"]; + +// Projections that share a row and get grouped horizontally +const GROUPED_PROJECTIONS: Record = { + attn: ["q", "k", "v"], + glu: ["gate", "up"], +}; + +// Full projection ordering within each sublayer (grouped inputs first, then outputs) +const PROJECTION_ORDER: Record = { + attn: ["q", "k", "v", "o"], + attn_fused: ["qkv", "o"], + glu: ["gate", "up", "down"], + mlp: ["up", "down"], +}; + +export function parseLayer(name: string): LayerInfo { + if (name === "embed") return { name, block: -1, sublayer: "embed", projection: null }; + if (name === "output") return { name, block: Infinity, sublayer: "output", projection: null }; + + const parts = name.split("."); + return { + name, + block: +parts[0], + sublayer: parts[1], + projection: parts[2], + }; +} + +/** + * Row key: layers that share the same visual row. + * q/k/v share "0.attn.qkv", gate/up share "0.glu.gate_up". + * Ungrouped projections (o, down) get their own row. + */ +export function getRowKey(layer: string): string { + const info = parseLayer(layer); + if (info.sublayer === "embed" || info.sublayer === "output") return layer; + + const grouped = GROUPED_PROJECTIONS[info.sublayer]; + if (grouped && info.projection && grouped.includes(info.projection)) { + return `${info.block}.${info.sublayer}.${grouped.join("_")}`; + } + return layer; +} + +/** + * Row label for display. + */ +export function getRowLabel(rowKey: string): string { + if (rowKey === "embed") return "embed"; + if (rowKey === "output") return "output"; + + const parts = rowKey.split("."); + const block = parts[0]; + const sublayer = parts[1]; + const projPart = parts[2]; + + if (!projPart) return `${block}.${sublayer}`; + + // Grouped projections: show "0.attn.qkv" or "0.glu.gate/up" + if (projPart.includes("_")) { + return `${block}.${sublayer}.${projPart.replace(/_/g, "/")}`; + } + return rowKey; +} + +/** + * Sort row keys: embed at bottom, output at top, blocks in between. + * Within a block: sublayers follow SUBLAYER_ORDER, grouped projections before ungrouped. + */ +export function sortRows(rows: string[]): string[] { + return [...rows].sort((a, b) => { + const partsA = a.split("."); + const partsB = b.split("."); + + const blockA = a === "embed" ? -1 : a === "output" ? Infinity : +partsA[0]; + const blockB = b === "embed" ? -1 : b === "output" ? Infinity : +partsB[0]; + + if (blockA !== blockB) return blockA - blockB; + + const sublayerA = partsA[1] ?? ""; + const sublayerB = partsB[1] ?? ""; + const sublayerDiff = SUBLAYER_ORDER.indexOf(sublayerA) - SUBLAYER_ORDER.indexOf(sublayerB); + if (sublayerDiff !== 0) return sublayerDiff; + + // Within same sublayer: order by first projection in the row key + const projOrder = PROJECTION_ORDER[sublayerA] ?? []; + const firstProjA = (partsA[2] ?? "").split("_")[0]; + const firstProjB = (partsB[2] ?? "").split("_")[0]; + const projIdxA = projOrder.indexOf(firstProjA); + const projIdxB = projOrder.indexOf(firstProjB); + return (projIdxA === -1 ? 999 : projIdxA) - (projIdxB === -1 ? 999 : projIdxB); + }); +} + +/** + * Get the grouped projections for a sublayer, if any. + * Returns null if no grouping (each projection gets its own horizontal space). + */ +export function getGroupProjections(sublayer: string): string[] | null { + return GROUPED_PROJECTIONS[sublayer] ?? null; +} + +/** + * Check if a specific projection is part of its sublayer's group. + */ +export function isGroupedProjection(sublayer: string, projection: string): boolean { + const group = GROUPED_PROJECTIONS[sublayer]; + return group !== undefined && group.includes(projection); +} + +/** + * Build the full layer address from block + sublayer + projection. + */ +export function buildLayerAddress(block: number, sublayer: string, projection: string): string { + return `${block}.${sublayer}.${projection}`; +} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index e03826414..db2794da5 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -1,46 +1,105 @@ /** Types for the intervention forward pass feature */ -export type InterventionNode = { - layer: string; - seq_pos: number; - component_idx: number; -}; +/** Default eval PGD settings (distinct from training PGD which is an optimization regularizer) */ +export const EVAL_PGD_N_STEPS = 4; +export const EVAL_PGD_STEP_SIZE = 1.0; export type TokenPrediction = { token: string; token_id: number; - spd_prob: number; - target_prob: number; + prob: number; logit: number; + target_prob: number; target_logit: number; }; -export type InterventionResponse = { - input_tokens: string[]; - predictions_per_position: TokenPrediction[][]; +export type LabelPredictions = { + position: number; + ci: TokenPrediction; + stochastic: TokenPrediction; + adversarial: TokenPrediction; + ablated: TokenPrediction | null; }; -/** A forked intervention run with modified tokens */ -export type ForkedInterventionRun = { - id: number; - token_replacements: [number, number][]; // [(seq_pos, new_token_id), ...] - result: InterventionResponse; - created_at: string; +export type InterventionResult = { + input_tokens: string[]; + ci: TokenPrediction[][]; + stochastic: TokenPrediction[][]; + adversarial: TokenPrediction[][]; + ablated: TokenPrediction[][] | null; + ci_loss: number; + stochastic_loss: number; + adversarial_loss: number; + ablated_loss: number | null; + label: LabelPredictions | null; }; /** Persisted intervention run from the server */ export type InterventionRunSummary = { id: number; selected_nodes: string[]; // node keys (layer:seq:cIdx) - result: InterventionResponse; + result: InterventionResult; created_at: string; - forked_runs?: ForkedInterventionRun[]; // child runs with modified tokens }; /** Request to run and save an intervention */ export type RunInterventionRequest = { graph_id: number; - text: string; selected_nodes: string[]; - top_k?: number; + nodes_to_ablate?: string[]; + top_k: number; + adv_pgd: { n_steps: number; step_size: number }; }; + +// --- Frontend-only run lifecycle types --- + +import { SvelteSet } from "svelte/reactivity"; +import { isInterventableNode } from "./promptAttributionsTypes"; + +/** Draft run: cloned from a parent, editable node selection. No forwarded results yet. */ +export type DraftRun = { + kind: "draft"; + parentId: number; + selectedNodes: SvelteSet; +}; + +/** Baked run: forwarded and immutable. Wraps a persisted InterventionRunSummary. */ +export type BakedRun = { + kind: "baked"; + id: number; + selectedNodes: Set; + result: InterventionResult; + createdAt: string; +}; + +export type InterventionRun = DraftRun | BakedRun; + +export type InterventionState = { + runs: InterventionRun[]; + activeIndex: number; +}; + +/** Build initial InterventionState from persisted runs. + * The first persisted run is the base run (all CI > 0 nodes), auto-created during graph computation. */ +export function buildInterventionState(persistedRuns: InterventionRunSummary[]): InterventionState { + if (persistedRuns.length === 0) throw new Error("Graph must have at least one intervention run (the base run)"); + const runs: InterventionRun[] = persistedRuns.map( + (r): BakedRun => ({ + kind: "baked", + id: r.id, + selectedNodes: new Set(r.selected_nodes), + result: r.result, + createdAt: r.created_at, + }), + ); + return { runs, activeIndex: 0 }; +} + +/** Get all interventable node keys with CI > 0 from a nodeCiVals record */ +export function getInterventableNodes(nodeCiVals: Record): Set { + const nodes = new Set(); + for (const [nodeKey, ci] of Object.entries(nodeCiVals)) { + if (isInterventableNode(nodeKey) && ci > 0) nodes.add(nodeKey); + } + return nodes; +} diff --git a/spd/app/frontend/src/lib/promptAttributionsTypes.ts b/spd/app/frontend/src/lib/promptAttributionsTypes.ts index 100a60c16..d77b22fc4 100644 --- a/spd/app/frontend/src/lib/promptAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/promptAttributionsTypes.ts @@ -7,9 +7,10 @@ export type PromptPreview = { token_ids: number[]; tokens: string[]; preview: string; + next_token_probs: (number | null)[]; // Probability of next token (last is null) }; -export type Edge = { +export type EdgeData = { src: string; // "layer:seq:cIdx" tgt: string; // "layer:seq:cIdx" val: number; @@ -19,9 +20,30 @@ export type EdgeAttribution = { key: string; // "layer:seq:cIdx" for prompt or "layer:cIdx" for dataset value: number; // raw attribution value (positive or negative) normalizedMagnitude: number; // |value| / maxAbsValue, for color intensity (0-1) + tokenStr: string | null; // resolved token string for embed/output layers }; -export type OutputProbEntry = { +/** Sort edges by |val| desc, take top N, normalize magnitudes to [0,1]. */ +export function topEdgeAttributions( + edges: EdgeData[], + getKey: (e: EdgeData) => string, + n: number, + resolveTokenStr?: (key: string) => string | null, +): EdgeAttribution[] { + const sorted = [...edges].sort((a, b) => Math.abs(b.val) - Math.abs(a.val)).slice(0, n); + const maxAbsVal = Math.abs(sorted[0]?.val || 1); + return sorted.map((e) => { + const key = getKey(e); + return { + key, + value: e.val, + normalizedMagnitude: Math.abs(e.val) / maxAbsVal, + tokenStr: resolveTokenStr ? resolveTokenStr(key) : null, + }; + }); +} + +export type OutputProbability = { prob: number; // CI-masked (SPD model) probability logit: number; // CI-masked (SPD model) raw logit target_prob: number; // Target model probability @@ -29,31 +51,47 @@ export type OutputProbEntry = { token: string; }; +export type CISnapshot = { + step: number; + total_steps: number; + layers: string[]; + seq_len: number; + initial_alive: number[][]; + current_alive: number[][]; + l0_total: number; + loss: number; +}; + export type GraphType = "standard" | "optimized" | "manual"; export type GraphData = { id: number; graphType: GraphType; tokens: string[]; - edges: Edge[]; - edgesBySource: Map; // nodeKey -> edges where this node is source - edgesByTarget: Map; // nodeKey -> edges where this node is target - outputProbs: Record; // key is "seq:cIdx" + edges: EdgeData[]; + edgesBySource: Map; // nodeKey -> edges where this node is source + edgesByTarget: Map; // nodeKey -> edges where this node is target + // Absolute-target variant (∂|y|/∂x · x), null for old graphs + edgesAbs: EdgeData[] | null; + edgesAbsBySource: Map | null; + edgesAbsByTarget: Map | null; + outputProbs: Record; // key is "seq:cIdx" nodeCiVals: Record; // node key -> CI value (or output prob for output nodes or 1 for wte node) nodeSubcompActs: Record; // node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: number; // max absolute edge value + maxAbsAttrAbs: number | null; // max absolute edge value for abs-target variant maxAbsSubcompAct: number; // max absolute subcomponent activation for normalization l0_total: number; // total active components at current CI threshold optimization?: OptimizationResult; }; /** Build edge indexes from flat edge array (single pass) */ -export function buildEdgeIndexes(edges: Edge[]): { - edgesBySource: Map; - edgesByTarget: Map; +export function buildEdgeIndexes(edges: EdgeData[]): { + edgesBySource: Map; + edgesByTarget: Map; } { - const edgesBySource = new Map(); - const edgesByTarget = new Map(); + const edgesBySource = new Map(); + const edgesByTarget = new Map(); for (const edge of edges) { const bySrc = edgesBySource.get(edge.src); @@ -76,30 +114,62 @@ export function buildEdgeIndexes(edges: Edge[]): { export type MaskType = "stochastic" | "ci"; +export type CELossResult = { + type: "ce"; + coeff: number; + position: number; + label_token: number; + label_str: string; +}; + +export type KLLossResult = { + type: "kl"; + coeff: number; + position: number; +}; + +export type LogitLossResult = { + type: "logit"; + coeff: number; + position: number; + label_token: number; + label_str: string; +}; + +export type LossResult = CELossResult | KLLossResult | LogitLossResult; + +export type OptimizationMetrics = { + ci_masked_label_prob: number | null; // Probability of label under CI mask (CE loss only) + stoch_masked_label_prob: number | null; // Probability of label under stochastic mask (CE loss only) + adv_pgd_label_prob: number | null; // Probability of label under adversarial mask (CE loss only) + l0_total: number; // Total L0 (active components) +}; + +export type PgdConfig = { + n_steps: number; + step_size: number; +}; + export type OptimizationResult = { imp_min_coeff: number; steps: number; pnorm: number; beta: number; - // CE loss params (optional - required together) - label_token: number | null; - label_str: string | null; - ce_loss_coeff: number | null; - label_prob: number | null; - // KL loss param (optional) - kl_loss_coeff: number | null; mask_type: MaskType; + loss: LossResult; + metrics: OptimizationMetrics; + pgd: PgdConfig | null; }; -export type ComponentSummary = { +export type SubcomponentMetadata = { subcomponent_idx: number; mean_ci: number; }; -export type ActivationContextsSummary = Record; +export type ActivationContextsSummary = Record; // Note: Token P/R/lift stats come from /token_stats endpoint (batch job), not here -export type ComponentDetail = { +export type SubcomponentActivationContexts = { subcomponent_idx: number; mean_ci: number; example_tokens: string[][]; @@ -107,7 +177,7 @@ export type ComponentDetail = { example_component_acts: number[][]; }; -export type CorrelatedComponent = { +export type CorrelatedSubcomponent = { component_key: string; score: number; count_i: number; // Subject (query component) firing count @@ -116,12 +186,12 @@ export type CorrelatedComponent = { n_tokens: number; // Total tokens }; -export type ComponentCorrelations = { - precision: CorrelatedComponent[]; - recall: CorrelatedComponent[]; - jaccard: CorrelatedComponent[]; - pmi: CorrelatedComponent[]; - bottom_pmi: CorrelatedComponent[]; +export type SubcomponentCorrelationsResponse = { + precision: CorrelatedSubcomponent[]; + recall: CorrelatedSubcomponent[]; + jaccard: CorrelatedSubcomponent[]; + pmi: CorrelatedSubcomponent[]; + bottom_pmi: CorrelatedSubcomponent[]; }; // Token P/R/lift/PMI for a single category (input or output) @@ -134,30 +204,46 @@ export type TokenPRLiftPMI = { }; // Token stats from batch job - includes both input and output stats -export type TokenStats = { +export type TokenStatsResponse = { input: TokenPRLiftPMI; // What tokens activate this component output: TokenPRLiftPMI; // What tokens this component predicts }; -export type TokenizeResult = { +export type TokenizeResponse = { token_ids: number[]; tokens: string[]; text: string; + next_token_probs: (number | null)[]; // Probability of next token (last is null) }; -export type TokenInfo = { +export type TokenSearchResult = { id: number; string: string; + prob: number; }; -// Client-side computed types +/** Select active edge set based on variant preference. Falls back to signed if abs unavailable. */ +export function getActiveEdges( + data: GraphData, + variant: "signed" | "abs_target", +): { edges: EdgeData[]; bySource: Map; byTarget: Map; maxAbsAttr: number } { + if (variant === "abs_target" && data.edgesAbs) { + return { + edges: data.edgesAbs, + bySource: data.edgesAbsBySource!, + byTarget: data.edgesAbsByTarget!, + maxAbsAttr: data.maxAbsAttrAbs || 1, + }; + } + return { + edges: data.edges, + bySource: data.edgesBySource, + byTarget: data.edgesByTarget, + maxAbsAttr: data.maxAbsAttr || 1, + }; +} -export type LayerInfo = { - name: string; - block: number; - type: "attn" | "mlp" | "embed" | "output"; - subtype: string; -}; +// Client-side computed types export type NodePosition = { x: number; @@ -193,34 +279,30 @@ export type LayoutResult = { }; // Component probe result -export type ComponentProbeResult = { +export type SubcomponentProbeResult = { tokens: string[]; ci_values: number[]; subcomp_acts: number[]; + next_token_probs: (number | null)[]; // Probability of next token (last is null) }; -// Display name mapping for special layers -const LAYER_DISPLAY_NAMES: Record = { - lm_head: "W_U", -}; - -/** Get display name for a layer (e.g., "lm_head" -> "W_U") */ -export function getLayerDisplayName(layer: string): string { - return LAYER_DISPLAY_NAMES[layer] ?? layer; +/** Get display name for a layer (e.g., "lm_head" -> "W_U") using model-provided names */ +export function getLayerDisplayName(layer: string, displayNames: Record): string { + return displayNames[layer] ?? layer; } /** Format a node key for display, replacing layer names with display names */ -export function formatNodeKeyForDisplay(nodeKey: string): string { +export function formatNodeKeyForDisplay(nodeKey: string, displayNames: Record): string { const [layer, ...rest] = nodeKey.split(":"); - const displayName = getLayerDisplayName(layer); + const displayName = getLayerDisplayName(layer, displayNames); return [displayName, ...rest].join(":"); } // Node intervention helpers -// "wte" and "output" are pseudo-layers used for visualization but are not part of the +// "embed" and "output" are pseudo-layers used for visualization but are not part of the // decomposed model. They cannot be intervened on - only the internal layers (attn/mlp) // can have their components selectively activated. -const NON_INTERVENTABLE_LAYERS = new Set(["wte", "output"]); +const NON_INTERVENTABLE_LAYERS = new Set(["embed", "wte", "output"]); export function isInterventableNode(nodeKey: string): boolean { const layer = nodeKey.split(":")[0]; @@ -230,3 +312,28 @@ export function isInterventableNode(nodeKey: string): boolean { export function filterInterventableNodes(nodeKeys: Iterable): Set { return new Set([...nodeKeys].filter(isInterventableNode)); } + +/** + * Convert a node key (layer:seq:cIdx) to a component key (layer:cIdx). + * Component keys are used for caching/fetching component data. + */ +export function nodeKeyToComponentKey(nodeKey: string): string { + const [layer, , cIdx] = nodeKey.split(":"); + return `${layer}:${cIdx}`; +} + +/** + * Extract unique component keys from a graph. + * Filters out non-interventable nodes (wte, output) and returns unique layer:cIdx keys. + */ +export function extractComponentKeys(graph: GraphData): string[] { + const componentKeys = new Set(); + + for (const nodeKey of Object.keys(graph.nodeCiVals)) { + if (isInterventableNode(nodeKey)) { + componentKeys.add(nodeKeyToComponentKey(nodeKey)); + } + } + + return Array.from(componentKeys); +} diff --git a/spd/app/frontend/src/lib/registry.ts b/spd/app/frontend/src/lib/registry.ts index a1f3d59e7..7bb388355 100644 --- a/spd/app/frontend/src/lib/registry.ts +++ b/spd/app/frontend/src/lib/registry.ts @@ -1,79 +1,64 @@ /** - * Registry of canonical SPD runs for quick access in the app. + * Canonical SPD runs for the run picker. + * + * Static data (name, notes) renders instantly in the UI. + * Dynamic data (architecture, availability) is hydrated from the backend. */ +export type ClusterMappingEntry = { path: string; notes: string }; + export type RegistryEntry = { - /** Full wandb run id (e.g., "goodfire/spd/jyo9duz5") */ wandbRunId: string; - /** Human-readable model name */ - modelName: string; - /** Optional notes about the run */ + name?: string; notes?: string; - /** Optional cluster mappings for the run */ - clusterMappings?: { - path: string; - notes: string; - }[]; + clusterMappings?: ClusterMappingEntry[]; }; const DEFAULT_ENTITY_PROJECT = "goodfire/spd"; -/** - * Canonical runs registry - add new entries here. - * These appear in the dropdown for quick selection. - */ export const CANONICAL_RUNS: RegistryEntry[] = [ { - wandbRunId: "goodfire/spd/s-8dc8cf09", - modelName: "ss_llama_simple_mlp-2L-wide", - notes: "Lucius run with beta=0.1, Jan 16", - }, - { - wandbRunId: "goodfire/spd/s-7884efcc", - modelName: "ss_llama_simple_mlp-1.25M (4L)", - notes: "Lucius' new run, Jan 8", + name: "Thomas", + wandbRunId: "goodfire/spd/s-82ffb969", + notes: "pile_llama_simple_mlp-4L", + clusterMappings: [ + { + path: "/mnt/polished-lake/artifacts/mechanisms/spd/clustering/runs/c-f9cc81c8/cluster_mapping.json", + notes: "All layers, 9100 iterations", + }, + ], }, { - wandbRunId: "goodfire/spd/vjbol27n", - modelName: "ss_llama_simple_mlp-1.25M (4L)", - notes: "Lucius' run, Dec 8", + name: "Jose", + wandbRunId: "goodfire/spd/s-55ea3f9b", + notes: "pile_llama_simple_mlp-4L", clusterMappings: [ { - path: "clustering/ensembles/e-c313e883/cluster_mapping_e-c313e883.json", - notes: "All layers, 80 iterations", + path: "/mnt/polished-lake/artifacts/mechanisms/spd/clustering/runs/c-70b28465/cluster_mapping.json", + notes: "All layers, 9100 iterations", }, ], }, { - wandbRunId: "goodfire/spd/278we8gk", - modelName: "ss_llama_simple_mlp-1.25M (4L)", - notes: "Dan's initial run, Dec 6", + name: "finetune", + wandbRunId: "goodfire/spd/s-17805b61", + notes: "finetune", }, { - wandbRunId: "goodfire/spd/jyo9duz5", - modelName: "ss_gpt2_simple-1.25M (4L)", + wandbRunId: "goodfire/spd/s-275c8f21", + notes: "Lucius' pile run Feb 11", }, { - wandbRunId: "goodfire/spd/5cr21lbs", - modelName: "ss_llama_simple_mlp (1L)", - clusterMappings: [ - { - path: "clustering/ensembles/e-04370c84/cluster_mapping_e-04370c84.json", - notes: "All layers, 200 iterations", - }, - { - path: "clustering/ensembles/e-5f228e5f/cluster_mapping_e-5f228e5f.json", - notes: "Just down_proj, 80 iterations", - }, - ], + wandbRunId: "goodfire/spd/s-eab2ace8", + notes: "Oli's PPGD run, great metrics", }, { - wandbRunId: "goodfire/spd/itmexlj0", - modelName: "ss_llama_simple_mlp (2L)", + wandbRunId: "goodfire/spd/s-892f140b", + notes: "Lucius run, Jan 22", }, { - wandbRunId: "goodfire/spd/33n6xjjt", - modelName: "ss_gpt2_simple (1L)", + wandbRunId: "goodfire/spd/s-7884efcc", + notes: "Lucius' new run, Jan 8", }, ]; @@ -84,7 +69,6 @@ export const CANONICAL_RUNS: RegistryEntry[] = [ */ export function formatRunIdForDisplay(wandbRunId: string): string { if (wandbRunId.startsWith(`${DEFAULT_ENTITY_PROJECT}/`)) { - // Extract just the run id (last segment) const parts = wandbRunId.split("/"); return parts[parts.length - 1]; } diff --git a/spd/app/frontend/src/lib/tokenUtils.ts b/spd/app/frontend/src/lib/tokenUtils.ts new file mode 100644 index 000000000..87a9ce2d1 --- /dev/null +++ b/spd/app/frontend/src/lib/tokenUtils.ts @@ -0,0 +1,41 @@ +/** + * Shared token display utilities. + * + * Backend already escapes most control chars via `escape_for_display()` in app_tokenizer.py, + * but the frontend applies the same transforms defensively (some paths may bypass the backend). + */ + +const CONTROL_CHAR_MAP: [string, string][] = [ + ["\n", "↵"], + ["\r", "⏎"], + ["\t", "⇥"], + ["\v", "⇣"], + ["\f", "⇟"], + ["\x00", "␀"], +]; + +/** Replace invisible / control characters with visible unicode proxies. */ +export function sanitizeToken(tok: string): string { + let out = tok; + for (const [char, replacement] of CONTROL_CHAR_MAP) { + out = out.replaceAll(char, replacement); + } + return out; +} + +/** + * Get the next-token probability at a given position. + * + * nextTokenProbs[i] is P(token[i+1] | token[0..i]), so the probability + * "for" position i (the token displayed there) is nextTokenProbs[i-1]. + * Position 0 has no prediction (it's the first token). + */ +export function getProbAtPosition(nextTokenProbs: (number | null)[], i: number): number | null { + if (i === 0) return null; + return nextTokenProbs[i - 1]; +} + +export function formatProb(prob: number | null): string { + if (prob === null) return ""; + return `${(prob * 100).toFixed(1)}%`; +} diff --git a/spd/app/frontend/src/lib/useComponentData.svelte.ts b/spd/app/frontend/src/lib/useComponentData.svelte.ts index da96d307d..d954faa27 100644 --- a/spd/app/frontend/src/lib/useComponentData.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentData.svelte.ts @@ -1,45 +1,53 @@ -import { getContext } from "svelte"; +import { getContext, untrack } from "svelte"; import type { Loadable } from "."; import { ApiError, getComponentAttributions, getComponentCorrelations, getComponentTokenStats, + getGraphInterpComponentDetail, getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; -import type { ComponentCorrelations, ComponentDetail, TokenStats } from "./promptAttributionsTypes"; +import type { AllMetricAttributions, GraphInterpComponentDetail, InterpretationDetail } from "./api"; +import type { + SubcomponentCorrelationsResponse, + SubcomponentActivationContexts, + TokenStatsResponse, +} from "./promptAttributionsTypes"; import { RUN_KEY, type InterpretationBackendState, type RunContext } from "./useRun.svelte"; /** Correlations are paginated in the UI, so fetch more */ const CORRELATIONS_TOP_K = 100; -/** Token stats are displayed directly (max 50 shown) */ -const TOKEN_STATS_TOP_K = 50; +/** Token stats are paginated in the UI */ +const TOKEN_STATS_TOP_K = 200; /** Dataset attributions top-k */ const DATASET_ATTRIBUTIONS_TOP_K = 20; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; /** - * Hook for loading component data (detail, correlations, token stats, interpretation detail). + * Hook for loading component data via network requests. * * Call `load(layer, cIdx)` explicitly when you want to fetch data. * Interpretation headline is derived from the global runState cache. * Interpretation detail (reasoning + prompt) is fetched on-demand. + * + * For graph tooltips (smaller initial limits + background fetch), use useComponentDataExpectCached. */ export function useComponentData() { const runState = getContext(RUN_KEY); - let componentDetail = $state>({ status: "uninitialized" }); + let componentDetail = $state>({ status: "uninitialized" }); // null inside Loadable means "no data for this component" (404) - let correlations = $state>({ status: "uninitialized" }); - let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let correlations = $state>({ status: "uninitialized" }); + let tokenStats = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); // Current coords being loaded/displayed (for interpretation lookup) let currentCoords = $state(null); @@ -67,7 +75,7 @@ export function useComponentData() { // Fetch component detail (cached in runState after first call) runState - .getComponentDetail(layer, cIdx) + .getActivationContextDetail(layer, cIdx) .then((data) => { if (isStale()) return; componentDetail = { status: "loaded", data }; @@ -107,35 +115,59 @@ export function useComponentData() { } }); - // Fetch dataset attributions (404 = not available) - getComponentAttributions(layer, cIdx, DATASET_ATTRIBUTIONS_TOP_K) - .then((data) => { - if (isStale()) return; - datasetAttributions = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - datasetAttributions = { status: "loaded", data: null }; - } else { - datasetAttributions = { status: "error", error }; - } - }); + // Fetch dataset attributions (skip entirely if not available for this run) + if (runState.datasetAttributionsAvailable) { + getComponentAttributions(layer, cIdx, DATASET_ATTRIBUTIONS_TOP_K) + .then((data) => { + if (isStale()) return; + datasetAttributions = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + datasetAttributions = { status: "loaded", data: null }; + } else { + datasetAttributions = { status: "error", error }; + } + }); + } else { + datasetAttributions = { status: "loaded", data: null }; + } - // Fetch interpretation detail (404 = no interpretation for this component) - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail (skip if not available for this run) + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } /** @@ -149,6 +181,7 @@ export function useComponentData() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache - reactive to both coords and cache @@ -202,6 +235,9 @@ export function useComponentData() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts new file mode 100644 index 000000000..d76c5da9e --- /dev/null +++ b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts @@ -0,0 +1,230 @@ +/** + * Hook for lazily loading component data with small initial limits. + * + * Fetches activation contexts (10), correlations (10), and token stats (10) + * in parallel for fast initial render, then background-fetches full activation + * examples (200). Dataset attributions and interpretation detail are on-demand. + */ + +import { getContext, untrack } from "svelte"; +import type { Loadable } from "."; +import { + ApiError, + getActivationContextDetail, + getComponentAttributions, + getComponentCorrelations, + getComponentTokenStats, + getGraphInterpComponentDetail, + getInterpretationDetail, + requestComponentInterpretation, +} from "./api"; +import type { AllMetricAttributions, GraphInterpComponentDetail, InterpretationDetail } from "./api"; +import type { + SubcomponentCorrelationsResponse, + SubcomponentActivationContexts, + TokenStatsResponse, +} from "./promptAttributionsTypes"; +import { RUN_KEY, type InterpretationBackendState, type RunContext } from "./useRun.svelte"; + +const DATASET_ATTRIBUTIONS_TOP_K = 20; +/** Fetch more activation examples in background after initial cached load */ +const ACTIVATION_EXAMPLES_FULL_LIMIT = 200; + +export type { AllMetricAttributions as DatasetAttributions }; + +export type ComponentCoords = { layer: string; cIdx: number }; + +export function useComponentDataExpectCached() { + const runState = getContext(RUN_KEY); + + let componentDetail = $state>({ status: "uninitialized" }); + let correlations = $state>({ status: "uninitialized" }); + let tokenStats = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); + let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); + + let currentCoords = $state(null); + let requestId = 0; + + /** Fetch full activation examples in background (overwrites cached data when complete). */ + function startBackgroundFetch( + layer: string, + cIdx: number, + cachedDetail: SubcomponentActivationContexts, + isStale: () => boolean, + ) { + getActivationContextDetail(layer, cIdx, ACTIVATION_EXAMPLES_FULL_LIMIT) + .then((data) => { + if (isStale()) return; + if (data.example_tokens.length > cachedDetail.example_tokens.length) { + componentDetail = { status: "loaded", data }; + } + }) + .catch((error) => { + if (isStale()) return; + componentDetail = { status: "error", error }; + }); + } + + /** Start on-demand fetches (dataset attributions, interpretation detail). */ + function startOnDemandFetches(layer: string, cIdx: number, isStale: () => boolean) { + // Skip fetch entirely if dataset attributions not available for this run + if (runState.datasetAttributionsAvailable) { + datasetAttributions = { status: "loading" }; + getComponentAttributions(layer, cIdx, DATASET_ATTRIBUTIONS_TOP_K) + .then((data) => { + if (isStale()) return; + datasetAttributions = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + datasetAttributions = { status: "loaded", data: null }; + } else { + datasetAttributions = { status: "error", error }; + } + }); + } else { + datasetAttributions = { status: "loaded", data: null }; + } + + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + interpretationDetail = { status: "loading" }; + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + interpretationDetail = { status: "error", error }; + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } + } + + function load(layer: string, cIdx: number) { + currentCoords = { layer, cIdx }; + const thisRequestId = ++requestId; + + const isStale = () => requestId !== thisRequestId; + + componentDetail = { status: "loading" }; + correlations = { status: "loading" }; + tokenStats = { status: "loading" }; + + Promise.all([ + getActivationContextDetail(layer, cIdx, 10), + getComponentCorrelations(layer, cIdx, 10).catch(() => null), + getComponentTokenStats(layer, cIdx, 10).catch(() => null), + ]) + .then(([detail, corr, stats]) => { + if (isStale()) return; + componentDetail = { status: "loaded", data: detail }; + correlations = { status: "loaded", data: corr }; + tokenStats = { status: "loaded", data: stats }; + startBackgroundFetch(layer, cIdx, detail, isStale); + }) + .catch((error) => { + if (isStale()) return; + componentDetail = { status: "error", error }; + correlations = { status: "error", error }; + tokenStats = { status: "error", error }; + }); + + startOnDemandFetches(layer, cIdx, isStale); + } + + function reset() { + requestId++; + currentCoords = null; + componentDetail = { status: "uninitialized" }; + correlations = { status: "uninitialized" }; + tokenStats = { status: "uninitialized" }; + datasetAttributions = { status: "uninitialized" }; + interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; + } + + // Interpretation is derived from the global cache + const interpretation = $derived.by((): Loadable => { + if (!currentCoords) return { status: "uninitialized" }; + return runState.getInterpretation(`${currentCoords.layer}:${currentCoords.cIdx}`); + }); + + async function generateInterpretation() { + if (!currentCoords) return; + + const { layer, cIdx } = currentCoords; + const componentKey = `${layer}:${cIdx}`; + + try { + runState.setInterpretation(componentKey, { status: "generating" }); + const result = await requestComponentInterpretation(layer, cIdx); + runState.setInterpretation(componentKey, { status: "generated", data: result }); + + // Fetch the detail now that it exists + try { + const detail = await getInterpretationDetail(layer, cIdx); + interpretationDetail = { status: "loaded", data: detail }; + } catch (detailError) { + interpretationDetail = { status: "error", error: detailError }; + } + } catch (e) { + runState.setInterpretation(componentKey, { + status: "generation-error", + error: e instanceof Error ? e.message : String(e), + }); + } + } + + return { + get componentDetail() { + return componentDetail; + }, + get correlations() { + return correlations; + }, + get tokenStats() { + return tokenStats; + }, + get datasetAttributions() { + return datasetAttributions; + }, + get interpretation() { + return interpretation; + }, + get interpretationDetail() { + return interpretationDetail; + }, + get graphInterpDetail() { + return graphInterpDetail; + }, + load, + reset, + generateInterpretation, + }; +} diff --git a/spd/app/frontend/src/lib/useRun.svelte.ts b/spd/app/frontend/src/lib/useRun.svelte.ts index 629b203b8..1cfc3cca6 100644 --- a/spd/app/frontend/src/lib/useRun.svelte.ts +++ b/spd/app/frontend/src/lib/useRun.svelte.ts @@ -7,8 +7,8 @@ import type { Loadable } from "."; import * as api from "./api"; -import type { RunState as RunData, Interpretation } from "./api"; -import type { ActivationContextsSummary, ComponentDetail, PromptPreview, TokenInfo } from "./promptAttributionsTypes"; +import type { LoadedRun as RunData, InterpretationHeadline, GraphInterpHeadline } from "./api"; +import type { PromptPreview, SubcomponentActivationContexts, SubcomponentMetadata } from "./promptAttributionsTypes"; /** Maps component keys to cluster IDs. Singletons (unclustered components) have null values. */ export type ClusterMappingData = Record; @@ -28,7 +28,7 @@ type ClusterMapping = { export type InterpretationBackendState = | { status: "none" } | { status: "generating" } - | { status: "generated"; data: Interpretation } + | { status: "generated"; data: InterpretationHeadline } | { status: "generation-error"; error: unknown }; export function useRun() { @@ -38,26 +38,32 @@ export function useRun() { /** Interpretation labels keyed by component key (layer:cIdx) */ let interpretations = $state>>({ status: "uninitialized" }); + /** Intruder eval scores keyed by component key */ + let intruderScores = $state>>({ status: "uninitialized" }); + + /** Graph interp labels keyed by component key (layer:cIdx) */ + let graphInterpLabels = $state>>({ status: "uninitialized" }); + /** Cluster mapping for the current run */ let clusterMapping = $state(null); /** Available prompts for the current run */ let prompts = $state>({ status: "uninitialized" }); - /** All tokens in the tokenizer for the current run */ - let allTokens = $state>({ status: "uninitialized" }); + /** Activation contexts summary (null = harvest not available) */ + let activationContextsSummary = $state | null>>({ + status: "uninitialized", + }); - /** Activation contexts summary */ - let activationContextsSummary = $state>({ status: "uninitialized" }); - - /** Cached component details keyed by component key (layer:cIdx) - non-reactive */ - let _componentDetailsCache: Record = {}; + // Cached activation context detail keyed by component key (layer:cIdx) - non-reactive + let _componentDetailsCache: Record = {}; /** Reset all run-scoped state */ function resetRunScopedState() { prompts = { status: "uninitialized" }; - allTokens = { status: "uninitialized" }; interpretations = { status: "uninitialized" }; + intruderScores = { status: "uninitialized" }; + graphInterpLabels = { status: "uninitialized" }; activationContextsSummary = { status: "uninitialized" }; _componentDetailsCache = {}; clusterMapping = null; @@ -67,10 +73,17 @@ export function useRun() { function fetchRunScopedData() { prompts = { status: "loading" }; interpretations = { status: "loading" }; + intruderScores = { status: "loading" }; api.listPrompts() .then((p) => (prompts = { status: "loaded", data: p })) .catch((error) => (prompts = { status: "error", error })); + api.getIntruderScores() + .then((data) => (intruderScores = { status: "loaded", data })) + .catch((error) => (intruderScores = { status: "error", error })); + api.getAllGraphInterpLabels() + .then((data) => (graphInterpLabels = { status: "loaded", data })) + .catch((error) => (graphInterpLabels = { status: "error", error })); api.getAllInterpretations() .then((i) => { interpretations = { @@ -89,19 +102,11 @@ export function useRun() { .catch((error) => (interpretations = { status: "error", error })); } - /** Fetch tokens - must complete before run is considered loaded */ - async function fetchTokens(): Promise { - allTokens = { status: "loading" }; - const tokens = await api.getAllTokens(); - allTokens = { status: "loaded", data: tokens }; - return tokens; - } - async function loadRun(wandbPath: string, contextLength: number) { run = { status: "loading" }; try { await api.loadRun(wandbPath, contextLength); - const [status] = await Promise.all([api.getStatus(), fetchTokens()]); + const status = await api.getStatus(); if (status) { run = { status: "loaded", data: status }; fetchRunScopedData(); @@ -123,10 +128,6 @@ export function useRun() { try { const status = await api.getStatus(); if (status) { - // Fetch tokens if we don't have them (e.g., page refresh) - if (allTokens.status === "uninitialized") { - await fetchTokens(); - } run = { status: "loaded", data: status }; // Fetch other run-scoped data if we don't have it if (interpretations.status === "uninitialized") { @@ -163,6 +164,12 @@ export function useRun() { } } + /** Get intruder score for a component, if available */ + function getIntruderScore(componentKey: string): number | null { + if (intruderScores.status !== "loaded") return null; + return intruderScores.data[componentKey] ?? null; + } + /** Set interpretation for a component (updates cache without full reload) */ function setInterpretation(componentKey: string, interpretation: InterpretationBackendState) { if (interpretations.status === "loaded") { @@ -170,12 +177,12 @@ export function useRun() { } } - /** Get component detail (fetches once, then cached) */ - async function getComponentDetail(layer: string, cIdx: number): Promise { + /** Get activation context detail (fetches once, then cached) */ + async function getActivationContextDetail(layer: string, cIdx: number): Promise { const cacheKey = `${layer}:${cIdx}`; if (cacheKey in _componentDetailsCache) return _componentDetailsCache[cacheKey]; - const detail = await api.getComponentDetail(layer, cIdx); + const detail = await api.getActivationContextDetail(layer, cIdx); _componentDetailsCache[cacheKey] = detail; return detail; } @@ -205,6 +212,11 @@ export function useRun() { return clusterMapping?.data[key] ?? null; } + function getGraphInterpLabel(componentKey: string): GraphInterpHeadline | null { + if (graphInterpLabels.status !== "loaded") return null; + return graphInterpLabels.data[componentKey] ?? null; + } + return { get run() { return run; @@ -212,25 +224,36 @@ export function useRun() { get interpretations() { return interpretations; }, + get graphInterpLabels() { + return graphInterpLabels; + }, get clusterMapping() { return clusterMapping; }, get prompts() { return prompts; }, - get allTokens() { - return allTokens; - }, get activationContextsSummary() { return activationContextsSummary; }, + get datasetAttributionsAvailable() { + return run.status === "loaded" && run.data.dataset_attributions_available; + }, + get graphInterpAvailable() { + return run.status === "loaded" && run.data.graph_interp_available; + }, + get autoInterpAvailable() { + return run.status === "loaded" && run.data.autointerp_available; + }, loadRun, clearRun, syncStatus, refreshPrompts, getInterpretation, setInterpretation, - getComponentDetail, + getIntruderScore, + getGraphInterpLabel, + getActivationContextDetail, loadActivationContextsSummary, setClusterMapping, clearClusterMapping, diff --git a/spd/app/frontend/vite.config.ts b/spd/app/frontend/vite.config.ts index bfec3a4ab..a08d086fb 100644 --- a/spd/app/frontend/vite.config.ts +++ b/spd/app/frontend/vite.config.ts @@ -1,10 +1,20 @@ import { defineConfig } from "vite"; import { svelte } from "@sveltejs/vite-plugin-svelte"; +// BACKEND_URL is set by run_app.py when launching the dev server. +// Default to localhost:8000 for type checking and build (proxy only used during dev). +const backendUrl = process.env.BACKEND_URL || "http://localhost:8000"; + // https://vite.dev/config/ export default defineConfig({ plugins: [svelte()], - // server: { - // hmr: false, - // }, + server: { + hmr: false, + proxy: { + "/api": { + target: backendUrl, + changeOrigin: true, + }, + }, + }, }); diff --git a/spd/app/run_app.py b/spd/app/run_app.py index f21dec83d..c61174d1e 100755 --- a/spd/app/run_app.py +++ b/spd/app/run_app.py @@ -1,21 +1,27 @@ """ Development server launcher for SPD app. + Starts backend and frontend with: - Automatic port detection (with --strictPort for Vite) - - TCP-based health checks (no false negatives on 404) - - Graceful shutdown of process groups + - HTTP health checks that validate status codes (and optional content) + - Fail-fast if a child dies during startup + - Graceful shutdown (TERM -> KILL) of process groups - Clear logging & dependency checks """ +from __future__ import annotations + import atexit -import concurrent.futures import contextlib import os +import shutil import signal import socket import subprocess import sys import time +from collections.abc import Callable +from dataclasses import dataclass from datetime import datetime from enum import StrEnum from pathlib import Path @@ -40,14 +46,23 @@ class AnsiEsc(StrEnum): LOGS_DIR.mkdir(parents=True, exist_ok=True) LOGFILE = LOGS_DIR / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log" -STARTUP_TIMEOUT_SECONDS = 30 +DEFAULT_STARTUP_TIMEOUT_SECONDS = 90 BACKEND_DEFAULT_START = 8000 FRONTEND_DEFAULT_START = 5173 +def _require_bins(*bins: str) -> None: + missing = [b for b in bins if shutil.which(b) is None] + if missing: + print( + f"{AnsiEsc.RED}✗ Missing dependencies:{AnsiEsc.RESET} {', '.join(missing)}", + file=sys.stderr, + ) + sys.exit(1) + + def is_port_in_use(port: int) -> bool: """Best-effort check: try binding on loopback IPv4 and IPv6.""" - # IPv4 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s4: s4.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: @@ -55,7 +70,6 @@ def is_port_in_use(port: int) -> bool: except OSError: return True - # IPv6 (if available) try: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s6: s6.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -64,7 +78,6 @@ def is_port_in_use(port: int) -> bool: except OSError: return True except OSError: - # IPv6 not supported; ignore pass return False @@ -82,6 +95,13 @@ def find_available_port(start_port: int) -> int: sys.exit(1) +def _tcp_open(host: str, port: int, timeout: float = 0.5) -> bool: + """Returns True if a TCP connection can be established.""" + with contextlib.suppress(OSError), socket.create_connection((host, port), timeout=timeout): + return True + return False + + def _spawn( cmd: list[str], cwd: Path, @@ -90,16 +110,14 @@ def _spawn( ) -> subprocess.Popen[str]: """Spawn a process in its own process group, streaming stdout/stderr to logfile.""" try: - # Use preexec_fn to set process group on Unix systems - # This allows us to kill the entire process tree later return subprocess.Popen( cmd, cwd=str(cwd), stdout=logfile, stderr=subprocess.STDOUT, bufsize=1, - text=True, # modern alias for universal_newlines=True - preexec_fn=os.setpgrp, # Create new process group + text=True, + preexec_fn=os.setpgrp, env=env, ) except FileNotFoundError as e: @@ -111,34 +129,155 @@ def _spawn( sys.exit(1) +@dataclass(frozen=True) +class HealthCheck: + url: str + ok_statuses: set[int] + timeout: float = 1.0 + headers: dict[str, str] | None = None + allow_redirects: bool = False + body_predicate: Callable[[requests.Response], bool] | None = None + + class AppRunner: """Manages backend and frontend processes with proper cleanup on signals.""" - def __init__(self): + def __init__(self, startup_timeout_seconds: int): self.backend_process: subprocess.Popen[str] | None = None self.frontend_process: subprocess.Popen[str] | None = None self.cleanup_called = False + self.startup_timeout_seconds = startup_timeout_seconds + + self._session = requests.Session() + self._session.headers.update({"User-Agent": "spd-dev-launcher/1.0"}) + + def _kill_process_group(self, proc: subprocess.Popen[str], sig: int) -> None: + if proc.poll() is not None: + return + with contextlib.suppress(ProcessLookupError, PermissionError, OSError): + os.killpg(os.getpgid(proc.pid), sig) + + def cleanup(self) -> None: + """Cleanup all running processes (process groups).""" + if self.cleanup_called: + return + self.cleanup_called = True - def wait_to_serve(self, port: int, name: str, pid: int | None = None) -> None: + print("\nShutting down...", flush=True) + + procs = [p for p in (self.backend_process, self.frontend_process) if p] + # Try graceful first + for p in procs: + self._kill_process_group(p, signal.SIGTERM) + + # Wait briefly + deadline = time.time() + 2.0 + for p in procs: + if not p: + continue + remaining = max(0.0, deadline - time.time()) + with contextlib.suppress(subprocess.TimeoutExpired): + p.wait(timeout=remaining) + + # Force kill if still alive + for p in procs: + if p and p.poll() is None: + self._kill_process_group(p, signal.SIGKILL) + with contextlib.suppress(subprocess.TimeoutExpired): + p.wait(timeout=0.5) + + def _fail_child_died(self, name: str) -> None: + print( + f"\n{AnsiEsc.RED}✗{AnsiEsc.RESET} {name} process died unexpectedly", + file=sys.stderr, + ) + print(f"{AnsiEsc.DIM}Check {LOGFILE} for details{AnsiEsc.RESET}", file=sys.stderr) + sys.exit(1) + + def wait_http_ready( + self, + *, + checks: list[HealthCheck], + name: str, + port_for_tcp_hint: int, + proc_getter: Callable[[], subprocess.Popen[str] | None], + pid: int | None = None, + ) -> None: + """ + Wait until ANY check passes. Validates HTTP status codes (and optional body predicate). + Also checks for child liveness while waiting. + """ start = time.time() - while time.time() < (start + STARTUP_TIMEOUT_SECONDS): - try: - response = requests.get(f"http://localhost:{port}/api/health", timeout=1.0) - if response.status_code == 200: - # Print success message immediately when ready - if pid is not None: - print( - f" {AnsiEsc.GREEN}✓{AnsiEsc.RESET} {name} started (pid {pid}, port {port})" - ) - return - except requests.RequestException: - pass - time.sleep(0.5) + last_error: str | None = None + last_status: int | None = None + last_url: str | None = None + last_body_snip: str | None = None + + while time.time() < (start + self.startup_timeout_seconds): + proc = proc_getter() + if proc and proc.poll() is not None: + self._fail_child_died(name) + + # TCP hint first to reduce noisy connect exceptions + if not _tcp_open("localhost", port_for_tcp_hint, timeout=0.25): + time.sleep(0.25) + continue + + for hc in checks: + try: + resp = self._session.get( + hc.url, + timeout=hc.timeout, + headers=hc.headers, + allow_redirects=hc.allow_redirects, + ) + last_url = hc.url + last_status = resp.status_code + last_body_snip = resp.text[:200].replace("\n", "\\n") + + if resp.status_code in hc.ok_statuses: + if hc.body_predicate and not hc.body_predicate(resp): + last_error = "body predicate failed" + continue + + if pid is not None: + print( + f" {AnsiEsc.GREEN}✓{AnsiEsc.RESET} {name} started {AnsiEsc.DIM}(pid {pid}){AnsiEsc.RESET}" + ) + return + + last_error = f"unexpected status {resp.status_code}" + except requests.RequestException as e: + last_error = f"request error: {type(e).__name__}: {e}" + + time.sleep(0.4) + + # Timeout diagnostics print(f"{AnsiEsc.RED}✗{AnsiEsc.RESET} {name} healthcheck failed", file=sys.stderr) + if last_url is not None: + print( + f"{AnsiEsc.DIM}Last check:{AnsiEsc.RESET} {last_url}", + file=sys.stderr, + ) + if last_status is not None: + print( + f"{AnsiEsc.DIM}Last status:{AnsiEsc.RESET} {last_status}", + file=sys.stderr, + ) + if last_error is not None: + print( + f"{AnsiEsc.DIM}Last error:{AnsiEsc.RESET} {last_error}", + file=sys.stderr, + ) + if last_body_snip: + print( + f"{AnsiEsc.DIM}Body snippet:{AnsiEsc.RESET} {last_body_snip}", + file=sys.stderr, + ) + print(f"{AnsiEsc.DIM}Check {LOGFILE} for details{AnsiEsc.RESET}", file=sys.stderr) sys.exit(1) def spawn_backend(self, port: int, logfile: TextIO) -> subprocess.Popen[str]: - """Spawn backend process without waiting for it to be ready.""" project_root = APP_DIR.parent.parent cmd = [ "uv", @@ -150,65 +289,28 @@ def spawn_backend(self, port: int, logfile: TextIO) -> subprocess.Popen[str]: str(port), ] proc = _spawn(cmd, cwd=project_root, env=None, logfile=logfile) - self.backend_process = proc # Immediately visible to signal handler + self.backend_process = proc return proc def spawn_frontend( self, port: int, backend_port: int, logfile: TextIO ) -> subprocess.Popen[str]: - """Spawn frontend process without waiting for it to be ready.""" env = os.environ.copy() - env["VITE_API_URL"] = f"http://localhost:{backend_port}" - # strictPort = fail-fast if port is taken (so our "did it die?" check works) + env["BACKEND_URL"] = f"http://localhost:{backend_port}" cmd = ["npm", "run", "dev", "--", "--port", str(port), "--strictPort"] proc = _spawn(cmd, cwd=APP_DIR / "frontend", env=env, logfile=logfile) - self.frontend_process = proc # Immediately visible to signal handler + self.frontend_process = proc return proc - def cleanup(self) -> None: - """Cleanup all running processes (process groups).""" - if self.cleanup_called: - return - self.cleanup_called = True - - print("\n👋 Shutting down...", flush=True) - - # Kill all process groups immediately with SIGKILL for reliability - for proc in (self.backend_process, self.frontend_process): - if proc and proc.poll() is None: - # Kill the process group - with contextlib.suppress(ProcessLookupError, PermissionError, OSError): - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - - # Also kill the direct process as fallback - with contextlib.suppress(ProcessLookupError, PermissionError, OSError): - proc.kill() - - # Brief wait for processes to die - for proc in (self.backend_process, self.frontend_process): - if proc: - with contextlib.suppress(subprocess.TimeoutExpired): - proc.wait(timeout=0.3) - def monitor_child_liveness(self) -> None: - log_lines_to_show = 5 + log_lines_to_show = 20 prev_lines: list[str] = [] while True: if self.backend_process and self.backend_process.poll() is not None: - print( - f"\n{AnsiEsc.RED}✗{AnsiEsc.RESET} Backend process died unexpectedly", - file=sys.stderr, - ) - print(f"{AnsiEsc.DIM}Check {LOGFILE} for details{AnsiEsc.RESET}", file=sys.stderr) - sys.exit(1) + self._fail_child_died("Backend") if self.frontend_process and self.frontend_process.poll() is not None: - print( - f"\n{AnsiEsc.RED}✗{AnsiEsc.RESET} Frontend process died unexpectedly", - file=sys.stderr, - ) - print(f"{AnsiEsc.DIM}Check {LOGFILE} for details{AnsiEsc.RESET}", file=sys.stderr) - sys.exit(1) + self._fail_child_died("Frontend") # Show last N lines of logs in a box try: @@ -222,19 +324,10 @@ def monitor_child_liveness(self) -> None: lines_to_clear = len(prev_lines) + 2 print(f"\033[{lines_to_clear}A\033[J", end="") - # Print box with tail - local_logfile = LOGFILE.relative_to(os.getcwd()) - print( - f"{AnsiEsc.DIM}┌─ logs ({local_logfile}) {'─' * (60 - len(str(local_logfile)))}{AnsiEsc.RESET}" - ) + print(f"{AnsiEsc.DIM}┌─ logs {'─' * 32}{AnsiEsc.RESET}") for line in tail: - clipped_line = ( - line.rstrip()[:100] + "..." - if len(line.rstrip()) > 100 - else line.rstrip() - ) - print(f"{AnsiEsc.DIM}│ {clipped_line}{AnsiEsc.RESET}") - print(f"{AnsiEsc.DIM}└{'─' * 80}{AnsiEsc.RESET}") + print(f"{AnsiEsc.DIM}│ {line.rstrip()}{AnsiEsc.RESET}") + print(f"{AnsiEsc.DIM}└{'─' * 40}{AnsiEsc.RESET}") prev_lines = tail except FileNotFoundError: @@ -243,50 +336,76 @@ def monitor_child_liveness(self) -> None: time.sleep(1.0) def run(self) -> None: - """Main entry point to run the development servers.""" + """Launch the backend and frontend development servers.""" print(f"{AnsiEsc.DIM}Logfile: {LOGFILE}{AnsiEsc.RESET}") print(f"{AnsiEsc.DIM}Finding available ports...{AnsiEsc.RESET}") - backend_port = find_available_port(BACKEND_DEFAULT_START) - frontend_port = find_available_port(FRONTEND_DEFAULT_START) - print(f" - {AnsiEsc.DIM}Backend port: {backend_port}{AnsiEsc.RESET}") - print(f" - {AnsiEsc.DIM}Frontend port: {frontend_port}{AnsiEsc.RESET}") - print() + bport = find_available_port(BACKEND_DEFAULT_START) + fport = find_available_port(FRONTEND_DEFAULT_START) - print(f"{AnsiEsc.BOLD}🚀 Starting development servers{AnsiEsc.RESET}") + print(f" - {AnsiEsc.DIM}Backend port: {bport}{AnsiEsc.RESET}") + print(f" - {AnsiEsc.DIM}Frontend port: {fport}{AnsiEsc.RESET}\n") + + print(f"{AnsiEsc.BOLD}Starting development servers{AnsiEsc.RESET}") print(f"{AnsiEsc.DIM}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━{AnsiEsc.RESET}") - # Open logfile for streaming child output with open(LOGFILE, "a", buffering=1, encoding="utf-8") as logfile: - # Spawn both processes in parallel - print(f" {AnsiEsc.DIM}▸ Spawning backend and frontend...{AnsiEsc.RESET}") - backend_proc = self.spawn_backend(backend_port, logfile) - frontend_proc = self.spawn_frontend(frontend_port, backend_port, logfile) - - # Wait for both to be ready in parallel - print(f" {AnsiEsc.DIM}▸ Waiting for servers to be ready...{AnsiEsc.RESET}") - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - backend_future = executor.submit( - self.wait_to_serve, backend_port, "Backend", backend_proc.pid - ) - frontend_future = executor.submit( - self.wait_to_serve, frontend_port, "Frontend", frontend_proc.pid + check_host = "localhost" + + # Start backend first and wait for it to be ready + print(f" {AnsiEsc.DIM}▸ Spawning backend...{AnsiEsc.RESET}") + backend_proc = self.spawn_backend(bport, logfile) + + backend_checks = [ + HealthCheck( + url=f"http://{check_host}:{bport}/api/health", + ok_statuses={200}, + timeout=1.0, ) - # Wait for both to complete (will raise if either fails) - concurrent.futures.wait([backend_future, frontend_future]) - backend_future.result() - frontend_future.result() + ] + + self.wait_http_ready( + checks=backend_checks, + name="Backend", + port_for_tcp_hint=bport, + proc_getter=lambda: self.backend_process, + pid=backend_proc.pid, + ) + + # Start frontend after backend is ready + print(f" {AnsiEsc.DIM}▸ Spawning frontend...{AnsiEsc.RESET}") + frontend_proc = self.spawn_frontend(fport, bport, logfile) + + frontend_checks = [ + HealthCheck( + url=f"http://{check_host}:{fport}/", + ok_statuses={200, 204, 301, 302, 304}, + timeout=1.0, + allow_redirects=True, + ), + HealthCheck( + url=f"http://{check_host}:{fport}/@vite/client", + ok_statuses={200, 304}, + timeout=1.0, + allow_redirects=True, + ), + ] + + self.wait_http_ready( + checks=frontend_checks, + name="Frontend", + port_for_tcp_hint=fport, + proc_getter=lambda: self.frontend_process, + pid=frontend_proc.pid, + ) print(f"{AnsiEsc.DIM}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━{AnsiEsc.RESET}\n") + time.sleep(0.1) - # Success banner - print(f"{AnsiEsc.GREEN}{AnsiEsc.BOLD}✓ Ready!{AnsiEsc.RESET}\n") - print(f"{AnsiEsc.DIM}Backend http://localhost:{backend_port}/{AnsiEsc.RESET}") print( - f"{AnsiEsc.BOLD}Frontend {AnsiEsc.GREEN}{AnsiEsc.BOLD}{AnsiEsc.UNDERLINE}http://localhost:{frontend_port}/{AnsiEsc.RESET}\n" + f"{AnsiEsc.BOLD}Ready: {AnsiEsc.GREEN}{AnsiEsc.UNDERLINE}http://localhost:{fport}/{AnsiEsc.RESET}\n" ) - # Monitor child liveness self.monitor_child_liveness() @@ -295,21 +414,19 @@ def main() -> None: with open(LOGFILE, "w", encoding="utf-8") as lf: lf.write(f"Launcher started at {datetime.now().isoformat()}\n") - # Create runner and register signal handlers - runner = AppRunner() + _require_bins("uv", "npm") + + runner = AppRunner(startup_timeout_seconds=DEFAULT_STARTUP_TIMEOUT_SECONDS) def signal_handler(_signum: int, _frame: FrameType | None) -> None: - """Handle termination signals by cleaning up and exiting.""" runner.cleanup() sys.exit(0) - # Register cleanup handlers atexit.register(runner.cleanup) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGHUP, signal_handler) - # Run the app runner.run() diff --git a/spd/autointerp/CLAUDE.md b/spd/autointerp/CLAUDE.md index 08ad5ae42..f22e1f376 100644 --- a/spd/autointerp/CLAUDE.md +++ b/spd/autointerp/CLAUDE.md @@ -5,49 +5,78 @@ LLM-based automated interpretation of SPD components. Consumes pre-harvested dat ## Usage ```bash -# Run interpretation (requires harvest data to exist first) -python -m spd.autointerp.scripts.run_interpret --model google/gemini-3-flash-preview +# Run interpretation with a config file +python -m spd.autointerp.scripts.run_interpret --config_path path/to/config.yaml + +# Run with inline config JSON +python -m spd.autointerp.scripts.run_interpret --config_json '{"model": "google/gemini-3-flash-preview", "reasoning_effort": null}' # Or via SLURM spd-autointerp +spd-autointerp --model google/gemini-3-flash-preview --reasoning_effort medium ``` Requires `OPENROUTER_API_KEY` env var. ## Data Storage -Data is stored in `SPD_OUT_DIR/autointerp/` (see `spd/settings.py`): +Each autointerp subrun has its own SQLite database: ``` -SPD_OUT_DIR/autointerp// -└── results.jsonl # One InterpretationResult per line (append-only for resume) +SPD_OUT_DIR/autointerp// +└── / # e.g. a-20260206_153040 + ├── interp.db # SQLite DB: interpretations + scores (WAL mode) + └── config.yaml # AutointerpConfig (for reproducibility) ``` +`InterpRepo` reads from the latest subrun (by lexicographic sort of `a-*` dir names). + +The `interp.db` schema has three tables: +- `interpretations`: component_key -> label, confidence, reasoning, raw_response, prompt +- `scores`: (component_key, score_type) -> score, details (JSON blob with trial data) +- `config`: key-value store + +Score types: `detection`, `fuzzing`. + +**Note on intruder scores**: Intruder evaluation lives in `spd/harvest/` (not here) because it tests decomposition quality, not label quality. Intruder scores are stored in `harvest.db`. Detection and fuzzing evaluate interpretation labels and belong here. + ## Architecture +### Config (`config.py`) + +`AutointerpConfig` is a discriminated union over interpretation strategy configs. Each variant specifies everything that affects interpretation output (model, prompt params, reasoning effort). Admin/execution params (cost limits, parallelism) are NOT part of the config. + +Current strategies: +- `CompactSkepticalConfig` — compact prompt, skeptical tone, structured JSON output + +Also contains `AutointerpEvalConfig` for eval jobs (detection, fuzzing). + +### Strategies (`strategies/`) + +Each strategy config type has a corresponding prompt implementation: +- `strategies/compact_skeptical.py` — prompt formatting for `CompactSkepticalConfig` +- `strategies/dispatch.py` — routes `AutointerpConfig` → strategy implementation via `match` + +### Database (`db.py`) + +`InterpDB` class wrapping SQLite for interpretations and scores. Uses WAL mode for concurrent reads. Serialization via `orjson`. + +### Repository (`repo.py`) + +`InterpRepo` provides read/write access to autointerp data for a run. Lazily opens the SQLite database on first access. Used by the app backend. + ### Interpret (`interpret.py`) - Uses OpenRouter API with structured JSON outputs - Maximum parallelism with exponential backoff on rate limits -- Resume support: Skips already-completed components on restart -- Progress bar via `tqdm_asyncio` - -### Prompt Template (`prompt_template.py`) - -Jinja2 template providing the LLM with: -- Architecture context (model class, layer position, dataset) -- Activation examples with CI values -- Token statistics (PMI for input and output tokens) -- Co-occurring components +- Resume support: Skips already-completed components via `db.get_completed_keys()` +- Progress logging via `spd.log.logger` +- `interpret_component()` interprets a single component +- `run_interpret()` orchestrates batch interpretation with resume support ## Key Types (`schemas.py`) ```python InterpretationResult # LLM's label + confidence + reasoning +ArchitectureInfo # Model architecture context for prompts ``` - -## Status - -Early stage. Primary next steps: -- Eval harness for interpretations (precision/recall via LLM activation simulator) -- Integration with app UI to display labels diff --git a/spd/autointerp/__init__.py b/spd/autointerp/__init__.py index e5225e25b..2039f4daf 100644 --- a/spd/autointerp/__init__.py +++ b/spd/autointerp/__init__.py @@ -1,3 +1 @@ """Auto-interpretation pipeline for SPD components.""" - -MAX_EXAMPLES_PER_COMPONENT = 30 diff --git a/spd/autointerp/config.py b/spd/autointerp/config.py new file mode 100644 index 000000000..4d60a1705 --- /dev/null +++ b/spd/autointerp/config.py @@ -0,0 +1,110 @@ +"""Autointerp configuration. + +CompactSkepticalConfig: interpretation strategy config. +AutointerpEvalConfig: eval job config (detection, fuzzing). +AutointerpSlurmConfig: CompactSkepticalConfig + eval + SLURM submission params. +""" + +from typing import Annotated, Literal + +from openrouter.components import Effort +from pydantic import Field + +from spd.base_config import BaseConfig +from spd.settings import DEFAULT_PARTITION_NAME + +FORBIDDEN_WORDS_DEFAULT = [ + "narrative", + "story", + "character", + "theme", + "descriptive", + "content", + "transition", + "scene", +] + + +class CompactSkepticalConfig(BaseConfig): + """Current default strategy: compact prompt, skeptical tone, structured JSON output.""" + + type: Literal["compact_skeptical"] = "compact_skeptical" + max_examples: int = 30 + include_pmi: bool = True + include_spd_context: bool = True + include_dataset_description: bool = True + label_max_words: int = 8 + forbidden_words: list[str] | None = None + + +class DualViewConfig(BaseConfig): + """Dual-view strategy: presents both input and output evidence with dual example views. + + Key differences from compact_skeptical: + - Output data presented first + - Two example sections: "fires on" (current token) and "produces" (next token) + - Task asks for functional description, not detection label + """ + + type: Literal["dual_view"] = "dual_view" + max_examples: int = 30 + include_pmi: bool = True + include_dataset_description: bool = True + label_max_words: int = 8 + forbidden_words: list[str] | None = None + + +StrategyConfig = CompactSkepticalConfig | DualViewConfig + + +class AutointerpConfig(BaseConfig): + model: str = "google/gemini-3-flash-preview" + reasoning_effort: Effort = "low" + limit: int | None = None + cost_limit_usd: float | None = None + max_requests_per_minute: int = 500 + max_concurrent: int = 50 + template_strategy: Annotated[StrategyConfig, Field(discriminator="type")] + + +class DetectionEvalConfig(BaseConfig): + type: Literal["detection"] = "detection" + n_activating: int = 5 + n_non_activating: int = 5 + n_trials: int = 5 + + +class FuzzingEvalConfig(BaseConfig): + type: Literal["fuzzing"] = "fuzzing" + n_correct: int = 5 + n_incorrect: int = 2 + n_trials: int = 5 + + +class AutointerpEvalConfig(BaseConfig): + """Config for label-based autointerp evals (detection, fuzzing).""" + + model: str = "google/gemini-3-flash-preview" + reasoning_effort: Effort = "none" + detection_config: DetectionEvalConfig + fuzzing_config: FuzzingEvalConfig + limit: int | None = None + cost_limit_usd: float | None = None + max_requests_per_minute: int = 500 + max_concurrent: int = 50 + + +class AutointerpSlurmConfig(BaseConfig): + """Config for the autointerp functional unit (interpret + evals). + + Dependency graph within autointerp: + interpret (depends on harvest merge) + ├── detection (depends on interpret) + └── fuzzing (depends on interpret) + """ + + config: AutointerpConfig + partition: str = DEFAULT_PARTITION_NAME + time: str = "12:00:00" + evals: AutointerpEvalConfig | None + evals_time: str = "12:00:00" diff --git a/spd/autointerp/db.py b/spd/autointerp/db.py new file mode 100644 index 000000000..f05227f5c --- /dev/null +++ b/spd/autointerp/db.py @@ -0,0 +1,154 @@ +"""SQLite database for autointerp data (interpretations and scores). NFS-hosted, single writer then read-only.""" + +from pathlib import Path + +import orjson + +from spd.autointerp.schemas import InterpretationResult +from spd.utils.sqlite import open_nfs_sqlite + +_SCHEMA = """\ +CREATE TABLE IF NOT EXISTS interpretations ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS scores ( + component_key TEXT NOT NULL, + score_type TEXT NOT NULL, + score REAL NOT NULL, + details TEXT NOT NULL, + PRIMARY KEY (component_key, score_type) +); + +CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); +""" + +DONE_MARKER = ".done" + + +class InterpDB: + def __init__(self, db_path: Path, readonly: bool = False) -> None: + self._conn = open_nfs_sqlite(db_path, readonly) + if not readonly: + self._conn.executescript(_SCHEMA) + self._db_path = db_path + + def mark_done(self) -> None: + (self._db_path.parent / DONE_MARKER).touch() + + def save_interpretation(self, result: InterpretationResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO interpretations VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def save_interpretations(self, results: list[InterpretationResult]) -> None: + rows = [ + (r.component_key, r.label, r.confidence, r.reasoning, r.raw_response, r.prompt) + for r in results + ] + self._conn.executemany( + "INSERT OR REPLACE INTO interpretations VALUES (?, ?, ?, ?, ?, ?)", + rows, + ) + self._conn.commit() + + def get_interpretation(self, component_key: str) -> InterpretationResult | None: + row = self._conn.execute( + "SELECT * FROM interpretations WHERE component_key = ?", + (component_key,), + ).fetchone() + if row is None: + return None + return InterpretationResult( + component_key=row["component_key"], + label=row["label"], + confidence=row["confidence"], + reasoning=row["reasoning"], + raw_response=row["raw_response"], + prompt=row["prompt"], + ) + + def get_all_interpretations(self) -> dict[str, InterpretationResult]: + rows = self._conn.execute("SELECT * FROM interpretations").fetchall() + return { + row["component_key"]: InterpretationResult( + component_key=row["component_key"], + label=row["label"], + confidence=row["confidence"], + reasoning=row["reasoning"], + raw_response=row["raw_response"], + prompt=row["prompt"], + ) + for row in rows + } + + def get_completed_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM interpretations").fetchall() + return {row["component_key"] for row in rows} + + def save_score(self, component_key: str, score_type: str, score: float, details: str) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO scores VALUES (?, ?, ?, ?)", + (component_key, score_type, score, details), + ) + self._conn.commit() + + def save_scores(self, score_type: str, scores: list[tuple[str, float, str]]) -> None: + rows = [(key, score_type, score, details) for key, score, details in scores] + self._conn.executemany( + "INSERT OR REPLACE INTO scores VALUES (?, ?, ?, ?)", + rows, + ) + self._conn.commit() + + def get_scores(self, score_type: str) -> dict[str, float]: + rows = self._conn.execute( + "SELECT component_key, score FROM scores WHERE score_type = ?", + (score_type,), + ).fetchall() + return {row["component_key"]: row["score"] for row in rows} + + def save_config(self, key: str, value: object) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO config VALUES (?, ?)", + (key, orjson.dumps(value).decode()), + ) + self._conn.commit() + + def has_interpretations(self) -> bool: + row = self._conn.execute("SELECT EXISTS(SELECT 1 FROM interpretations LIMIT 1)").fetchone() + assert row is not None + return bool(row[0]) + + def get_interpretation_count(self) -> int: + row = self._conn.execute("SELECT COUNT(*) FROM interpretations").fetchone() + assert row is not None + return row[0] + + def has_scores(self, score_type: str) -> bool: + row = self._conn.execute( + "SELECT EXISTS(SELECT 1 FROM scores WHERE score_type = ? LIMIT 1)", + (score_type,), + ).fetchone() + assert row is not None + return bool(row[0]) + + def close(self) -> None: + self._conn.close() diff --git a/spd/autointerp/interpret.py b/spd/autointerp/interpret.py index 5e38e4985..1d4ba7105 100644 --- a/spd/autointerp/interpret.py +++ b/spd/autointerp/interpret.py @@ -1,385 +1,196 @@ import asyncio import json -import random -import time -from dataclasses import asdict, dataclass -from enum import StrEnum +from collections.abc import Iterable from pathlib import Path -import httpx from openrouter import OpenRouter -from openrouter.components import JSONSchemaConfig, MessageTypedDict, ResponseFormatJSONSchema -from openrouter.errors import ( - BadGatewayResponseError, - ChatError, - EdgeNetworkTimeoutResponseError, - ProviderOverloadedResponseError, - RequestTimeoutResponseError, - ServiceUnavailableResponseError, - TooManyRequestsResponseError, +from openrouter.components import Effort, Reasoning + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.config import StrategyConfig +from spd.autointerp.db import InterpDB +from spd.autointerp.llm_api import ( + LLMError, + LLMJob, + LLMResult, + make_response_format, + map_llm_calls, ) -from tqdm.asyncio import tqdm_asyncio -from transformers import AutoTokenizer -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from spd.app.backend.compute import get_model_n_blocks -from spd.autointerp.prompt_template import INTERPRETATION_SCHEMA, format_prompt_template -from spd.autointerp.schemas import ArchitectureInfo, InterpretationResult -from spd.configs import LMTaskConfig +from spd.autointerp.schemas import InterpretationResult, ModelMetadata +from spd.autointerp.strategies.dispatch import INTERPRETATION_SCHEMA, format_prompt from spd.harvest.analysis import TokenPRLift, get_input_token_stats, get_output_token_stats -from spd.harvest.harvest import HarvestResult +from spd.harvest.repo import HarvestRepo from spd.harvest.schemas import ComponentData -from spd.harvest.storage import TokenStatsStorage from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo - -# Retry config -MAX_RETRIES = 8 -BASE_DELAY_S = 0.5 -MAX_DELAY_S = 60.0 -JITTER_FACTOR = 0.5 -MAX_CONCURRENT_REQUESTS = 50 -MAX_REQUESTS_PER_MINUTE = 300 # Gemini flash has 400 RPM limit - - -class RateLimiter: - """Sliding window rate limiter for async code.""" - - def __init__(self, max_requests: int, period_seconds: float): - self.max_requests = max_requests - self.period = period_seconds - self.timestamps: list[float] = [] - self.lock = asyncio.Lock() - - async def acquire(self) -> None: - async with self.lock: - now = time.monotonic() - self.timestamps = [t for t in self.timestamps if now - t < self.period] - - if len(self.timestamps) >= self.max_requests: - sleep_time = self.timestamps[0] + self.period - now - if sleep_time > 0: - await asyncio.sleep(sleep_time) - self.timestamps = self.timestamps[1:] - - self.timestamps.append(time.monotonic()) - - -RETRYABLE_ERRORS = ( - TooManyRequestsResponseError, - ProviderOverloadedResponseError, - ServiceUnavailableResponseError, - BadGatewayResponseError, - RequestTimeoutResponseError, - EdgeNetworkTimeoutResponseError, - ChatError, - httpx.TransportError, # Low-level network errors (ReadError, ConnectError, etc.) -) - - -class OpenRouterModelName(StrEnum): - GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview" - - -@dataclass -class CostTracker: - input_tokens: int = 0 - output_tokens: int = 0 - input_price_per_token: float = 0.0 - output_price_per_token: float = 0.0 - - def add(self, input_tokens: int, output_tokens: int) -> None: - self.input_tokens += input_tokens - self.output_tokens += output_tokens - - def cost_usd(self) -> float: - return ( - self.input_tokens * self.input_price_per_token - + self.output_tokens * self.output_price_per_token - ) - - -async def chat_with_retry( - client: OpenRouter, - model: str, - messages: list[MessageTypedDict], - response_format: ResponseFormatJSONSchema, - max_tokens: int, - context_label: str, -) -> tuple[str, int, int]: - """Send chat request with exponential backoff retry. Returns (content, input_tokens, output_tokens).""" - last_error: Exception | None = None - for attempt in range(MAX_RETRIES): - try: - response = await client.chat.send_async( - model=model, - max_tokens=max_tokens, - messages=messages, - response_format=response_format, - ) - choice = response.choices[0] - message = choice.message - assert isinstance(message.content, str) - assert response.usage is not None - - if choice.finish_reason == "length": - logger.warning(f"{context_label}: Response truncated at {max_tokens} tokens") - - return ( - message.content, - int(response.usage.prompt_tokens), - int(response.usage.completion_tokens), - ) - except RETRYABLE_ERRORS as e: - last_error = e - if attempt == MAX_RETRIES - 1: - break - - delay = min(BASE_DELAY_S * (2**attempt), MAX_DELAY_S) - jitter = delay * JITTER_FACTOR * random.random() - total_delay = delay + jitter - - tqdm_asyncio.write( - f"[retry {attempt + 1}/{MAX_RETRIES}] ({context_label}) " - f"{type(e).__name__}, backing off {total_delay:.1f}s" - ) - await asyncio.sleep(total_delay) - assert last_error is not None - raise RuntimeError(f"Max retries exceeded for {context_label}: {last_error}") - - -async def get_model_pricing(client: OpenRouter, model_id: str) -> tuple[float, float]: - """Returns (input_price, output_price) per token.""" - response = await client.models.list_async() - for model in response.data: - if model.id == model_id: - return float(model.pricing.prompt), float(model.pricing.completion) - raise ValueError(f"Model {model_id} not found") +MAX_CONCURRENT = 50 async def interpret_component( - client: OpenRouter, + api: OpenRouter, model: str, + reasoning_effort: Effort, + strategy: StrategyConfig, component: ComponentData, - arch: ArchitectureInfo, - tokenizer: PreTrainedTokenizerBase, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, input_token_stats: TokenPRLift, output_token_stats: TokenPRLift, -) -> tuple[InterpretationResult, int, int] | None: - """Returns (result, input_tokens, output_tokens), or None on failure.""" - prompt = format_prompt_template( +) -> InterpretationResult: + """Interpret a single component. Used by the app for on-demand interpretation.""" + prompt = format_prompt( + strategy=strategy, component=component, - arch=arch, - tokenizer=tokenizer, + model_metadata=model_metadata, + app_tok=app_tok, input_token_stats=input_token_stats, output_token_stats=output_token_stats, ) - try: - raw, in_tok, out_tok = await chat_with_retry( - client=client, - model=model, - messages=[{"role": "user", "content": prompt}], - response_format=ResponseFormatJSONSchema( - json_schema=JSONSchemaConfig( - name="interpretation", - schema_={**INTERPRETATION_SCHEMA, "additionalProperties": False}, - strict=True, - ) - ), - max_tokens=1500, - context_label=component.component_key, - ) - except RuntimeError as e: - logger.error(str(e)) - return None + schema = INTERPRETATION_SCHEMA + response_format = make_response_format("interpretation", schema) + + response = await api.chat.send_async( + model=model, + max_tokens=8000, + messages=[{"role": "user", "content": prompt}], + response_format=response_format, + reasoning=Reasoning(effort=reasoning_effort), + ) - try: - parsed = json.loads(raw) - except json.JSONDecodeError: - logger.error(f"Failed to parse JSON: `{raw}`") - return None + choice = response.choices[0] + assert isinstance(choice.message.content, str) + raw = choice.message.content + parsed = json.loads(raw) - assert len(parsed) == 3, f"Expected 3 fields, got {len(parsed)}" + assert len(parsed) == 3, f"Expected 3 fields, got {parsed}" label = parsed["label"] confidence = parsed["confidence"] - reasoning = parsed["reasoning"] - assert isinstance(label, str) and isinstance(confidence, str) and isinstance(reasoning, str) + reasoning_text = parsed["reasoning"] + assert ( + isinstance(label, str) and isinstance(confidence, str) and isinstance(reasoning_text, str) + ) - return ( - InterpretationResult( - component_key=component.component_key, - label=label, - confidence=confidence, - reasoning=reasoning, - raw_response=raw, - prompt=prompt, - ), - in_tok, - out_tok, + return InterpretationResult( + component_key=component.component_key, + label=label, + confidence=confidence, + reasoning=reasoning_text, + raw_response=raw, + prompt=prompt, ) -async def interpret_all( - components: list[ComponentData], - arch: ArchitectureInfo, +def run_interpret( openrouter_api_key: str, - interpreter_model: str, - output_path: Path, - token_stats: TokenStatsStorage, - limit: int | None = None, + model: str, + reasoning_effort: Effort, + limit: int | None, + cost_limit_usd: float | None, + max_requests_per_minute: int, + max_concurrent: int, + model_metadata: ModelMetadata, + template_strategy: StrategyConfig, + harvest: HarvestRepo, + db_path: Path, + tokenizer_name: str, ) -> list[InterpretationResult]: - """Interpret all components with maximum parallelism. Rate limits handled via exponential backoff.""" - results: list[InterpretationResult] = [] - completed = set[str]() + summary = harvest.get_summary() + logger.info(f"Loaded summary for {len(summary)} components") - if output_path.exists(): - print(f"Resuming: {output_path} exists") - with open(output_path) as f: - for line in f: - data = json.loads(line) - results.append(InterpretationResult(**data)) - completed.add(data["component_key"]) - print(f"Resuming: {len(completed)} already completed") + token_stats = harvest.get_token_stats() + assert token_stats is not None, "token_stats.pt not found. Run harvest first." - components_sorted = sorted(components, key=lambda c: c.mean_ci, reverse=True) - remaining = [c for c in components_sorted if c.component_key not in completed] - if limit is not None: - remaining = remaining[:limit] - print(f"Interpreting {len(remaining)} components") - start_idx = len(results) - - output_lock = asyncio.Lock() - semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) - rate_limiter = RateLimiter(MAX_REQUESTS_PER_MINUTE, period_seconds=60.0) + app_tok = AppTokenizer.from_pretrained(tokenizer_name) - tokenizer = AutoTokenizer.from_pretrained(arch.tokenizer_name) - assert isinstance(tokenizer, PreTrainedTokenizerBase) + eligible_keys = sorted(summary, key=lambda k: summary[k].firing_density, reverse=True) - async def process_one( - component: ComponentData, - index: int, - client: OpenRouter, - cost_tracker: CostTracker, - ) -> None: - await rate_limiter.acquire() - async with semaphore: - try: - # Compute token stats for this component - input_stats = get_input_token_stats( - token_stats, component.component_key, tokenizer, top_k=20 - ) - output_stats = get_output_token_stats( - token_stats, component.component_key, tokenizer, top_k=50 - ) - assert input_stats is not None, ( - f"No input token stats for {component.component_key}" - ) - assert output_stats is not None, ( - f"No output token stats for {component.component_key}" - ) + if limit is not None: + eligible_keys = eligible_keys[:limit] - res = await interpret_component( - client=client, - model=interpreter_model, - component=component, - arch=arch, - tokenizer=tokenizer, - input_token_stats=input_stats, - output_token_stats=output_stats, - ) - if res is None: - logger.error(f"Failed to interpret {component.component_key}") - return - result, in_tok, out_tok = res + async def _run() -> list[InterpretationResult]: + db = InterpDB(db_path) - async with output_lock: - results.append(result) - cost_tracker.add(in_tok, out_tok) - line = json.dumps(asdict(result)) + "\n" - log_progress = index % 100 == 0 - progress_msg = ( - f"[{index}] ${cost_tracker.cost_usd():.2f} ({cost_tracker.input_tokens:,} in, {cost_tracker.output_tokens:,} out)" - if log_progress - else "" + try: + completed = db.get_completed_keys() + if completed: + logger.info(f"Resuming: {len(completed)} already completed") + + remaining_keys = [k for k in eligible_keys if k not in completed] + logger.info(f"Interpreting {len(remaining_keys)} components") + + schema = INTERPRETATION_SCHEMA + + def build_jobs() -> Iterable[LLMJob]: + for key in remaining_keys: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest" + input_stats = get_input_token_stats(token_stats, key, app_tok, top_k=20) + output_stats = get_output_token_stats(token_stats, key, app_tok, top_k=50) + assert input_stats is not None + assert output_stats is not None + prompt = format_prompt( + strategy=template_strategy, + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + input_token_stats=input_stats, + output_token_stats=output_stats, ) - with open(output_path, "a") as f: - f.write(line) - - if log_progress: - tqdm_asyncio.write(progress_msg) - except Exception as e: - logger.error(f"Fatal error on {component.component_key}: {type(e).__name__}: {e}") - raise - - async with OpenRouter(api_key=openrouter_api_key) as client: - input_price, output_price = await get_model_pricing(client, interpreter_model) - cost_tracker = CostTracker( - input_price_per_token=input_price, output_price_per_token=output_price - ) - print(f"Pricing: ${input_price * 1e6:.2f}/M input, ${output_price * 1e6:.2f}/M output") + yield LLMJob(prompt=prompt, schema=schema, key=key) - await tqdm_asyncio.gather( - *[ - process_one(c, i, client, cost_tracker) - for i, c in enumerate(remaining, start=start_idx) - ], - desc="Interpreting", - ) + results: list[InterpretationResult] = [] + n_errors = 0 - print(f"Final cost: ${cost_tracker.cost_usd():.2f}") - return results - - -def get_architecture_info(wandb_path: str) -> ArchitectureInfo: - run_info = SPDRunInfo.from_path(wandb_path) - model = ComponentModel.from_run_info(run_info) - n_blocks = get_model_n_blocks(model.target_model) - config = run_info.config - task_config = config.task_config - assert isinstance(task_config, LMTaskConfig) - assert config.tokenizer_name is not None - return ArchitectureInfo( - n_blocks=n_blocks, - c_per_layer=model.module_to_c, - model_class=config.pretrained_model_class, - dataset_name=task_config.dataset_name, - tokenizer_name=config.tokenizer_name, - ) - - -def run_interpret( - wandb_path: str, - openrouter_api_key: str, - interpreter_model: str, - activation_contexts_dir: Path, - correlations_dir: Path, - autointerp_dir: Path, - limit: int | None = None, -) -> list[InterpretationResult]: - arch = get_architecture_info(wandb_path) - components = HarvestResult.load_components(activation_contexts_dir) - output_path = autointerp_dir / "results.jsonl" + async for outcome in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=model, + reasoning_effort=reasoning_effort, + jobs=build_jobs(), + max_tokens=8000, + max_concurrent=max_concurrent, + max_requests_per_minute=max_requests_per_minute, + cost_limit_usd=cost_limit_usd, + response_schema=schema, + n_total=len(remaining_keys), + ): + match outcome: + case LLMResult(job=job, parsed=parsed, raw=raw): + assert len(parsed) == 3, f"Expected 3 fields, got {len(parsed)}" + label = parsed["label"] + confidence = parsed["confidence"] + reasoning_text = parsed["reasoning"] + assert ( + isinstance(label, str) + and isinstance(confidence, str) + and isinstance(reasoning_text, str) + ) + result = InterpretationResult( + component_key=job.key, + label=label, + confidence=confidence, + reasoning=reasoning_text, + raw_response=raw, + prompt=job.prompt, + ) + results.append(result) + db.save_interpretation(result) + case LLMError(job=job, error=e): + n_errors += 1 + logger.error(f"Skipping {job.key}: {type(e).__name__}: {e}") + + error_rate = n_errors / (n_errors + len(results)) + # 10 is a magic number - just trying to avoid low sample size causing this to false alarm + if error_rate > 0.2 and n_errors > 10: + raise RuntimeError( + f"Error rate {error_rate:.0%} ({n_errors}/{len(remaining_keys)}) exceeds 20% threshold" + ) - # Load token stats - token_stats_path = correlations_dir / "token_stats.pt" - assert token_stats_path.exists(), ( - f"token_stats.pt not found at {token_stats_path}. Run harvest first." - ) - token_stats = TokenStatsStorage.load(token_stats_path) + finally: + db.close() - results = asyncio.run( - interpret_all( - components=components, - arch=arch, - openrouter_api_key=openrouter_api_key, - interpreter_model=interpreter_model, - output_path=output_path, - token_stats=token_stats, - limit=limit, - ) - ) + db.mark_done() + logger.info(f"Completed {len(results)} interpretations -> {db_path}") + return results - print(f"Completed {len(results)} interpretations -> {output_path}") - return results + return asyncio.run(_run()) diff --git a/spd/autointerp/llm_api.py b/spd/autointerp/llm_api.py new file mode 100644 index 000000000..422afd465 --- /dev/null +++ b/spd/autointerp/llm_api.py @@ -0,0 +1,350 @@ +"""LLM API utilities: batch concurrent calls with rate limiting, retry, and cost tracking.""" + +import asyncio +import contextlib +import json +import random +import time +from collections.abc import AsyncGenerator, Iterable, Sized +from dataclasses import dataclass, field +from typing import Any + +import httpx +from aiolimiter import AsyncLimiter +from openrouter import OpenRouter +from openrouter.components import ( + Effort, + JSONSchemaConfig, + Reasoning, + ResponseFormatJSONSchema, +) +from openrouter.errors import ( + BadGatewayResponseError, + ChatError, + EdgeNetworkTimeoutResponseError, + InternalServerResponseError, + OpenRouterDefaultError, + OpenRouterError, + ProviderOverloadedResponseError, + RequestTimeoutResponseError, + ServiceUnavailableResponseError, + TooManyRequestsResponseError, +) + +from spd.log import logger + +_MAX_RETRIES = 8 +_BASE_DELAY_S = 0.5 +_MAX_DELAY_S = 60.0 +_JITTER_FACTOR = 0.5 +_REQUEST_TIMEOUT_MS = 120_000 +_JSON_PARSE_RETRIES = 3 +_MAX_BACKOFF_S = 600.0 + +_RETRYABLE_ERRORS = ( + TooManyRequestsResponseError, + ProviderOverloadedResponseError, + ServiceUnavailableResponseError, + BadGatewayResponseError, + InternalServerResponseError, + RequestTimeoutResponseError, + EdgeNetworkTimeoutResponseError, + ChatError, + OpenRouterDefaultError, + httpx.TransportError, +) + + +def make_response_format(name: str, schema: dict[str, Any]) -> ResponseFormatJSONSchema: + return ResponseFormatJSONSchema( + json_schema=JSONSchemaConfig( + name=name, + schema_={**schema, "additionalProperties": False}, + strict=True, + ) + ) + + +@dataclass +class LLMJob: + prompt: str + schema: dict[str, Any] + key: str + + +@dataclass +class LLMResult: + job: LLMJob + parsed: dict[str, Any] + raw: str + + +@dataclass +class LLMError: + job: LLMJob + error: Exception + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +@dataclass +class CostTracker: + input_tokens: int = 0 + output_tokens: int = 0 + input_price_per_token: float = 0.0 + output_price_per_token: float = 0.0 + limit_usd: float | None = None + _budget_exceeded: asyncio.Event = field(default_factory=asyncio.Event) + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + async def add(self, input_tokens: int, output_tokens: int) -> None: + async with self._lock: + self.input_tokens += input_tokens + self.output_tokens += output_tokens + if self.limit_usd is not None and self.cost_usd() >= self.limit_usd: + self._budget_exceeded.set() + + def over_budget(self) -> bool: + return self._budget_exceeded.is_set() + + def cost_usd(self) -> float: + return ( + self.input_tokens * self.input_price_per_token + + self.output_tokens * self.output_price_per_token + ) + + +class _BudgetExceededError(Exception): + pass + + +class _GlobalBackoff: + """Shared backoff that pauses all coroutines when the API pushes back.""" + + def __init__(self) -> None: + self._resume_at = 0.0 + self._lock = asyncio.Lock() + + async def set_backoff(self, seconds: float) -> None: + assert seconds <= _MAX_BACKOFF_S, ( + f"Server requested {seconds:.0f}s backoff, exceeds {_MAX_BACKOFF_S:.0f}s cap" + ) + async with self._lock: + self._resume_at = max(self._resume_at, time.monotonic() + seconds) + + async def wait(self) -> None: + delay = self._resume_at - time.monotonic() + if delay > 0: + await asyncio.sleep(delay) + + +async def _get_model_pricing(api: OpenRouter, model_id: str) -> tuple[float, float]: + """Returns (input_price, output_price) per token.""" + response = await api.models.list_async() + for model in response.data: + if model.id == model_id: + return float(model.pricing.prompt), float(model.pricing.completion) + raise ValueError(f"Model {model_id} not found") + + +def _get_retry_after(e: Exception) -> float | None: + if not isinstance(e, OpenRouterError): + return None + val = e.headers.get("retry-after") + if val is None: + return None + try: + return float(val) + except ValueError: + return None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def map_llm_calls( + openrouter_api_key: str, + model: str, + reasoning_effort: Effort, + jobs: Iterable[LLMJob], + max_tokens: int, + max_concurrent: int, + max_requests_per_minute: int, + cost_limit_usd: float | None, + response_schema: dict[str, Any], + n_total: int | None = None, + cost_tracker: CostTracker | None = None, +) -> AsyncGenerator[LLMResult | LLMError]: + """Fan out LLM calls concurrently, yielding results as they complete. + + Handles rate limiting, retry with exponential backoff, JSON parsing, + cost tracking, and progress logging. Yields LLMResult on success, + LLMError on failure. Silently stops remaining jobs on budget exceeded. + + Jobs can be a lazy iterable (e.g. a generator). Prompt building in the + generator body naturally interleaves with async HTTP calls. + + Pass a shared CostTracker to accumulate costs across multiple calls. + """ + if n_total is None and isinstance(jobs, Sized): + n_total = len(jobs) + + assert not (cost_tracker is not None and cost_limit_usd is not None), ( + "Pass cost_limit_usd or cost_tracker, not both" + ) + + async with OpenRouter(api_key=openrouter_api_key) as api: + input_price, output_price = await _get_model_pricing(api, model) + if cost_tracker is not None: + cost = cost_tracker + cost.input_price_per_token = input_price + cost.output_price_per_token = output_price + else: + cost = CostTracker( + input_price_per_token=input_price, + output_price_per_token=output_price, + limit_usd=cost_limit_usd, + ) + rate_limiter = AsyncLimiter(max_rate=max_requests_per_minute, time_period=60) + backoff = _GlobalBackoff() + reasoning = Reasoning(effort=reasoning_effort) + response_format = make_response_format("response", response_schema) + + async def chat(prompt: str, context_label: str) -> str: + if cost.over_budget(): + raise _BudgetExceededError(f"${cost.cost_usd():.2f}") + + last_error: Exception | None = None + for attempt in range(_MAX_RETRIES): + await backoff.wait() + async with rate_limiter: + try: + response = await api.chat.send_async( + model=model, + max_tokens=max_tokens, + messages=[{"role": "user", "content": prompt}], + timeout_ms=_REQUEST_TIMEOUT_MS, + response_format=response_format, + reasoning=reasoning, + ) + choice = response.choices[0] + message = choice.message + assert isinstance(message.content, str) + assert response.usage is not None + + if choice.finish_reason == "length": + logger.warning( + f"{context_label}: Response truncated at {max_tokens} tokens" + ) + + await cost.add( + int(response.usage.prompt_tokens), + int(response.usage.completion_tokens), + ) + return message.content + except _RETRYABLE_ERRORS as e: + last_error = e + if attempt == _MAX_RETRIES - 1: + break + + retry_after = _get_retry_after(e) + if retry_after is not None: + await backoff.set_backoff(retry_after) + delay = retry_after + else: + delay = min(_BASE_DELAY_S * (2**attempt), _MAX_DELAY_S) + jitter = delay * _JITTER_FACTOR * random.random() + delay = delay + jitter + + logger.warning( + f"[retry {attempt + 1}/{_MAX_RETRIES}] ({context_label}) " + f"{type(e).__name__}, backing off {delay:.1f}s" + ) + await asyncio.sleep(delay) + + assert last_error is not None + raise RuntimeError(f"Max retries exceeded for {context_label}: {last_error}") + + queue: asyncio.Queue[LLMResult | LLMError | None] = asyncio.Queue() + + n_done = 0 + budget_exceeded = False + + async def process_one(job: LLMJob) -> None: + nonlocal n_done, budget_exceeded + if budget_exceeded: + return + + try: + raw = "" + parsed = None + for attempt in range(_JSON_PARSE_RETRIES): + raw = await chat(job.prompt, job.key) + try: + parsed = json.loads(raw) + break + except json.JSONDecodeError: + if attempt == _JSON_PARSE_RETRIES - 1: + raise + logger.warning( + f"{job.key}: invalid JSON " + f"(attempt {attempt + 1}/{_JSON_PARSE_RETRIES}), retrying" + ) + assert parsed is not None + await queue.put(LLMResult(job=job, parsed=parsed, raw=raw)) + except _BudgetExceededError: + budget_exceeded = True + return + except Exception as e: + await queue.put(LLMError(job=job, error=e)) + + n_done += 1 + total_str = f"/{n_total}" if n_total is not None else "" + if n_done == 1 or n_done % 10 == 0 or n_done == n_total: + logger.info( + f"[{n_done}{total_str}] ${cost.cost_usd():.2f} " + f"({cost.input_tokens:,} in, {cost.output_tokens:,} out)" + ) + + async def run_all() -> None: + job_queue: asyncio.Queue[LLMJob | None] = asyncio.Queue(maxsize=max_concurrent) + + async def worker() -> None: + while (job := await job_queue.get()) is not None: + await process_one(job) + + workers = [asyncio.create_task(worker()) for _ in range(max_concurrent)] + try: + for n_queued, job in enumerate(jobs, 1): + if budget_exceeded: + break + await job_queue.put(job) + if n_queued % 500 == 0: + logger.info(f"Queued {n_queued} jobs") + for _ in workers: + await job_queue.put(None) + await asyncio.gather(*workers) + finally: + await queue.put(None) + + task = asyncio.create_task(run_all()) + try: + while True: + item = await queue.get() + if item is None: + break + yield item + finally: + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + logger.info( + f"Final cost: ${cost.cost_usd():.2f} " + f"({cost.input_tokens:,} in, {cost.output_tokens:,} out)" + ) diff --git a/spd/autointerp/loaders.py b/spd/autointerp/loaders.py deleted file mode 100644 index aa1c40de5..000000000 --- a/spd/autointerp/loaders.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Loaders for reading autointerp output files.""" - -import json - -from spd.autointerp.schemas import InterpretationResult, get_autointerp_dir - - -def load_interpretations(wandb_run_id: str) -> dict[str, InterpretationResult] | None: - """Load interpretation results from autointerp output.""" - autointerp_dir = get_autointerp_dir(wandb_run_id) - path = autointerp_dir / "results.jsonl" - if not path.exists(): - return None - - results: dict[str, InterpretationResult] = {} - with open(path) as f: - for line in f: - data = json.loads(line) - result = InterpretationResult(**data) - results[result.component_key] = result - return results diff --git a/spd/autointerp/prompt_helpers.py b/spd/autointerp/prompt_helpers.py new file mode 100644 index 000000000..490e993ef --- /dev/null +++ b/spd/autointerp/prompt_helpers.py @@ -0,0 +1,158 @@ +"""Shared prompt-building helpers for autointerp and graph interpretation. + +Pure functions for formatting component data into LLM prompt sections. +""" + +import re + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.utils import delimit_tokens +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md + +DATASET_DESCRIPTIONS: dict[str, str] = { + "SimpleStories/SimpleStories": ( + "SimpleStories: 2M+ short stories (200-350 words), grade 1-8 reading level. " + "Simple vocabulary, common narrative elements." + ), + "danbraunai/pile-uncopyrighted-tok-shuffled": ( + "The Pile (uncopyrighted subset): diverse text from books, " + "academic papers, code, web pages, and other sources." + ), + "danbraunai/pile-uncopyrighted-tok": ( + "The Pile (uncopyrighted subset): diverse text from books, " + "academic papers, code, web pages, and other sources." + ), +} + +WEIGHT_NAMES: dict[str, str] = { + "attn.q": "attention query projection", + "attn.k": "attention key projection", + "attn.v": "attention value projection", + "attn.o": "attention output projection", + "mlp.up": "MLP up-projection", + "mlp.down": "MLP down-projection", + "glu.up": "GLU up-projection", + "glu.down": "GLU down-projection", + "glu.gate": "GLU gate projection", +} + +_ORDINALS = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"] + + +def ordinal(n: int) -> str: + if 1 <= n <= len(_ORDINALS): + return _ORDINALS[n - 1] + return f"{n}th" + + +def human_layer_desc(canonical: str, n_blocks: int) -> str: + """'0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks'""" + m = re.match(r"(\d+)\.(.*)", canonical) + if not m: + return canonical + layer_idx = int(m.group(1)) + weight_key = m.group(2) + weight_name = WEIGHT_NAMES.get(weight_key, weight_key) + return f"{weight_name} in the {ordinal(layer_idx + 1)} of {n_blocks} blocks" + + +def layer_position_note(canonical: str, n_blocks: int) -> str: + m = re.match(r"(\d+)\.", canonical) + if not m: + return "" + layer_idx = int(m.group(1)) + if layer_idx == n_blocks - 1: + return "This is in the final block, so its output directly influences token predictions." + remaining = n_blocks - 1 - layer_idx + return ( + f"This is {remaining} block{'s' if remaining > 1 else ''} from the output, " + f"so its effect on token predictions is indirect — filtered through later layers." + ) + + +def density_note(firing_density: float) -> str: + if firing_density > 0.15: + return ( + "This is a high-density component (fires frequently). " + "High-density components often act as broad biases rather than selective features." + ) + if firing_density < 0.005: + return "This is a very sparse component, likely highly specific." + return "" + + +def build_output_section( + output_stats: TokenPRLift, + output_pmi: list[tuple[str, float]] | None, +) -> Md: + md = Md() + if output_pmi: + md.labeled_list( + "**Output PMI (pointwise mutual information, in nats: how much more likely " + "a token is to be produced when this component fires, vs its base rate. " + "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in output_pmi[:10]], + ) + if output_stats.top_precision: + md.labeled_list( + "**Output precision — of all probability mass for token X, what fraction " + "is at positions where this component fires?**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in output_stats.top_precision[:10]], + ) + return md + + +def build_input_section( + input_stats: TokenPRLift, + input_pmi: list[tuple[str, float]] | None, +) -> Md: + md = Md() + if input_pmi: + md.labeled_list( + "**Input PMI (same metric as above, for input tokens):**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in input_pmi[:6]], + ) + if input_stats.top_precision: + md.labeled_list( + "**Input precision — probability the component fires given the current token is X:**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in input_stats.top_precision[:8]], + ) + return md + + +def _build_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, + shift_firings: bool, +) -> Md: + items: list[str] = [] + for ex in component.activation_examples[:max_examples]: + if not any(ex.firings): + continue + spans = app_tok.get_spans(ex.token_ids) + firings = [False] + ex.firings[:-1] if shift_firings else ex.firings + tokens = list(zip(spans, firings, strict=True)) + items.append(delimit_tokens(tokens)) + md = Md() + if items: + md.numbered(items) + return md + + +def build_fires_on_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> Md: + return _build_examples(component, app_tok, max_examples, shift_firings=False) + + +def build_says_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> Md: + return _build_examples(component, app_tok, max_examples, shift_firings=True) diff --git a/spd/autointerp/prompt_template.py b/spd/autointerp/prompt_template.py deleted file mode 100644 index f01823227..000000000 --- a/spd/autointerp/prompt_template.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Prompt templates for component auto-interpretation.""" - -import json - -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from spd.app.backend.utils import build_token_lookup -from spd.autointerp import MAX_EXAMPLES_PER_COMPONENT -from spd.autointerp.schemas import ArchitectureInfo -from spd.harvest.analysis import TokenPRLift -from spd.harvest.schemas import ComponentData - - -def _parse_layer_description(layer: str, n_blocks: int) -> str: - """Parse layer name into human-readable description. - - e.g. "h.2.mlp.c_fc" -> "MLP up-projection in layer 3 of 4" - """ - parts = layer.split(".") - assert parts[0] == "h", f"unexpected layer format: {layer}" - layer_idx = int(parts[1]) - layer_num = layer_idx + 1 - - sublayer = ".".join(parts[2:]) - sublayer_desc = { - "mlp.c_fc": "MLP up-projection", - "mlp.c_proj": "MLP down-projection", - "attn.q_proj": "attention Q projection", - "attn.k_proj": "attention K projection", - "attn.v_proj": "attention V projection", - "attn.o_proj": "attention output projection", - }.get(sublayer, sublayer) - - return f"{sublayer_desc} in layer {layer_num} of {n_blocks}" - - -INTERPRETATION_SCHEMA = { - "type": "object", - "properties": { - "label": { - "type": "string", - "description": "3-10 word label describing what the component detects/represents", - }, - "confidence": { - "type": "string", - "enum": ["low", "medium", "high"], - "description": "low = multiple plausible interpretations or weak signal; medium = coherent pattern but some noise; high = clear, consistent pattern across metrics", - }, - "reasoning": { - "type": "string", - "description": "2-4 sentences explaining the evidence and ambiguities", - }, - }, - "required": ["label", "confidence", "reasoning"], - "additionalProperties": False, -} - -INTERPRETATION_SCHEMA_JSON_STR = json.dumps(INTERPRETATION_SCHEMA, indent=2) - - -DATASET_DESCRIPTIONS: dict[str, str] = { - "SimpleStories/SimpleStories": """\ -SimpleStories is a dataset of 2M+ short stories (200-350 words each) at a grade 1-8 reading level. -The stories cover diverse themes (friendship, courage, loss, discovery) and settings (magical lands, -schools, forests, space). The vocabulary is simple, everyday English. Stories feature common narrative -elements: characters with names, emotions, dialogue, and simple plot arcs with resolutions.""", -} - -SPD_THEORETICAL_CONTEXT = """\ -SPD (Stochastic Parameter Decomposition) decomposes a neural network's weight matrices into rank-1 -"subcomponents". Each subcomponent has a causal importance (CI) value predicted *per sequence position* -by a small auxiliary neural network. CI indicates how necessary the component is for the model's output -at that position: high CI (close to 1) means the component is essential and cannot be ablated; low CI -(close to 0) means it can be removed without affecting output. The training objective encourages -sparsity: as few components as possible should have high CI for any given input.""" - - -def format_prompt_template( - component: ComponentData, - arch: ArchitectureInfo, - tokenizer: PreTrainedTokenizerBase, - input_token_stats: TokenPRLift, - output_token_stats: TokenPRLift, - ci_display_threshold: float = 0.3, - output_precision_top_k: int = 40, -) -> str: - """Improved prompt template using recall/precision/PMI. - - Key improvements over v1: - - Uses recall AND precision for input tokens - - Uses precision AND PMI for output tokens - - Only shows high-CI tokens in examples (reduces noise) - - Includes inline metric definitions - - Better dataset descriptions to avoid vacuous interpretations - """ - lookup = build_token_lookup(tokenizer, tokenizer.name_or_path) - PADDING_SENTINEL = -1 - - # Convert PMI from ComponentData to decoded tokens - input_pmi = ( - [(lookup[tid], pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None - ) - output_pmi = ( - [(lookup[tid], pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None - ) - - # Build input token section using recall, precision, and PMI - input_section = _build_input_token_section(input_token_stats, input_pmi) - - # Build output token section using precision and PMI - output_section = _build_output_token_section( - output_token_stats, output_pmi, output_precision_top_k - ) - - # Build examples showing only high-CI tokens - examples_section = _build_examples_section( - component, tokenizer, lookup, ci_display_threshold, PADDING_SENTINEL - ) - - # Get dataset description - dataset_description = DATASET_DESCRIPTIONS.get( - arch.dataset_name, f"Dataset: {arch.dataset_name}" - ) - - # Calculate firing rate context - firing_rate_context = "" - if component.mean_ci > 0: - tokens_per_firing = int(1 / component.mean_ci) - firing_rate_context = f" (fires on ~1 in {tokens_per_firing} tokens)" - - layer_desc = _parse_layer_description(component.layer, arch.n_blocks) - - return f"""\ -Label this neural network component from a Stochastic Parameter Decomposition. - -## Background - -{SPD_THEORETICAL_CONTEXT} - -## Model Context - -**Model**: {arch.model_class} ({arch.n_blocks} layers) -**Dataset**: {dataset_description} - -## Component Context - -**Component location**: {layer_desc} -**Activation rate**: {component.mean_ci * 100:.2f}%{firing_rate_context} - ---- - -{input_section} - ---- - -{output_section} - ---- - -{examples_section} - ---- - -## Task - -Based on the above context, what concept or pattern does this component represent? -Consider both what the component *does* (what tokens it helps predict) and what triggers it. - -If the pattern is unclear or the evidence is weak, say so. Use "unclear" or "noisy" in your label if appropriate—do not force an interpretation where none exists. - -Return JSON: -```json -{INTERPRETATION_SCHEMA_JSON_STR} -``` -""" - - -def _build_input_token_section( - input_stats: TokenPRLift, - input_pmi: list[tuple[str, float]] | None, -) -> str: - """Build input token analysis section using recall, precision, and PMI.""" - section = """\ -## Correlations with Input Tokens - -The following metrics concern correlations between this component firing and the "current" token (the token at the position where the component is active). - -""" - - # Recall section - if input_stats.top_recall: - section += '**Recall:** _"What % of this component\'s firings occurred on token X?"_\n' - for token, recall in input_stats.top_recall[:10]: - pct = recall * 100 - if pct >= 1: - section += f" {repr(token)}: {pct:.0f}%\n" - elif pct >= 0.1: - section += f" {repr(token)}: {pct:.1f}%\n" - section += "\n" - - # Precision section - very important for detecting deterministic triggers - if input_stats.top_precision: - section += '**Precision:** _"When token X appears, what % of the time does this component fire?"_\n' - for token, prec in input_stats.top_precision[:10]: - section += f" {repr(token)}: {prec * 100:.0f}%\n" - section += "\n" - - # PMI section - shows surprising associations - if input_pmi: - section += "**PMI:** _Tokens with higher-than-expected co-occurrence_\n" - for token, pmi in input_pmi[:8]: - section += f" {repr(token)}: {pmi:.2f}\n" - section += "\n" - - return section - - -def _build_output_token_section( - output_stats: TokenPRLift, - output_pmi: list[tuple[str, float]] | None, - top_k: int, -) -> str: - """Build output token analysis section using precision and PMI.""" - section = """\ -## Correlations with Predicted Tokens - -The following metrics concern correlations between this component firing and the token the model predicts at that position. - -""" - - # Precision section - if output_stats.top_precision: - section += '**Precision:** _"When the model predicts token X, what % of the time is this component active?"_\n\n' - # Group by precision ranges - very_high = [(t, p) for t, p in output_stats.top_precision[:top_k] if p > 0.90] - high = [(t, p) for t, p in output_stats.top_precision[:top_k] if 0.70 <= p <= 0.90] - medium = [(t, p) for t, p in output_stats.top_precision[:top_k] if 0.50 <= p < 0.70] - - if very_high: - tokens = [repr(t) for t, _ in very_high[:30]] - section += f"**Very high (>90%)**: {', '.join(tokens)}\n\n" - - if high: - tokens = [repr(t) for t, _ in high[:20]] - section += f"**High (70-90%)**: {', '.join(tokens)}\n\n" - - if medium: - tokens = [repr(t) for t, _ in medium[:15]] - section += f"**Medium (50-70%)**: {', '.join(tokens)}\n\n" - - if not very_high and not high and not medium: - tokens_with_prec = [ - f"{repr(t)} ({p * 100:.0f}%)" for t, p in output_stats.top_precision[:15] - ] - section += f"Top by precision: {', '.join(tokens_with_prec)}\n\n" - - # PMI section for output tokens - if output_pmi: - section += "**PMI:** _Tokens with higher-than-expected co-occurrence_\n" - for token, pmi in output_pmi[:10]: - section += f" {repr(token)}: {pmi:.2f}\n" - - return section - - -def _build_examples_section( - component: ComponentData, - tokenizer: PreTrainedTokenizerBase, - lookup: dict[int, str], - ci_threshold: float, - padding_sentinel: int, -) -> str: - """Build examples section showing only high-CI tokens.""" - section = f"""\ -## Activation Examples - -_Showing tokens where CI > {ci_threshold} (component is active)_ - -""" - examples = component.activation_examples[:MAX_EXAMPLES_PER_COMPONENT] - for i, example in enumerate(examples): - # Decode full text - valid_tokens = [t for t in example.token_ids if t != padding_sentinel and t >= 0] - full_text = tokenizer.decode(valid_tokens) if valid_tokens else "" - display_text = full_text.replace("\n", " ") - - # Get high-CI tokens with their CI values - active_tokens = [] - for tid, ci in zip(example.token_ids, example.ci_values, strict=True): - if ci > ci_threshold and tid != padding_sentinel and tid >= 0: - tok = lookup[tid].strip() - active_tokens.append(f'"{tok}" ({ci:.2f})') - - active_str = ", ".join(active_tokens) - - section += f'Ex {i + 1}: "{display_text}"\n' - section += f" Active tokens: {active_str}\n\n" - - return section diff --git a/spd/autointerp/repo.py b/spd/autointerp/repo.py new file mode 100644 index 000000000..75c3fec29 --- /dev/null +++ b/spd/autointerp/repo.py @@ -0,0 +1,104 @@ +"""Autointerp data repository. + +Owns SPD_OUT_DIR/autointerp// and provides read access to +interpretations and evaluation scores. + +Each autointerp subrun (a-YYYYMMDD_HHMMSS) has its own interp.db. +Use InterpRepo.open() to construct — returns None if no autointerp data exists. +""" + +from pathlib import Path +from typing import Any + +import yaml + +from spd.autointerp.db import DONE_MARKER, InterpDB +from spd.autointerp.schemas import InterpretationResult, get_autointerp_dir +from spd.log import logger + + +class InterpRepo: + """Read access to autointerp data for a single run. + + Constructed via InterpRepo.open(). DB is opened eagerly at construction. + """ + + def __init__(self, db: InterpDB, subrun_dir: Path, run_id: str) -> None: + self._db = db + self._subrun_dir = subrun_dir + self.subrun_id = subrun_dir.name + self.run_id = run_id + + @classmethod + def _find_latest_done_subrun_dir(cls, run_id: str) -> Path | None: + autointerp_dir = get_autointerp_dir(run_id) + if not autointerp_dir.exists(): + return None + candidates = sorted( + [ + d + for d in autointerp_dir.iterdir() + if d.is_dir() and d.name.startswith("a-") and (d / DONE_MARKER).exists() + ], + key=lambda d: d.name, + ) + return candidates[-1] if candidates else None + + @classmethod + def open(cls, run_id: str) -> "InterpRepo | None": + """Open autointerp data for a run. Returns None if no completed autointerp data exists.""" + subrun_dir = cls._find_latest_done_subrun_dir(run_id) + if subrun_dir is None: + return None + db_path = subrun_dir / "interp.db" + if not db_path.exists(): + return None + logger.info(f"Opening autointerp data for {run_id} from {subrun_dir}") + return cls( + db=InterpDB(db_path, readonly=True), + subrun_dir=subrun_dir, + run_id=run_id, + ) + + # -- Provenance ------------------------------------------------------------ + + def get_config(self) -> dict[str, Any] | None: + config_path = self._subrun_dir / "config.yaml" + if not config_path.exists(): + return None + with open(config_path) as f: + return yaml.safe_load(f) + + def get_interpretation_count(self) -> int: + return self._db.get_interpretation_count() + + def get_available_score_types(self) -> list[str]: + return [st for st in ["detection", "fuzzing"] if self._db.has_scores(st)] + + # -- Interpretations ------------------------------------------------------- + + def get_all_interpretations(self) -> dict[str, InterpretationResult]: + return self._db.get_all_interpretations() + + def get_interpretation(self, component_key: str) -> InterpretationResult | None: + return self._db.get_interpretation(component_key) + + def save_interpretation(self, result: InterpretationResult) -> None: + self._db.save_interpretation(result) + + # -- Eval scores (label-dependent only) ------------------------------------ + + def get_detection_scores(self) -> dict[str, float] | None: + scores = self._db.get_scores("detection") + return scores if scores else None + + def get_fuzzing_scores(self) -> dict[str, float] | None: + scores = self._db.get_scores("fuzzing") + return scores if scores else None + + def get_scores(self, score_type: str) -> dict[str, float]: + scores = self._db.get_scores(score_type) + return scores if scores else {} + + def save_score(self, component_key: str, score_type: str, score: float, details: str) -> None: + self._db.save_score(component_key, score_type, score, details) diff --git a/spd/autointerp/schemas.py b/spd/autointerp/schemas.py index 8ab563f93..b5cd729b7 100644 --- a/spd/autointerp/schemas.py +++ b/spd/autointerp/schemas.py @@ -9,18 +9,22 @@ AUTOINTERP_DATA_DIR = SPD_OUT_DIR / "autointerp" -def get_autointerp_dir(wandb_run_id: str) -> Path: - """Get the autointerp (interpretations) directory for a run.""" - return AUTOINTERP_DATA_DIR / wandb_run_id +def get_autointerp_dir(decomposition_id: str) -> Path: + """Get the top-level autointerp directory for an SPD run.""" + return AUTOINTERP_DATA_DIR / decomposition_id + + +def get_autointerp_subrun_dir(decomposition_id: str, autointerp_run_id: str) -> Path: + """Get the directory for a specific autointerp run (timestamped subdirectory).""" + return get_autointerp_dir(decomposition_id) / autointerp_run_id @dataclass -class ArchitectureInfo: +class ModelMetadata: n_blocks: int - c_per_layer: dict[str, int] # Maps layer name -> number of components model_class: str dataset_name: str - tokenizer_name: str + layer_descriptions: dict[str, str] @dataclass diff --git a/spd/autointerp/scoring/__init__.py b/spd/autointerp/scoring/__init__.py new file mode 100644 index 000000000..8c602f2bb --- /dev/null +++ b/spd/autointerp/scoring/__init__.py @@ -0,0 +1,4 @@ +"""Scoring module for evaluating autointerp quality. + +Results are written as append-only JSONL for resume support. +""" diff --git a/spd/autointerp/scoring/detection.py b/spd/autointerp/scoring/detection.py new file mode 100644 index 000000000..afc00a163 --- /dev/null +++ b/spd/autointerp/scoring/detection.py @@ -0,0 +1,245 @@ +"""Detection scoring. + +Tests whether a component's interpretation label is predictive of its activations by asking +an LLM to classify plain text examples as activating or non-activating. + +Based on: EleutherAI's sae-auto-interp (https://blog.eleuther.ai/autointerp/). +""" + +import json +import random +from collections import defaultdict +from dataclasses import asdict, dataclass + +from openrouter.components import Effort + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.utils import delimit_tokens +from spd.autointerp.config import DetectionEvalConfig +from spd.autointerp.db import InterpDB +from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls +from spd.autointerp.repo import InterpRepo +from spd.harvest.schemas import ActivationExample, ComponentData +from spd.log import logger + +DETECTION_SCHEMA = { + "type": "object", + "properties": { + "activating": { + "type": "array", + "items": {"type": "integer"}, + "description": "1-indexed example numbers that activate the component", + }, + }, + "required": ["activating"], +} + + +@dataclass +class DetectionTrial: + predicted_activating: list[int] # 1-indexed example numbers the LLM said activate + actual_activating: list[int] # ground truth 1-indexed + tpr: float + tnr: float + balanced_acc: float + + +@dataclass +class DetectionResult: + component_key: str + score: float # mean balanced accuracy across trials + trials: list[DetectionTrial] + n_errors: int + + +def _format_example_with_center_token( + example: ActivationExample, + app_tok: AppTokenizer, +) -> str: + """Format an example with the center token marked with <>. + + Harvest windows are centered on the firing position, so the center token + is always the one that triggered collection. We mark center for both + activating and non-activating examples to avoid positional leakage. + """ + valid_ids = [tid for tid in example.token_ids if tid >= 0] + center = len(valid_ids) // 2 + spans = app_tok.get_spans(valid_ids) + tokens = [(span, i == center) for i, span in enumerate(spans)] + return delimit_tokens(tokens) + + +def _sample_non_activating_examples( + target_component: ComponentData, + all_components: list[ComponentData], + n: int, + rng: random.Random, +) -> list[ActivationExample]: + """Sample non-activating examples from other components.""" + other_components = [ + c + for c in all_components + if c.component_key != target_component.component_key and len(c.activation_examples) >= 1 + ] + assert other_components, "No other components available for non-activating sampling" + + sampled: list[ActivationExample] = [] + for _ in range(n): + donor = rng.choice(other_components) + sampled.append(rng.choice(donor.activation_examples)) + return sampled + + +def _build_detection_prompt( + label: str, + examples_with_labels: list[tuple[str, bool]], +) -> str: + n_total = len(examples_with_labels) + + examples_text = "" + for i, (text, _) in enumerate(examples_with_labels): + examples_text += f"Example {i + 1}: {text}\n\n" + + return f"""\ +A neural network component has been labeled as: "{label}" + +Below are {n_total} text snippets. In each, one token is marked between <>. \ +For some examples, the marked token is one where this component fires. \ +For others, the marked token is random. + +{examples_text}\ +Based on the label, in which examples is the <> token one where this component fires? + +Respond with the list of activating example numbers.""" + + +@dataclass +class _TrialGroundTruth: + component_key: str + actual_activating: set[int] + actual_non_activating: set[int] + + +async def run_detection_scoring( + components: list[ComponentData], + interp_repo: InterpRepo, + score_db: InterpDB, + model: str, + reasoning_effort: Effort, + openrouter_api_key: str, + tokenizer_name: str, + config: DetectionEvalConfig, + max_concurrent: int, + max_requests_per_minute: int, + limit: int | None, + cost_limit_usd: float | None, +) -> list[DetectionResult]: + app_tok = AppTokenizer.from_pretrained(tokenizer_name) + + labels = {key: result.label for key, result in interp_repo.get_all_interpretations().items()} + + eligible = [ + c + for c in components + if c.component_key in labels and len(c.activation_examples) >= config.n_activating + ] + if limit is not None: + eligible = eligible[:limit] + + existing_scores = score_db.get_scores("detection") + completed = set(existing_scores.keys()) + if completed: + logger.info(f"Resuming: {len(completed)} already scored") + + remaining = [c for c in eligible if c.component_key not in completed] + logger.info(f"Scoring {len(remaining)} components ({len(remaining) * config.n_trials} trials)") + + rng = random.Random() + jobs: list[LLMJob] = [] + ground_truth: dict[str, _TrialGroundTruth] = {} + + for component in remaining: + label = labels[component.component_key] + for trial_idx in range(config.n_trials): + activating = ( + list(component.activation_examples) + if len(component.activation_examples) <= config.n_activating + else rng.sample(component.activation_examples, config.n_activating) + ) + + non_activating = _sample_non_activating_examples( + component, components, config.n_non_activating, rng + ) + + formatted: list[tuple[str, bool]] = [] + for ex in activating: + formatted.append((_format_example_with_center_token(ex, app_tok), True)) + for ex in non_activating: + formatted.append((_format_example_with_center_token(ex, app_tok), False)) + rng.shuffle(formatted) + + key = f"{component.component_key}/trial{trial_idx}" + actual_act = {i + 1 for i, (_, is_act) in enumerate(formatted) if is_act} + actual_non_act = {i + 1 for i, (_, is_act) in enumerate(formatted) if not is_act} + jobs.append( + LLMJob( + prompt=_build_detection_prompt(label, formatted), + schema=DETECTION_SCHEMA, + key=key, + ) + ) + ground_truth[key] = _TrialGroundTruth( + component_key=component.component_key, + actual_activating=actual_act, + actual_non_activating=actual_non_act, + ) + + component_trials: defaultdict[str, list[DetectionTrial]] = defaultdict(list) + component_errors: defaultdict[str, int] = defaultdict(int) + + async for outcome in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=model, + reasoning_effort=reasoning_effort, + jobs=jobs, + max_tokens=5000, + max_concurrent=max_concurrent, + max_requests_per_minute=max_requests_per_minute, + cost_limit_usd=cost_limit_usd, + response_schema=DETECTION_SCHEMA, + n_total=len(jobs), + ): + match outcome: + case LLMResult(job=job, parsed=parsed): + gt = ground_truth[job.key] + predicted = {int(x) for x in parsed["activating"]} + tp = len(predicted & gt.actual_activating) + tn = len(gt.actual_non_activating - predicted) + tpr = tp / len(gt.actual_activating) if gt.actual_activating else 0.0 + tnr = tn / len(gt.actual_non_activating) if gt.actual_non_activating else 0.0 + component_trials[gt.component_key].append( + DetectionTrial( + predicted_activating=sorted(predicted), + actual_activating=sorted(gt.actual_activating), + tpr=tpr, + tnr=tnr, + balanced_acc=(tpr + tnr) / 2, + ) + ) + case LLMError(job=job, error=e): + gt = ground_truth[job.key] + component_errors[gt.component_key] += 1 + logger.error(f"{job.key}: {type(e).__name__}: {e}") + + results: list[DetectionResult] = [] + for component in remaining: + ck = component.component_key + trials = component_trials.get(ck, []) + n_err = component_errors.get(ck, 0) + score = sum(t.balanced_acc for t in trials) / len(trials) if trials else 0.0 + result = DetectionResult(component_key=ck, score=score, trials=trials, n_errors=n_err) + results.append(result) + score_db.save_score(ck, "detection", score, json.dumps(asdict(result))) + + logger.info(f"Scored {len(results)} components") + return results diff --git a/spd/autointerp/scoring/fuzzing.py b/spd/autointerp/scoring/fuzzing.py new file mode 100644 index 000000000..9669a46cf --- /dev/null +++ b/spd/autointerp/scoring/fuzzing.py @@ -0,0 +1,243 @@ +"""Fuzzing scoring. + +Tests the *specificity* of an interpretation label by checking if an LLM can +distinguish correctly-highlighted activating tokens from incorrectly-highlighted ones. +Catches labels that are too vague or generic. + +Based on: EleutherAI's sae-auto-interp (https://blog.eleuther.ai/autointerp/). +""" + +import json +import random +from collections import defaultdict +from dataclasses import asdict, dataclass + +from openrouter.components import Effort + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.utils import delimit_tokens +from spd.autointerp.config import FuzzingEvalConfig +from spd.autointerp.db import InterpDB +from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls +from spd.autointerp.repo import InterpRepo +from spd.harvest.schemas import ActivationExample, ComponentData +from spd.log import logger + +FUZZING_SCHEMA = { + "type": "object", + "properties": { + "correct_examples": { + "type": "array", + "items": {"type": "integer"}, + "description": "1-indexed example numbers with correct highlighting", + }, + "reasoning": {"type": "string", "description": "Brief explanation"}, + }, + "required": ["correct_examples", "reasoning"], +} + + +@dataclass +class FuzzingTrial: + correct_positions: list[int] # 1-indexed positions with correct highlighting + predicted_correct: list[int] # what the LLM said was correct + tp: int + tn: int + n_correct: int + n_incorrect: int + + +@dataclass +class FuzzingResult: + component_key: str + score: float # balanced accuracy = (TPR + TNR) / 2 + trials: list[FuzzingTrial] + n_errors: int + + +def _delimit_tokens( + example: ActivationExample, + app_tok: AppTokenizer, +) -> tuple[str, int]: + """Format example with firing tokens in <>. Returns (text, n_delimited).""" + spans = app_tok.get_spans(example.token_ids) + tokens = [(span, firing) for span, firing in zip(spans, example.firings, strict=True)] + n_delimited = sum(example.firings) + return delimit_tokens(tokens), n_delimited + + +def _delimit_random_tokens( + example: ActivationExample, + app_tok: AppTokenizer, + n_to_delimit: int, + rng: random.Random, +) -> str: + """Format example with random tokens in <> instead of firing ones.""" + n_toks = len(example.token_ids) + + delimit_set = set(rng.sample(range(n_toks), min(n_to_delimit, n_toks))) + spans = app_tok.get_spans(example.token_ids) + tokens = [(span, j in delimit_set) for j, span in enumerate(spans)] + return delimit_tokens(tokens) + + +def _build_fuzzing_prompt( + label: str, + formatted_examples: list[tuple[str, bool]], +) -> str: + n_examples = len(formatted_examples) + + examples_text = "" + for i, (text, _) in enumerate(formatted_examples): + examples_text += f"Example {i + 1}: {text}\n\n" + + return f"""\ +A neural network component has been interpreted as: "{label}" + +Below are {n_examples} text examples where this component is active. In each example, some tokens \ +are marked between <>. In some examples, the <> tokens correctly indicate \ +where the component fires most strongly. In other examples, the <> tokens are random \ +and unrelated to the component's actual firing pattern. + +{examples_text}\ +Based on the interpretation "{label}", which examples have correctly-marked tokens \ +(consistent with the label) vs. randomly-marked tokens? + +Respond with the list of correctly-highlighted example numbers and brief reasoning.\ +""" + + +@dataclass +class _TrialGroundTruth: + component_key: str + correct_positions: set[int] + incorrect_positions: set[int] + + +async def run_fuzzing_scoring( + components: list[ComponentData], + interp_repo: InterpRepo, + score_db: InterpDB, + model: str, + reasoning_effort: Effort, + openrouter_api_key: str, + tokenizer_name: str, + config: FuzzingEvalConfig, + max_concurrent: int, + max_requests_per_minute: int, + limit: int | None, + cost_limit_usd: float | None, +) -> list[FuzzingResult]: + app_tok = AppTokenizer.from_pretrained(tokenizer_name) + + labels = {key: result.label for key, result in interp_repo.get_all_interpretations().items()} + + min_examples = config.n_correct + config.n_incorrect + + eligible = [ + c + for c in components + if c.component_key in labels and len(c.activation_examples) >= min_examples + ] + if limit is not None: + eligible = eligible[:limit] + + existing_scores = score_db.get_scores("fuzzing") + completed = set(existing_scores.keys()) + if completed: + logger.info(f"Resuming: {len(completed)} already scored") + + remaining = [c for c in eligible if c.component_key not in completed] + logger.info(f"Scoring {len(remaining)} components ({len(remaining) * config.n_trials} trials)") + + rng = random.Random() + jobs: list[LLMJob] = [] + ground_truth: dict[str, _TrialGroundTruth] = {} + + for component in remaining: + label = labels[component.component_key] + for trial_idx in range(config.n_trials): + sampled = rng.sample( + component.activation_examples, config.n_correct + config.n_incorrect + ) + correct_examples = sampled[: config.n_correct] + incorrect_examples = sampled[config.n_correct :] + + formatted: list[tuple[str, bool]] = [] + for ex in correct_examples: + text, _ = _delimit_tokens(ex, app_tok) + formatted.append((text, True)) + for ex in incorrect_examples: + _, n_delimited = _delimit_tokens(ex, app_tok) + n_to_delimit = max(n_delimited, 1) + text = _delimit_random_tokens(ex, app_tok, n_to_delimit, rng) + formatted.append((text, False)) + rng.shuffle(formatted) + + key = f"{component.component_key}/trial{trial_idx}" + correct_pos = {i + 1 for i, (_, is_correct) in enumerate(formatted) if is_correct} + incorrect_pos = {i + 1 for i, (_, is_correct) in enumerate(formatted) if not is_correct} + jobs.append( + LLMJob( + prompt=_build_fuzzing_prompt(label, formatted), schema=FUZZING_SCHEMA, key=key + ) + ) + ground_truth[key] = _TrialGroundTruth( + component_key=component.component_key, + correct_positions=correct_pos, + incorrect_positions=incorrect_pos, + ) + + component_trials: defaultdict[str, list[FuzzingTrial]] = defaultdict(list) + component_errors: defaultdict[str, int] = defaultdict(int) + + async for outcome in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=model, + reasoning_effort=reasoning_effort, + jobs=jobs, + max_tokens=5000, + max_concurrent=max_concurrent, + max_requests_per_minute=max_requests_per_minute, + cost_limit_usd=cost_limit_usd, + response_schema=FUZZING_SCHEMA, + ): + match outcome: + case LLMResult(job=job, parsed=parsed): + gt = ground_truth[job.key] + predicted_correct = set(parsed["correct_examples"]) + tp = len(gt.correct_positions & predicted_correct) + tn = len(gt.incorrect_positions - predicted_correct) + component_trials[gt.component_key].append( + FuzzingTrial( + correct_positions=sorted(gt.correct_positions), + predicted_correct=sorted(predicted_correct), + tp=tp, + tn=tn, + n_correct=len(gt.correct_positions), + n_incorrect=len(gt.incorrect_positions), + ) + ) + case LLMError(job=job, error=e): + gt = ground_truth[job.key] + component_errors[gt.component_key] += 1 + logger.error(f"{job.key}: {type(e).__name__}: {e}") + + results: list[FuzzingResult] = [] + for component in remaining: + ck = component.component_key + trials = component_trials.get(ck, []) + n_err = component_errors.get(ck, 0) + total_tp = sum(t.tp for t in trials) + total_tn = sum(t.tn for t in trials) + total_pos = sum(t.n_correct for t in trials) + total_neg = sum(t.n_incorrect for t in trials) + tpr = total_tp / total_pos if total_pos > 0 else 0.0 + tnr = total_tn / total_neg if total_neg > 0 else 0.0 + score = (tpr + tnr) / 2 if (total_pos > 0 and total_neg > 0) else 0.0 + result = FuzzingResult(component_key=ck, score=score, trials=trials, n_errors=n_err) + results.append(result) + score_db.save_score(ck, "fuzzing", score, json.dumps(asdict(result))) + + logger.info(f"Scored {len(results)} components") + return results diff --git a/spd/harvest/lib/__init__.py b/spd/autointerp/scoring/scripts/__init__.py similarity index 100% rename from spd/harvest/lib/__init__.py rename to spd/autointerp/scoring/scripts/__init__.py diff --git a/spd/autointerp/scoring/scripts/run_label_scoring.py b/spd/autointerp/scoring/scripts/run_label_scoring.py new file mode 100644 index 000000000..fd95f9763 --- /dev/null +++ b/spd/autointerp/scoring/scripts/run_label_scoring.py @@ -0,0 +1,113 @@ +"""CLI for label-based scoring (detection, fuzzing). + +Usage: + python -m spd.autointerp.scoring.scripts.run_label_scoring --config_json '...' --harvest_subrun_id h-20260211_120000 +""" + +import asyncio +import os +from typing import Any, Literal + +from dotenv import load_dotenv + +from spd.adapters import adapter_from_id +from spd.autointerp.config import AutointerpEvalConfig +from spd.autointerp.db import InterpDB +from spd.autointerp.repo import InterpRepo +from spd.autointerp.scoring.detection import run_detection_scoring +from spd.autointerp.scoring.fuzzing import run_fuzzing_scoring +from spd.harvest.repo import HarvestRepo + +LabelScorerType = Literal["detection", "fuzzing"] + + +def main( + decomposition_id: str, + scorer_type: LabelScorerType, + config_json: dict[str, Any], + harvest_subrun_id: str, +) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + load_dotenv() + openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") + assert openrouter_api_key, "OPENROUTER_API_KEY not set" + + config = AutointerpEvalConfig.model_validate(config_json) + + tokenizer_name = adapter_from_id(decomposition_id).tokenizer_name + + interp_repo = InterpRepo.open(decomposition_id) + assert interp_repo is not None, ( + f"No autointerp data for {decomposition_id}. Run autointerp first." + ) + + # Separate writable DB for saving scores (the repo's DB is readonly/immutable) + score_db = InterpDB(interp_repo._subrun_dir / "interp.db") + + harvest = HarvestRepo( + decomposition_id=decomposition_id, + subrun_id=harvest_subrun_id, + readonly=True, + ) + + components = harvest.get_all_components() + + match scorer_type: + case "detection": + asyncio.run( + run_detection_scoring( + components=components, + interp_repo=interp_repo, + score_db=score_db, + model=config.model, + reasoning_effort=config.reasoning_effort, + openrouter_api_key=openrouter_api_key, + tokenizer_name=tokenizer_name, + config=config.detection_config, + max_concurrent=config.max_concurrent, + max_requests_per_minute=config.max_requests_per_minute, + limit=config.limit, + cost_limit_usd=config.cost_limit_usd, + ) + ) + case "fuzzing": + asyncio.run( + run_fuzzing_scoring( + components=components, + interp_repo=interp_repo, + score_db=score_db, + model=config.model, + reasoning_effort=config.reasoning_effort, + openrouter_api_key=openrouter_api_key, + tokenizer_name=tokenizer_name, + config=config.fuzzing_config, + max_concurrent=config.max_concurrent, + max_requests_per_minute=config.max_requests_per_minute, + limit=config.limit, + cost_limit_usd=config.cost_limit_usd, + ) + ) + + score_db.close() + + +def get_command( + decomposition_id: str, + scorer_type: LabelScorerType, + config: AutointerpEvalConfig, + harvest_subrun_id: str, +) -> str: + config_json = config.model_dump_json(exclude_none=True) + return ( + f"python -m spd.autointerp.scoring.scripts.run_label_scoring " + f"--decomposition_id {decomposition_id} " + f"--scorer_type {scorer_type} " + f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/autointerp/scripts/run_interpret.py b/spd/autointerp/scripts/run_interpret.py index 76f27b44d..263329172 100644 --- a/spd/autointerp/scripts/run_interpret.py +++ b/spd/autointerp/scripts/run_interpret.py @@ -1,59 +1,89 @@ """CLI for autointerp pipeline. -Usage (direct execution): - python -m spd.autointerp.scripts.run_interpret - -Usage (SLURM submission): - spd-autointerp +Usage: + python -m spd.autointerp.scripts.run_interpret --config_json '...' + spd-autointerp # SLURM submission """ import os +from datetime import datetime +from typing import Any from dotenv import load_dotenv -from spd.autointerp.interpret import OpenRouterModelName, run_interpret -from spd.autointerp.schemas import get_autointerp_dir -from spd.harvest.schemas import get_activation_contexts_dir, get_correlations_dir -from spd.utils.wandb_utils import parse_wandb_run_path +from spd.adapters import adapter_from_id +from spd.autointerp.config import AutointerpConfig +from spd.autointerp.interpret import run_interpret +from spd.autointerp.schemas import get_autointerp_dir, get_autointerp_subrun_dir +from spd.harvest.repo import HarvestRepo +from spd.log import logger def main( - wandb_path: str, - model: OpenRouterModelName, - limit: int | None = None, + decomposition_id: str, + config_json: dict[str, Any], + harvest_subrun_id: str, + autointerp_subrun_id: str | None = None, ) -> None: - """Interpret harvested components. - - Args: - wandb_path: WandB run path for the target decomposition run. - model: OpenRouter model to use for interpretation. - limit: Maximum number of components to interpret (highest mean CI first). - """ - _, _, run_id = parse_wandb_run_path(wandb_path) + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + interp_config = AutointerpConfig.model_validate(config_json) load_dotenv() openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") assert openrouter_api_key, "OPENROUTER_API_KEY not set" - activation_contexts_dir = get_activation_contexts_dir(run_id) - assert activation_contexts_dir.exists(), ( - f"Activation contexts not found at {activation_contexts_dir}. Run harvest first." - ) + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=False) + + if autointerp_subrun_id is not None: + subrun_dir = get_autointerp_dir(decomposition_id) / autointerp_subrun_id + assert subrun_dir.exists(), f"Subrun dir not found: {subrun_dir}" + logger.info(f"Resuming existing subrun: {autointerp_subrun_id}") + else: + autointerp_subrun_id = "a-" + datetime.now().strftime("%Y%m%d_%H%M%S") + subrun_dir = get_autointerp_subrun_dir(decomposition_id, autointerp_subrun_id) + subrun_dir.mkdir(parents=True, exist_ok=True) + + # Save config for reproducibility + interp_config.to_file(subrun_dir / "config.yaml") - correlations_dir = get_correlations_dir(run_id) + db_path = subrun_dir / "interp.db" - autointerp_dir = get_autointerp_dir(run_id) - autointerp_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Autointerp run: {subrun_dir}") + + adapter = adapter_from_id(decomposition_id) run_interpret( - wandb_path, - openrouter_api_key, - model, - activation_contexts_dir, - correlations_dir, - autointerp_dir, - limit, + openrouter_api_key=openrouter_api_key, + model=interp_config.model, + reasoning_effort=interp_config.reasoning_effort, + limit=interp_config.limit, + cost_limit_usd=interp_config.cost_limit_usd, + max_requests_per_minute=interp_config.max_requests_per_minute, + model_metadata=adapter.model_metadata, + template_strategy=interp_config.template_strategy, + harvest=harvest, + db_path=db_path, + tokenizer_name=adapter.tokenizer_name, + max_concurrent=interp_config.max_concurrent, + ) + + +def get_command( + decomposition_id: str, + config: AutointerpConfig, + harvest_subrun_id: str, + autointerp_subrun_id: str | None = None, +) -> str: + config_json = config.model_dump_json(exclude_none=True) + cmd = ( + "python -m spd.autointerp.scripts.run_interpret " + f"--decomposition_id {decomposition_id} " + f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " ) + if autointerp_subrun_id is not None: + cmd += f"--autointerp_subrun_id {autointerp_subrun_id} " + return cmd if __name__ == "__main__": diff --git a/spd/autointerp/scripts/run_slurm.py b/spd/autointerp/scripts/run_slurm.py index c7f0091d2..4a1cd0bdc 100644 --- a/spd/autointerp/scripts/run_slurm.py +++ b/spd/autointerp/scripts/run_slurm.py @@ -1,79 +1,122 @@ """SLURM launcher for autointerp pipeline. -Submits interpret jobs to SLURM cluster programmatically. +Autointerp is a functional unit: interpret + label-dependent evals. This module +submits all jobs in the unit with proper dependency chaining. -Usage: - spd-autointerp - spd-autointerp --budget_usd 100 +Dependency graph (depends on a prior harvest merge): + interpret (depends on harvest merge) + ├── detection (depends on interpret) + └── fuzzing (depends on interpret) + +(Intruder eval is label-free and belongs to the harvest functional unit.) """ -from spd.autointerp.interpret import OpenRouterModelName +from dataclasses import dataclass + +from spd.autointerp.config import AutointerpSlurmConfig +from spd.autointerp.scoring.scripts import run_label_scoring +from spd.autointerp.scripts import run_interpret from spd.log import logger -from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job -from spd.utils.wandb_utils import wandb_path_to_url +from spd.utils.slurm import SlurmConfig, SubmitResult, generate_script, submit_slurm_job + +@dataclass +class AutointerpSubmitResult: + interpret_result: SubmitResult + detection_result: SubmitResult | None + fuzzing_result: SubmitResult | None -def launch_interpret_job( - wandb_path: str, - model: OpenRouterModelName, - partition: str, - time: str, - limit: int | None = None, -) -> None: - """Submit interpret job to SLURM (CPU-only, IO-bound). + +def submit_autointerp( + decomposition_id: str, + config: AutointerpSlurmConfig, + harvest_subrun_id: str, + dependency_job_id: str | None = None, + snapshot_branch: str | None = None, +) -> AutointerpSubmitResult: + """Submit the autointerp pipeline to SLURM. + + Submits interpret + eval jobs as a functional unit. All jobs depend on a + prior harvest merge (passed as dependency_job_id). Args: wandb_path: WandB run path for the target decomposition run. - model: OpenRouter model to use for interpretation. - partition: SLURM partition name. - time: Job time limit. - limit: Maximum number of components to interpret (highest mean CI first). + config: Autointerp SLURM configuration. + dependency_job_id: Job to wait for before starting (e.g. harvest merge). + snapshot_branch: Git snapshot branch to use. + + Returns: + AutointerpSubmitResult with interpret, detection, and fuzzing results. """ - job_name = "interpret" - - cmd_parts = [ - "python -m spd.autointerp.scripts.run_interpret", - f'"{wandb_path}"', - f"--model {model.value}", - ] - if limit is not None: - cmd_parts.append(f"--limit {limit}") - interpret_cmd = " \\\n ".join(cmd_parts) - - # Build full command with echoes - full_command = "\n".join( - [ - 'echo "=== Interpret ==="', - f'echo "WANDB_PATH: {wandb_path}"', - f'echo "MODEL: {model.value}"', - 'echo "SLURM_JOB_ID: $SLURM_JOB_ID"', - 'echo "================="', - "", - interpret_cmd, - "", - 'echo "Interpret complete!"', - ] + + # === 1. Interpret job === + interpret_cmd = run_interpret.get_command( + decomposition_id=decomposition_id, + config=config.config, + harvest_subrun_id=harvest_subrun_id, ) - config = SlurmConfig( - job_name=job_name, - partition=partition, - n_gpus=0, # CPU-only job - cpus_per_task=16, # (cluster default is 16cpus/gpu and 15GB memory/cpu. We need the memory) - time=time, - snapshot_branch=None, # Autointerp doesn't use git snapshots - comment=wandb_path_to_url(wandb_path), + interpret_slurm = SlurmConfig( + job_name="spd-interpret", + partition=config.partition, + n_gpus=2, + time=config.time, + snapshot_branch=snapshot_branch, + dependency_job_id=dependency_job_id, + comment=decomposition_id, ) - script_content = generate_script(config, full_command) - result = submit_slurm_job(script_content, "interpret") + script_content = generate_script(interpret_slurm, interpret_cmd) + interpret_result = submit_slurm_job(script_content, "spd-interpret") - logger.section("Interpret job submitted!") + logger.section("Interpret job submitted") logger.values( { - "Job ID": result.job_id, - "WandB path": wandb_path, - "Model": model.value, - "Log": result.log_pattern, - "Script": str(result.script_path), + "Job ID": interpret_result.job_id, + "Decomposition ID": decomposition_id, + "Model": config.config.model, + "Log": interpret_result.log_pattern, } ) + + if config.evals is None: + return AutointerpSubmitResult( + interpret_result=interpret_result, + detection_result=None, + fuzzing_result=None, + ) + + # === 2. Detection + fuzzing scoring (depend on interpret) === + scoring_results: dict[str, SubmitResult] = {} + for scorer in ("detection", "fuzzing"): + scoring_cmd = run_label_scoring.get_command( + decomposition_id, + scorer_type=scorer, + config=config.evals, + harvest_subrun_id=harvest_subrun_id, + ) + eval_slurm = SlurmConfig( + job_name=f"spd-{scorer}", + partition=config.partition, + n_gpus=2, + time=config.evals_time, + snapshot_branch=snapshot_branch, + dependency_job_id=interpret_result.job_id, + ) + eval_script = generate_script(eval_slurm, scoring_cmd) + scoring_result = submit_slurm_job(eval_script, f"spd-{scorer}") + scoring_results[scorer] = scoring_result + + logger.section(f"{scorer.capitalize()} scoring job submitted") + logger.values( + { + "Job ID": scoring_result.job_id, + "Depends on": f"interpret ({interpret_result.job_id})", + "Log": scoring_result.log_pattern, + } + ) + + return AutointerpSubmitResult( + interpret_result=interpret_result, + detection_result=scoring_results["detection"], + fuzzing_result=scoring_results["fuzzing"], + ) diff --git a/spd/autointerp/scripts/run_slurm_cli.py b/spd/autointerp/scripts/run_slurm_cli.py index ddc86c299..56db16499 100644 --- a/spd/autointerp/scripts/run_slurm_cli.py +++ b/spd/autointerp/scripts/run_slurm_cli.py @@ -1,30 +1,28 @@ """CLI entry point for autointerp SLURM launcher. Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. + +Usage: + spd-autointerp + spd-autointerp --config autointerp_config.yaml """ import fire -from spd.settings import DEFAULT_PARTITION_NAME - - -def main( - wandb_path: str, - model: str = "google/gemini-3-flash-preview", - partition: str = DEFAULT_PARTITION_NAME, - time: str = "12:00:00", - limit: int | None = None, -) -> None: - from spd.autointerp.interpret import OpenRouterModelName - from spd.autointerp.scripts.run_slurm import launch_interpret_job - - launch_interpret_job( - wandb_path=wandb_path, - model=OpenRouterModelName(model), - partition=partition, - time=time, - limit=limit, - ) + +def main(decomposition_id: str, config: str, harvest_subrun_id: str) -> None: + """Submit autointerp pipeline (interpret + evals) to SLURM. + + Args: + decomposition_id: ID of the target decomposition run. + config: Path to AutointerpSlurmConfig YAML/JSON. + harvest_subrun_id: Harvest subrun to use (e.g. "h-20260306_120000"). + """ + from spd.autointerp.config import AutointerpSlurmConfig + from spd.autointerp.scripts.run_slurm import submit_autointerp + + slurm_config = AutointerpSlurmConfig.from_file(config) + submit_autointerp(decomposition_id, slurm_config, harvest_subrun_id=harvest_subrun_id) def cli() -> None: diff --git a/spd/autointerp/strategies/__init__.py b/spd/autointerp/strategies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/autointerp/strategies/compact_skeptical.py b/spd/autointerp/strategies/compact_skeptical.py new file mode 100644 index 000000000..f72effd51 --- /dev/null +++ b/spd/autointerp/strategies/compact_skeptical.py @@ -0,0 +1,145 @@ +"""Compact skeptical interpretation strategy. + +Short labels (2-5 words), skeptical tone, structured JSON output. +Extracted from the original prompt_template.py. +""" + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.config import CompactSkepticalConfig +from spd.autointerp.prompt_helpers import ( + DATASET_DESCRIPTIONS, + build_fires_on_examples, +) +from spd.autointerp.schemas import ModelMetadata +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md + +SPD_CONTEXT = ( + "Each component has a causal importance (CI) value per token position. " + "High CI (near 1) = essential, cannot be ablated. Low CI (near 0) = ablatable." +) + + +def format_prompt( + config: CompactSkepticalConfig, + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + input_token_stats: TokenPRLift, + output_token_stats: TokenPRLift, +) -> str: + input_pmi: list[tuple[str, float]] | None = None + output_pmi: list[tuple[str, float]] | None = None + + if config.include_pmi: + input_pmi = [ + (app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top + ] + output_pmi = [ + (app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top + ] + + input_section = _build_input_section(input_token_stats, input_pmi) + output_section = _build_output_section(output_token_stats, output_pmi) + examples_section = build_fires_on_examples(component, app_tok, config.max_examples) + + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) + + layer_desc = model_metadata.layer_descriptions.get(component.layer, component.layer) + + dataset_line = "" + if config.include_dataset_description: + dataset_desc = DATASET_DESCRIPTIONS[model_metadata.dataset_name] + dataset_line = f", dataset: {dataset_desc}" + + forbidden = ", ".join(config.forbidden_words) if config.forbidden_words else "(none)" + + md = Md() + md.p("Label this neural network component.") + + if config.include_spd_context: + md.p(SPD_CONTEXT) + + md.h(2, "Context").bullets( + [ + f"Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line}", + f"Component location: {layer_desc}", + f"Component firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) + + md.h(2, "Token correlations") + md.extend(input_section).extend(output_section) + + md.h(2, "Activation examples (active tokens in <>)") + md.extend(examples_section) + + md.h(2, "Task") + md.p(f"Give a 2-{config.label_max_words} word label for what this component detects.") + md.p( + "Be SKEPTICAL. If you can't identify specific tokens or a tight grammatical " + 'pattern, say "unclear".' + ) + md.p("Rules:") + md.numbered( + [ + 'Good labels name SPECIFIC tokens: "\'the\'", "##ing suffix", "she/her pronouns"', + 'Say "unclear" if: tokens are too varied, pattern is abstract, or evidence is weak', + f"FORBIDDEN words (too vague): {forbidden}", + "Lowercase only", + 'Confidence: "high" = clear, specific pattern with strong evidence; ' + '"medium" = plausible but noisy; "low" = speculative', + ] + ) + md.p( + 'GOOD: "##ed suffix", "\'and\' conjunction", "she/her/hers", "period then capital", "unclear"\n' + 'BAD: "various words and punctuation", "verbs and adjectives", "tokens near commas"' + ) + + return md.build() + + +def _build_input_section( + input_stats: TokenPRLift, + input_pmi: list[tuple[str, float]] | None, +) -> Md: + md = Md() + if input_stats.top_recall: + md.labeled_list( + "**Input tokens with highest recall (most common current tokens when the component is firing)**", + [f"{repr(tok)}: {recall * 100:.0f}%" for tok, recall in input_stats.top_recall[:8]], + ) + if input_stats.top_precision: + md.labeled_list( + "**Input tokens with highest precision (probability the component fires given the current token is X)**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in input_stats.top_precision[:8]], + ) + if input_pmi: + md.labeled_list( + "**Input tokens with highest PMI (pointwise mutual information. Tokens with higher-than-base-rate likelihood of co-occurrence with the component firing)**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in input_pmi[:6]], + ) + return md + + +def _build_output_section( + output_stats: TokenPRLift, + output_pmi: list[tuple[str, float]] | None, +) -> Md: + md = Md() + if output_stats.top_precision: + md.labeled_list( + "**Output precision — of all predicted probability for token X, what fraction is at positions where this component fires?**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in output_stats.top_precision[:10]], + ) + if output_pmi: + md.labeled_list( + "**Output PMI — tokens the model predicts at higher-than-base-rate when this component fires:**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in output_pmi[:6]], + ) + return md diff --git a/spd/autointerp/strategies/dispatch.py b/spd/autointerp/strategies/dispatch.py new file mode 100644 index 000000000..e14ffa52c --- /dev/null +++ b/spd/autointerp/strategies/dispatch.py @@ -0,0 +1,53 @@ +"""Strategy dispatch: routes AutointerpConfig variants to their implementations.""" + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.config import CompactSkepticalConfig, DualViewConfig, StrategyConfig +from spd.autointerp.schemas import ModelMetadata +from spd.autointerp.strategies.compact_skeptical import ( + format_prompt as compact_skeptical_prompt, +) +from spd.autointerp.strategies.dual_view import ( + format_prompt as dual_view_prompt, +) +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData + +INTERPRETATION_SCHEMA = { + "type": "object", + "properties": { + "label": {"type": "string"}, + "confidence": {"type": "string", "enum": ["low", "medium", "high"]}, + "reasoning": {"type": "string"}, + }, + "required": ["label", "confidence", "reasoning"], + "additionalProperties": False, +} + + +def format_prompt( + strategy: StrategyConfig, + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + input_token_stats: TokenPRLift, + output_token_stats: TokenPRLift, +) -> str: + match strategy: + case CompactSkepticalConfig(): + return compact_skeptical_prompt( + strategy, + component, + model_metadata, + app_tok, + input_token_stats, + output_token_stats, + ) + case DualViewConfig(): + return dual_view_prompt( + strategy, + component, + model_metadata, + app_tok, + input_token_stats, + output_token_stats, + ) diff --git a/spd/autointerp/strategies/dual_view.py b/spd/autointerp/strategies/dual_view.py new file mode 100644 index 000000000..6d2263806 --- /dev/null +++ b/spd/autointerp/strategies/dual_view.py @@ -0,0 +1,140 @@ +"""Dual-view interpretation strategy. + +Key differences from compact_skeptical: +- Output token data presented first +- Two example sections: "fires on" (current token) and "produces" (next token) +- Human-readable layer descriptions with position context +- Task framing asks for functional description, not detection label +""" + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.config import DualViewConfig +from spd.autointerp.prompt_helpers import ( + DATASET_DESCRIPTIONS, + build_fires_on_examples, + build_input_section, + build_output_section, + build_says_examples, + density_note, + human_layer_desc, + layer_position_note, +) +from spd.autointerp.schemas import ModelMetadata +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md + + +def format_prompt( + config: DualViewConfig, + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + input_token_stats: TokenPRLift, + output_token_stats: TokenPRLift, +) -> str: + input_pmi: list[tuple[str, float]] | None = None + output_pmi: list[tuple[str, float]] | None = None + + if config.include_pmi: + input_pmi = [ + (app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top + ] + output_pmi = [ + (app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top + ] + + output_section = build_output_section(output_token_stats, output_pmi) + input_section = build_input_section(input_token_stats, input_pmi) + fires_on_examples = build_fires_on_examples(component, app_tok, config.max_examples) + says_examples = build_says_examples(component, app_tok, config.max_examples) + + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) + + canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) + layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) + position_note = layer_position_note(canonical, model_metadata.n_blocks) + dens_note = density_note(component.firing_density) + + context_notes = " ".join(filter(None, [position_note, dens_note])) + + dataset_line = "" + if config.include_dataset_description: + dataset_desc = DATASET_DESCRIPTIONS.get( + model_metadata.dataset_name, model_metadata.dataset_name + ) + dataset_line = f", dataset: {dataset_desc}" + + forbidden_sentence = ( + "FORBIDDEN vague words: " + ", ".join(config.forbidden_words) + ". " + if config.forbidden_words + else "" + ) + + md = Md() + md.p( + "Describe what this neural network component does.\n\n" + "Each component is a learned linear transformation inside a weight matrix. " + "It has an input function (what causes it to fire) and an output function " + "(what tokens it causes the model to produce). These are often different — " + "a component might fire on periods but produce sentence-opening words, or " + "fire on prepositions but produce abstract nouns.\n\n" + "Consider all of the evidence below critically. Token statistics can be noisy, " + "especially for high-density components. The activation examples are sampled " + "and may not be representative. Look for patterns that are consistent across " + "multiple sources of evidence." + ) + + md.h(2, "Context").bullets( + [ + f"Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line}", + f"Component location: {layer_desc}", + f"Component firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) + if context_notes: + md.p(context_notes) + + md.h(2, "Output tokens (what the model produces when this component fires)") + md.extend(output_section) + + md.h(2, "Input tokens (what causes this component to fire)") + md.extend(input_section) + + md.h(2, "Activation examples — where the component fires") + md.p("<> mark tokens where this component is active.") + md.extend(fires_on_examples) + + md.h(2, "Activation examples — what the model produces") + md.p( + "Same examples with <> shifted right by one — " + "showing the token that follows each firing position." + ) + md.extend(says_examples) + + md.h(2, "Task") + md.p( + f"Give a {config.label_max_words}-word-or-fewer label describing this component's " + "function. The label should read like a short description of the job this component " + "does in the network. Use both the input and output evidence." + ) + md.p("Examples of good labels across different component types:") + md.bullets( + [ + '"word stem completion (stems → suffixes)"', + '"closes dialogue with quotation marks"', + '"object pronouns after verbs"', + '"story-ending moral resolution vocabulary"', + '"aquatic scene vocabulary (frog, river, pond)"', + "\"'of course' and abstract nouns after prepositions\"", + ] + ) + md.p( + f'Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only.' + ) + + return md.build() diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md index 30777bc61..f502f8785 100644 --- a/spd/clustering/CLAUDE.md +++ b/spd/clustering/CLAUDE.md @@ -52,7 +52,9 @@ Entry point via `spd-clustering`. Submits clustering runs as SLURM job array, th Performs one clustering run: 1. Load decomposed model from WandB -2. Compute component activations on dataset batch +2. Compute component activations: + - **LM tasks**: Uses `n_tokens` and `n_tokens_per_seq` parameters. Iterates through batches of size `batch_size`, picks `n_tokens_per_seq` random token positions per sequence, collects CI values until `n_tokens` samples gathered. Result: `(n_tokens, C)` per layer. + - **resid_mlp tasks**: Single batch of size `batch_size`, no sequence dimension. Uses `component_activations()` directly. 3. Run merge iteration (greedy MDL-based clustering) 4. Save `MergeHistory` with group assignments per iteration @@ -77,7 +79,7 @@ Computes pairwise distances between clustering runs in an ensemble: ```python ClusteringPipelineConfig # Pipeline settings (n_runs, distances_methods, SLURM config) -ClusteringRunConfig # Single run settings (model_path, batch_size, merge_config) +ClusteringRunConfig # Single run settings (model_path, batch_size, n_tokens, merge_config) MergeConfig # Merge algorithm params (alpha, iters, activation_threshold) ``` @@ -106,6 +108,18 @@ DistancesArray # Float[np.ndarray, "n_iters n_ens n_ens"] - `matching_dist.py` - Optimal matching distance via Hungarian algorithm - `merge_pair_samplers.py` - Strategies for selecting which pair to merge +## Utility Scripts + +**`get_cluster_mapping.py`**: Extracts cluster assignments at a specific iteration from a clustering run, outputs JSON mapping component labels to cluster indices (singletons mapped to `null`). + +```bash +python -m spd.clustering.scripts.get_cluster_mapping /path/to/clustering_run --iteration 299 +``` + +## App Integration + +To make a cluster mapping available in the app's dropdown for a run, add its path to `CANONICAL_RUNS` in `spd/app/frontend/src/lib/registry.ts` under the corresponding run's `clusterMappings` array. + ## Config Files Configs live in `spd/clustering/configs/`: diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index cd6a2b742..2738daf3e 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -1,10 +1,12 @@ from dataclasses import dataclass from functools import cached_property -from typing import Literal, NamedTuple +from typing import Any, Literal, NamedTuple import torch -from jaxtyping import Bool, Float, Float16, Int +from jaxtyping import Bool, Float, Float16 from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm from spd.clustering.consts import ( ActivationsTensor, @@ -13,13 +15,14 @@ ComponentLabels, ) from spd.clustering.util import ModuleFilterFunc +from spd.log import logger from spd.models.component_model import ComponentModel, OutputWithCache def component_activations( model: ComponentModel, device: torch.device | str, - batch: Int[Tensor, "batch_size n_ctx"], + batch: Tensor, ) -> dict[str, ActivationsTensor]: """Get the component activations over a **single** batch.""" causal_importances: dict[str, ActivationsTensor] @@ -29,16 +32,78 @@ def component_activations( cache_type="input", ) - # TODO: !!!IMPORTANT!!! unclear what the right thing from CIOutputs is causal_importances = model.calc_causal_importances( pre_weight_acts=model_output.cache, sampling="continuous", detach_inputs=False, - ).upper_leaky + ).lower_leaky return causal_importances +def collect_activations( + model: ComponentModel, + dataloader: DataLoader[Any], + n_tokens: int, + n_tokens_per_seq: int, + device: torch.device | str, + seed: int, +) -> dict[str, Float[Tensor, "n_tokens C"]]: + """Collect activation samples by picking random tokens per sequence. + + Iterates through batches from dataloader, runs component_activations on each, + then selects n_tokens_per_seq random token positions per sequence. Collects until + n_tokens samples are gathered. + + Args: + model: ComponentModel to get activations from + dataloader: DataLoader yielding batches of sequences + n_tokens: Total number of token samples to collect + n_tokens_per_seq: Number of random token positions to sample per sequence + device: Device to run on + seed: Random seed for reproducible token position selection + """ + rng = torch.Generator().manual_seed(seed) + collected: dict[str, list[Tensor]] = {} + n_collected = 0 + + pbar = tqdm(dataloader, desc="Collecting activations", unit="batch") + for batch_data in pbar: + input_ids = batch_data["input_ids"] + batch_size, n_ctx = input_ids.shape + + activations = component_activations(model=model, batch=input_ids, device=device) + + # Pick n_tokens_per_seq random token positions per sequence + positions = torch.randint(0, n_ctx, (batch_size, n_tokens_per_seq), generator=rng) + batch_indices = torch.arange(batch_size).unsqueeze(1).expand_as(positions) + + for key, act in activations.items(): + # act shape: (batch_size, n_ctx, C) + sampled = act[batch_indices, positions] # (batch_size, n_tokens_per_seq, C) + sampled = sampled.reshape(batch_size * n_tokens_per_seq, -1) + if key not in collected: + collected[key] = [] + collected[key].append(sampled.cpu()) + + n_collected += batch_size * n_tokens_per_seq + pbar.set_postfix(tokens=f"{min(n_collected, n_tokens)}/{n_tokens}") + if n_collected >= n_tokens: + break + + assert n_collected >= n_tokens, ( + f"Dataloader exhausted: collected {n_collected} tokens but needed {n_tokens}" + ) + + logger.info(f"Collected {n_collected} token activations (requested {n_tokens})") + + # Concatenate and truncate to exactly n_tokens. Pop chunks to free memory before next module. + result: dict[str, Float[Tensor, "n_tokens C"]] = {} + for key in list(collected.keys()): + result[key] = torch.cat(collected.pop(key), dim=0)[:n_tokens] + return result + + def compute_coactivatons( activations: ActivationsTensor | BoolActivationsTensor, ) -> ClusterCoactivationShaped: @@ -118,8 +183,11 @@ def filter_dead_components( class ProcessedActivations: """Processed activations after filtering and concatenation""" - activations_raw: dict[str, ActivationsTensor] - "activations after filtering, but prior to concatenation" + module_component_counts: dict[str, int] + "total component count per module (including dead), preserving module order" + + module_alive_counts: dict[str, int] + "alive component count per module, preserving module order" activations: ActivationsTensor "activations after filtering and concatenation" @@ -137,12 +205,10 @@ def validate(self) -> None: @property def n_components_original(self) -> int: - """Total number of components before filtering. equal to the sum of all components in `activations_raw`, or to `n_components_alive + n_components_dead`""" - return sum(act.shape[1] for act in self.activations_raw.values()) + return sum(self.module_component_counts.values()) @property def n_components_alive(self) -> int: - """Number of alive components after filtering. equal to the length of `labels`""" n_alive: int = len(self.labels) assert n_alive + self.n_components_dead == self.n_components_original, ( f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })" @@ -155,7 +221,6 @@ def n_components_alive(self) -> int: @property def n_components_dead(self) -> int: - """Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise""" return len(self.dead_components_lst) if self.dead_components_lst else 0 @cached_property @@ -183,14 +248,23 @@ def get_label_index_alive(self, label: str) -> int: @property def module_keys(self) -> list[str]: - """Get the module keys from the activations_raw""" - return list(self.activations_raw.keys()) + return list(self.module_component_counts.keys()) def get_module_indices(self, module_key: str) -> list[int | None]: - """given a module key, return a list len "num components in that moduel", with int index in alive components, or None if dead""" - num_components: int = self.activations_raw[module_key].shape[1] + """given a module key, return a list len "num components in that module", with int index in alive components, or None if dead""" + num_components: int = self.module_component_counts[module_key] return [self.label_index[f"{module_key}:{i}"] for i in range(num_components)] + def get_module_activations(self) -> dict[str, ActivationsTensor]: + """Reconstruct per-module activation views (alive components only) from the concatenated tensor.""" + result: dict[str, ActivationsTensor] = {} + offset = 0 + for key, n_alive in self.module_alive_counts.items(): + if n_alive > 0: + result[key] = self.activations[:, offset : offset + n_alive] + offset += n_alive + return result + def process_activations( activations: dict[ @@ -198,70 +272,96 @@ def process_activations( Float[Tensor, "samples C"] # (sample x component gate activations) | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) ], - filter_dead_threshold: float = 0.01, + filter_dead_threshold: float, seq_mode: Literal["concat", "seq_mean", None] = None, filter_modules: ModuleFilterFunc | None = None, ) -> ProcessedActivations: - """get back a dict of coactivations, slices, and concated activations + """Concatenate per-module activations and filter dead components. - Args: - activations: Dictionary of activations by module - filter_dead_threshold: Threshold for filtering dead components - seq_mode: How to handle sequence dimension - filter_modules: Function to filter modules - sort_components: Whether to sort components by similarity within each module + Fuses concatenation and filtering into a single pass to avoid holding two full + copies (~2x total components * n_samples) in memory simultaneously. """ # reshape -- special cases for llms # ============================================================ activations_: dict[str, ActivationsTensor] if seq_mode == "concat": - # Concatenate the sequence dimension into the sample dimension activations_ = { key: act.reshape(act.shape[0] * act.shape[1], act.shape[2]) for key, act in activations.items() } elif seq_mode == "seq_mean": - # Take the mean over the sequence dimension activations_ = { key: act.mean(dim=1) if act.ndim == 3 else act for key, act in activations.items() } else: - # Use the activations as they are activations_ = activations - # put the labelled activations into one big matrix and filter them - # ============================================================ - # filter activations for only the modules we want if filter_modules is not None: activations_ = {key: act for key, act in activations_.items() if filter_modules(key)} - # compute the labels and total component count - total_c: int = 0 - labels: ComponentLabels = ComponentLabels(list()) + # First pass: compute per-module component counts and alive masks + module_component_counts: dict[str, int] = {} + alive_masks: dict[str, Bool[Tensor, " c"]] = {} + total_alive = 0 for key, act in activations_.items(): - c: int = act.shape[-1] - labels.extend([f"{key}:{i}" for i in range(c)]) - total_c += c - - # concat the activations - act_concat: ActivationsTensor = torch.cat([activations_[key] for key in activations_], dim=-1) - - # filter dead components - filtered_components: FilteredActivations = filter_dead_components( - activations=act_concat, - labels=labels, - filter_dead_threshold=filter_dead_threshold, - ) - - assert filtered_components.n_alive + filtered_components.n_dead == total_c, ( - f"({filtered_components.n_alive = }) + ({filtered_components.n_dead = }) != ({total_c = })" - ) + c = act.shape[-1] + module_component_counts[key] = c + if filter_dead_threshold > 0: + max_act: Float[Tensor, " c"] = act.max(dim=0).values + alive = max_act >= filter_dead_threshold + alive_masks[key] = alive + total_alive += int(alive.sum().item()) + else: + total_alive += c + + total_c = sum(module_component_counts.values()) + + # Second pass: pre-allocate output and copy alive components one module at a time, + # freeing each module's tensor after copying to keep peak memory ~= 1x total size. + first_act = next(iter(activations_.values())) + n_samples = first_act.shape[0] + dtype = first_act.dtype + act_filtered = torch.empty(n_samples, total_alive, dtype=dtype) + + offset = 0 + alive_labels = ComponentLabels(list()) + dead_labels = ComponentLabels(list()) + module_alive_counts: dict[str, int] = {} + + for key in list(activations_.keys()): + tensor = activations_.pop(key) + c = tensor.shape[-1] + + if filter_dead_threshold > 0: + alive = alive_masks[key] + n_alive = int(alive.sum().item()) + for i in range(c): + label = f"{key}:{i}" + if alive[i]: + alive_labels.append(label) + else: + dead_labels.append(label) + if n_alive > 0: + act_filtered[:, offset : offset + n_alive] = tensor[:, alive] + else: + n_alive = c + alive_labels.extend([f"{key}:{i}" for i in range(c)]) + act_filtered[:, offset : offset + n_alive] = tensor + + module_alive_counts[key] = n_alive + offset += n_alive + del tensor + + assert offset == total_alive + assert list(module_alive_counts.keys()) == list(module_component_counts.keys()) + assert len(alive_labels) + len(dead_labels) == total_c return ProcessedActivations( - activations_raw=activations_, - activations=filtered_components.activations, - labels=filtered_components.labels, - dead_components_lst=filtered_components.dead_components_labels, + module_component_counts=module_component_counts, + module_alive_counts=module_alive_counts, + activations=act_filtered, + labels=alive_labels, + dead_components_lst=dead_labels if dead_labels else None, ) diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index cfc2e5e51..4ab830319 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -43,6 +43,14 @@ class ClusteringRunConfig(BaseConfig): ) batch_size: PositiveInt = Field(..., description="Batch size for processing") + n_tokens: PositiveInt | None = Field( + default=None, + description="Number of token activation samples to collect (LM only)", + ) + n_tokens_per_seq: PositiveInt | None = Field( + default=None, + description="Number of random token positions to sample per sequence (LM only)", + ) dataset_seed: int = Field(0, description="Seed for dataset generation/loading") ensemble_id: str | None = Field( default=None, @@ -59,10 +67,6 @@ class ClusteringRunConfig(BaseConfig): description="WandB project name (None to disable WandB logging)", ) wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") - dataset_streaming: bool = Field( - default=False, - description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", - ) @model_validator(mode="before") def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: diff --git a/spd/clustering/configs/crc/pile_llama_simple_mlp-4L.json b/spd/clustering/configs/crc/pile_llama_simple_mlp-4L.json new file mode 100644 index 000000000..aace3deab --- /dev/null +++ b/spd/clustering/configs/crc/pile_llama_simple_mlp-4L.json @@ -0,0 +1,22 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 10, + "iters": 20000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/s-55ea3f9b", + "batch_size": 128, + "n_tokens": 500000, + "n_tokens_per_seq": 5, + "wandb_project": null, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } + } diff --git a/spd/clustering/configs/crc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json deleted file mode 100644 index e1647b6e4..000000000 --- a/spd/clustering/configs/crc/simplestories_dev.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "merge_config": { - "activation_threshold": 0.1, - "alpha": 1.0, - "iters": 100, - "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.001}, - "filter_dead_threshold": 0.1, - "module_name_filter": null - }, - "model_path": "wandb:goodfire/spd/runs/lxs77xye", - "batch_size": 32, - "wandb_project": null, - "logging_intervals": { - "stat": 1, - "tensor": 200, - "plot": 2000, - "artifact": 2000 - } -} \ No newline at end of file diff --git a/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json b/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json index 687173f6c..9fd4fb9d5 100644 --- a/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json +++ b/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json @@ -9,7 +9,9 @@ "module_name_filter": null }, "model_path": "wandb:goodfire/spd/runs/5cr21lbs", - "batch_size": 2048, + "batch_size": 64, + "n_tokens": 500000, + "n_tokens_per_seq": 10, "wandb_project": null, "logging_intervals": { "stat": 10, @@ -17,4 +19,4 @@ "plot": 2000, "artifact": 2000 } - } \ No newline at end of file + } diff --git a/spd/clustering/configs/crc/ss_llama_simple_mlp-2L.json b/spd/clustering/configs/crc/ss_llama_simple_mlp-2L.json index c6426fe9a..060684c76 100644 --- a/spd/clustering/configs/crc/ss_llama_simple_mlp-2L.json +++ b/spd/clustering/configs/crc/ss_llama_simple_mlp-2L.json @@ -9,7 +9,9 @@ "module_name_filter": null }, "model_path": "wandb:goodfire/spd/runs/itmexlj0", - "batch_size": 2048, + "batch_size": 64, + "n_tokens": 500000, + "n_tokens_per_seq": 10, "wandb_project": null, "logging_intervals": { "stat": 10, @@ -17,4 +19,4 @@ "plot": 2000, "artifact": 2000 } - } \ No newline at end of file + } diff --git a/spd/clustering/configs/crc/ss_llama_simple_mlp.json b/spd/clustering/configs/crc/ss_llama_simple_mlp.json index 6cf534ec5..2d5537f04 100644 --- a/spd/clustering/configs/crc/ss_llama_simple_mlp.json +++ b/spd/clustering/configs/crc/ss_llama_simple_mlp.json @@ -9,7 +9,9 @@ "module_name_filter": null }, "model_path": "wandb:goodfire/spd/runs/vjbol27n", - "batch_size": 512, + "batch_size": 64, + "n_tokens": 4096, + "n_tokens_per_seq": 5, "wandb_project": null, "logging_intervals": { "stat": 10, @@ -17,4 +19,4 @@ "plot": 2000, "artifact": 2000 } - } \ No newline at end of file + } diff --git a/spd/clustering/configs/crc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json deleted file mode 100644 index 911f71529..000000000 --- a/spd/clustering/configs/crc/test-simplestories.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "merge_config": { - "activation_threshold": 0.9, - "alpha": 1.0, - "iters": 5, - "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "filter_dead_threshold": 0.9, - "module_name_filter": "model.layers.0" - }, - "model_path": "wandb:goodfire/spd/runs/lxs77xye", - "batch_size": 1, - "wandb_project": null, - "logging_intervals": { - "stat": 1, - "tensor": 2, - "plot": 3, - "artifact": 4 - } -} \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index eccda019f..80e9c63bc 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,4 +6,4 @@ slurm_partition: null wandb_project: "spd" wandb_entity: "goodfire" create_git_snapshot: false -clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file +clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp.json" \ No newline at end of file diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 6dbcc37eb..fd3bd9d7a 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json" +clustering_run_config_path: "spd/clustering/configs/crc/pile_llama_simple_mlp-4L.json" n_runs: 2 distances_methods: ["perm_invariant_hamming"] slurm_job_name_prefix: "spd" diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c8e86f0fc..b097f4087 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -1,11 +1,12 @@ """Dataset loading utilities for clustering runs. -Each clustering run loads its own dataset batch, seeded by the run index. +Each clustering run loads its own dataset, seeded by the run index. """ from typing import Any -from spd.clustering.consts import BatchTensor +from torch.utils.data import DataLoader + from spd.configs import LMTaskConfig, ResidMLPTaskConfig from spd.data import DatasetConfig, create_data_loader from spd.experiments.resid_mlp.models import ResidMLP @@ -13,16 +14,13 @@ from spd.spd_types import TaskName -def load_dataset( +def create_clustering_dataloader( model_path: str, task_name: TaskName, batch_size: int, seed: int, - **kwargs: Any, -) -> BatchTensor: - """Load a single batch for clustering. - - Each run gets its own dataset batch, seeded by index in ensemble. +) -> DataLoader[Any]: + """Create a dataloader for clustering. Args: model_path: Path to decomposed model @@ -31,31 +29,27 @@ def load_dataset( seed: Random seed for dataset Returns: - Single batch of data + DataLoader yielding batches """ match task_name: case "lm": - return _load_lm_batch( + return _create_lm_dataloader( model_path=model_path, batch_size=batch_size, seed=seed, - **kwargs, ) case "resid_mlp": - return _load_resid_mlp_batch( + return _create_resid_mlp_dataloader( model_path=model_path, batch_size=batch_size, seed=seed, - **kwargs, ) case _: raise ValueError(f"Unsupported task: {task_name}") -def _load_lm_batch( - model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None -) -> BatchTensor: - """Load a batch for language model task.""" +def _create_lm_dataloader(model_path: str, batch_size: int, seed: int) -> DataLoader[Any]: + """Create a dataloader for language model task.""" spd_run = SPDRunInfo.from_path(model_path) cfg = spd_run.config @@ -63,20 +57,6 @@ def _load_lm_batch( f"Expected task_config to be of type LMTaskConfig, but got {type(cfg.task_config) = }" ) - try: - pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] - assert pretrained_model_name is not None - except Exception as e: - raise AttributeError("Could not find 'pretrained_model_name' in the SPD Run config") from e - - config_kwargs_: dict[str, Any] = { - **dict( - is_tokenized=False, - streaming=False, - ), - **(config_kwargs or {}), - } - dataset_config = DatasetConfig( name=cfg.task_config.dataset_name, hf_tokenizer_path=cfg.tokenizer_name, @@ -84,7 +64,8 @@ def _load_lm_batch( n_ctx=cfg.task_config.max_seq_len, seed=seed, # Use run-specific seed column_name=cfg.task_config.column_name, - **config_kwargs_, + is_tokenized=cfg.task_config.is_tokenized, + streaming=cfg.task_config.streaming, ) dataloader, _ = create_data_loader( @@ -94,13 +75,11 @@ def _load_lm_batch( global_seed=seed, # Use run-specific seed ) - # Get first batch - batch = next(iter(dataloader)) - return batch["input_ids"] + return dataloader -def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchTensor: - """Load a batch for ResidMLP task.""" +def _create_resid_mlp_dataloader(model_path: str, batch_size: int, seed: int) -> DataLoader[Any]: + """Create a dataloader for ResidMLP task.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader @@ -128,7 +107,5 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT data_generation_type=cfg.task_config.data_generation_type, ) - # Generate batch dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False) - batch, _ = next(iter(dataloader)) - return batch + return dataloader diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index 81a147f37..3b55c6a8a 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -47,7 +47,6 @@ def plot_activations( if save_dir is not None: save_dir.mkdir(parents=True, exist_ok=True) - act_dict: dict[str, ActivationsTensor] = processed_activations.activations_raw act_concat: ActivationsTensor = processed_activations.activations coact: ClusterCoactivationShaped = compute_coactivatons(act_concat) labels: ComponentLabels = ComponentLabels(processed_activations.labels) @@ -56,9 +55,12 @@ def plot_activations( # trim the activations if n_samples_max is specified # clone here so we don't modify the original tensor act_concat = act_concat[:n_samples_max].clone() - # we don't use the stuff in this dict again, so we can modify it in-place - for key in act_dict: - act_dict[key] = act_dict[key][:n_samples_max] + + # Reconstruct per-module views (alive components only), truncated to n_samples_max + act_dict: dict[str, ActivationsTensor] = { + key: act[:n_samples_max] + for key, act in processed_activations.get_module_activations().items() + } # Update n_samples to reflect the truncated size n_samples = act_concat.shape[0] diff --git a/spd/clustering/scripts/get_cluster_mapping.py b/spd/clustering/scripts/get_cluster_mapping.py index c2212e5cd..1877b2fbf 100644 --- a/spd/clustering/scripts/get_cluster_mapping.py +++ b/spd/clustering/scripts/get_cluster_mapping.py @@ -1,13 +1,12 @@ -"""Extract cluster mapping from an ensemble at a specific iteration. +"""Extract cluster mapping from a clustering run at a specific iteration. Usage: - python -m spd.clustering.scripts.get_cluster_mapping /path/to/ensemble --iteration 299 - python -m spd.clustering.scripts.get_cluster_mapping /path/to/ensemble --iteration 299 --run-idx 0 - python -m spd.clustering.scripts.get_cluster_mapping /path/to/ensemble --iteration 299 --notes "some notes" + python -m spd.clustering.scripts.get_cluster_mapping /path/to/clustering_run --iteration 299 + python -m spd.clustering.scripts.get_cluster_mapping /path/to/clustering_run --iteration 299 --notes "some notes" Output format: { - "ensemble_id": "e-5f228e5f", + "clustering_run_id": "cr-5f228e5f", "notes": "", "spd_run": "spd/goodfire/5cr21lbs", "clusters": {"h.0.mlp.down_proj:1": 0, "h.0.mlp.down_proj:2": null, ...} @@ -22,138 +21,102 @@ import fire import numpy as np -import yaml -from spd.settings import REPO_ROOT +from spd.clustering.merge_history import MergeHistory from spd.utils.wandb_utils import parse_wandb_run_path def get_cluster_mapping( - ensemble_dir: str | Path, - n_iterations: int, - run_idx: int = 0, -) -> dict[str, int | None]: + run_dir: Path, + iteration: int, +) -> tuple[dict[str, int | None], list[str]]: """Get mapping from component labels to cluster indices at a specific iteration. Args: - ensemble_dir: Path to ensemble directory containing ensemble_merge_array.npz - and ensemble_meta.json - n_iterations: Number of iterations to extract clusters from - run_idx: Run index within the ensemble (default 0) + run_dir: Path to clustering run directory containing history.zip + iteration: Iteration index to extract clusters from Returns: - Mapping from component label (e.g. "h.0.mlp.down_proj:42") to cluster index, - or None for singleton clusters (clusters with only one member). + Tuple of (mapping, labels) where mapping maps component label + (e.g. "h.0.mlp.down_proj:42") to cluster index, or None for + singleton clusters (clusters with only one member). """ - ensemble_dir = Path(ensemble_dir) + history_path = run_dir / "history.zip" + assert history_path.exists(), f"History not found: {history_path}" - merge_array_path = ensemble_dir / "ensemble_merge_array.npz" - meta_path = ensemble_dir / "ensemble_meta.json" + history = MergeHistory.read(history_path) - assert merge_array_path.exists(), f"Merge array not found: {merge_array_path}" - assert meta_path.exists(), f"Metadata not found: {meta_path}" - - merge_data = np.load(merge_array_path) - merge_array = merge_data["merge_array"] # shape: (n_runs, n_iterations, n_components) - - with open(meta_path) as f: - meta = json.load(f) - - component_labels: list[str] = meta["component_labels"] - n_runs, n_iterations_stored, n_components = merge_array.shape - - assert 0 <= run_idx < n_runs, f"run_idx {run_idx} out of bounds [0, {n_runs})" - assert 0 <= n_iterations < n_iterations_stored, ( - f"n_iterations {n_iterations} out of bounds [0, {n_iterations_stored})" - ) - assert len(component_labels) == n_components, ( - f"Label count mismatch: {len(component_labels)} labels vs {n_components} components" + assert 0 <= iteration < history.n_iters_current, ( + f"iteration {iteration} out of bounds [0, {history.n_iters_current})" ) - assignments = merge_array[run_idx, n_iterations, :] + merge = history.merges[iteration] + assignments = merge.group_idxs.numpy() + labels = list(history.labels) # Count members per cluster to identify singletons - cluster_ids, counts = np.unique(assignments, return_counts=True) - singleton_clusters = set(cluster_ids[counts == 1]) + unique_ids, counts = np.unique(assignments, return_counts=True) + singleton_clusters = set(unique_ids[counts == 1].tolist()) - return { - label: None if cluster_id in singleton_clusters else int(cluster_id) - for label, cluster_id in zip(component_labels, assignments, strict=True) + mapping = { + label: None if int(cluster_id) in singleton_clusters else int(cluster_id) + for label, cluster_id in zip(labels, assignments, strict=True) } + return mapping, labels -def get_spd_run_path(ensemble_dir: Path) -> str: - """Extract the SPD run path from the ensemble's pipeline config. - Follows pipeline_config.yaml -> clustering_run_config_path -> model_path, - then parses the wandb path. +def get_spd_run_path(run_dir: Path) -> str: + """Extract the SPD run path from the clustering run config. Returns: Formatted path like "spd/goodfire/5cr21lbs" """ - pipeline_config_path = ensemble_dir / "pipeline_config.yaml" - assert pipeline_config_path.exists(), f"Pipeline config not found: {pipeline_config_path}" + config_path = run_dir / "clustering_run_config.json" + assert config_path.exists(), f"Clustering run config not found: {config_path}" - with open(pipeline_config_path) as f: - pipeline_config = yaml.safe_load(f) + with open(config_path) as f: + config = json.load(f) - clustering_run_config_path = REPO_ROOT / pipeline_config["clustering_run_config_path"] - assert clustering_run_config_path.exists(), ( - f"Clustering run config not found: {clustering_run_config_path}" - ) - - with open(clustering_run_config_path) as f: - clustering_run_config = json.load(f) - - model_path = clustering_run_config["model_path"] + model_path = config["model_path"] entity, project, run_id = parse_wandb_run_path(model_path) return f"{entity}/{project}/{run_id}" def main( - ensemble_dir: str, - n_iterations: int, - run_idx: int = 0, + run_dir: str, + iteration: int, notes: str = "", output: str | None = None, ) -> None: """Extract cluster mapping with metadata and output as JSON. Args: - ensemble_dir: Path to ensemble directory - n_iterations: Number of iterations to extract clusters from - run_idx: Run index within the ensemble (default 0) + run_dir: Path to clustering run directory (containing history.zip) + iteration: Iteration index to extract clusters from notes: Optional notes to include in the output output: Optional output file path. If not provided, writes to - {ensemble_dir}/cluster_mapping_{ensemble_id}.json + {run_dir}/cluster_mapping.json """ - ensemble_path = Path(ensemble_dir) + run_path = Path(run_dir) - clusters = get_cluster_mapping( - ensemble_dir=ensemble_dir, - n_iterations=n_iterations, - run_idx=run_idx, - ) + clusters, _ = get_cluster_mapping(run_dir=run_path, iteration=iteration) - ensemble_id = ensemble_path.name - spd_run = get_spd_run_path(ensemble_path) + clustering_run_id = run_path.name + spd_run = get_spd_run_path(run_path) result = { - "ensemble_id": ensemble_id, + "clustering_run_id": clustering_run_id, "notes": notes, "spd_run": spd_run, - "n_iterations": n_iterations, - "run_idx": run_idx, + "iteration": iteration, "clusters": clusters, } json_str = json.dumps(result, indent=2) - if output is None: - out_path = ensemble_path / f"cluster_mapping_{ensemble_id}.json" - else: - out_path = Path(output) + out_path = run_path / "cluster_mapping.json" if output is None else Path(output) out_path.write_text(json_str) print(f"Wrote mapping ({len(clusters)} components) to {out_path}", file=sys.stderr) diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index f76e64d5a..844365681 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -29,17 +29,17 @@ from spd.clustering.activations import ( ProcessedActivations, + collect_activations, component_activations, process_activations, ) from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( ActivationsTensor, - BatchTensor, ClusterCoactivationShaped, ComponentLabels, ) -from spd.clustering.dataset import load_dataset +from spd.clustering.dataset import create_clustering_dataloader from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.math.semilog import semilog @@ -259,23 +259,14 @@ def main(run_config: ClusteringRunConfig) -> Path: spd_run = SPDRunInfo.from_path(run_config.model_path) task_name: TaskName = spd_run.config.task_config.task_name - # 1. Load dataset + # 1. Create dataloader logger.info(f"Loading dataset (seed={run_config.dataset_seed})") - load_dataset_kwargs: dict[str, Any] = dict() - if run_config.dataset_streaming: - logger.info("Using streaming dataset loading") - load_dataset_kwargs["config_kwargs"] = dict(streaming=True) - assert task_name == "lm", ( - f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." - ) - batch: BatchTensor = load_dataset( + dataloader = create_clustering_dataloader( model_path=run_config.model_path, task_name=task_name, batch_size=run_config.batch_size, seed=run_config.dataset_seed, - **load_dataset_kwargs, ) - batch = batch.to(device) # 2. Setup WandB for this run wandb_run: Run | None = None @@ -294,7 +285,6 @@ def main(run_config: ClusteringRunConfig) -> Path: f"assigned_idx:{assigned_idx}", ], ) - # logger.info(f"WandB run: {wandb_run.url}") # 3. Load model logger.info("Loading model") @@ -302,20 +292,33 @@ def main(run_config: ClusteringRunConfig) -> Path: # 4. Compute activations logger.info("Computing activations") - activations_dict: ( - dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] - ) = component_activations( - model=model, - batch=batch, - device=device, - ) + if task_name == "lm": + assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" + assert run_config.n_tokens_per_seq is not None, "n_tokens_per_seq must be set for LM tasks" + activations_dict = collect_activations( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + ) + else: + # resid_mlp: single batch, no sequence dimension + batch_data = next(iter(dataloader)) + batch, _ = batch_data # DatasetGeneratedDataLoader yields (batch, labels) + activations_dict = component_activations( + model=model, + batch=batch, + device=device, + ) # 5. Process activations logger.info("Processing activations") processed_activations: ProcessedActivations = process_activations( activations=activations_dict, filter_dead_threshold=run_config.merge_config.filter_dead_threshold, - seq_mode="concat" if task_name == "lm" else None, + seq_mode=None, filter_modules=run_config.merge_config.filter_modules, ) @@ -336,14 +339,14 @@ def main(run_config: ClusteringRunConfig) -> Path: single=True, ) - # Clean up memory - activations: ActivationsTensor = processed_activations.activations + # Extract what we need, then free the model and temporary objects + activations: ActivationsTensor = processed_activations.activations.to(device) component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) del processed_activations del activations_dict del model - del batch gc.collect() + torch.cuda.empty_cache() # 7. Run merge iteration logger.info("Starting merging") @@ -408,21 +411,13 @@ def cli() -> None: default=None, help="WandB entity name (user or team)", ) - parser.add_argument( - "--dataset-streaming", - action="store_true", - help="Whether to use streaming dataset loading (if supported by the dataset)", - ) - args: argparse.Namespace = parser.parse_args() # Load base config run_config = ClusteringRunConfig.from_file(args.config) # Override config values from CLI - overrides: dict[str, Any] = { - "dataset_streaming": args.dataset_streaming, - } + overrides: dict[str, Any] = {} # Handle ensemble-related overrides if args.pipeline_run_id is not None: diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index c383adb95..d02d94360 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -151,14 +151,12 @@ def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str def generate_clustering_commands( pipeline_config: ClusteringPipelineConfig, pipeline_run_id: str, - dataset_streaming: bool = False, ) -> list[str]: """Generate commands for each clustering run. Args: pipeline_config: Pipeline configuration pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) - dataset_streaming: Whether to use dataset streaming Returns: List of shell-safe command strings @@ -180,8 +178,6 @@ def generate_clustering_commands( "--wandb-entity", pipeline_config.wandb_entity, ] - if dataset_streaming: - cmd_parts.append("--dataset-streaming") commands.append(shlex.join(cmd_parts)) @@ -222,7 +218,6 @@ def main( local: bool = False, local_clustering_parallel: bool = False, local_calc_distances_parallel: bool = False, - dataset_streaming: bool = False, track_resources_calc_distances: bool = False, ) -> None: """Submit clustering runs to SLURM. @@ -272,7 +267,6 @@ def main( clustering_commands = generate_clustering_commands( pipeline_config=pipeline_config, pipeline_run_id=pipeline_run_id, - dataset_streaming=dataset_streaming, ) # Generate commands for calculating distances @@ -440,11 +434,6 @@ def cli(): action="store_true", help="If running locally, whether to track resource usage during distance calculations", ) - parser.add_argument( - "--dataset-streaming", - action="store_true", - help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", - ) args = parser.parse_args() @@ -467,7 +456,6 @@ def cli(): main( pipeline_config=pipeline_config, local=args.local, - dataset_streaming=args.dataset_streaming, local_clustering_parallel=args.local_clustering_parallel, local_calc_distances_parallel=args.local_calc_distances_parallel, track_resources_calc_distances=args.track_resources_calc_distances, diff --git a/spd/configs.py b/spd/configs.py index 14ddd316d..efbfb3bb4 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -13,7 +13,162 @@ from spd.base_config import BaseConfig from spd.log import logger -from spd.spd_types import CiFnType, ModelPath, Probability +from spd.spd_types import GlobalCiFnType, LayerwiseCiFnType, ModelPath, Probability + + +class LayerwiseCiConfig(BaseConfig): + """Configuration for layerwise CI functions (one per layer).""" + + mode: Literal["layerwise"] = "layerwise" + fn_type: LayerwiseCiFnType = Field( + ..., description="Type of layerwise CI function: mlp, vector_mlp, or shared_mlp" + ) + hidden_dims: list[NonNegativeInt] = Field( + ..., description="Hidden dimensions for the CI function MLP" + ) + + +class BlockGroupConfig(BaseConfig): + """Defines a group of modules processed together in global reverse residual CI. + + Modules within a block have their activations concatenated, projected to the residual + stream dimension, and processed together by a single reader network. + """ + + name: str = Field(..., description="Block identifier (e.g. 'unembed', 'layer_2_mlp')") + patterns: list[str] = Field( + ..., + description="Module patterns for this block (fnmatch-style, e.g. ['layers.2.mlp_*'])", + ) + + +class AttnConfig(BaseConfig): + """Configuration for self-attention. + + Uses RoPE (Rotary Position Embeddings) for sequence length generalization. + """ + + n_heads: PositiveInt = Field( + ..., + description="Number of attention heads. Must divide the input dimension.", + ) + max_len: PositiveInt = Field( + default=2048, + description="Maximum sequence length for RoPE embeddings.", + ) + rope_base: float = Field( + default=10000.0, + description="Base for RoPE frequency computation.", + ) + + +class GlobalSharedTransformerCiConfig(BaseConfig): + d_model: PositiveInt + n_blocks: PositiveInt + mlp_hidden_dim: list[NonNegativeInt] = Field( + description="Hidden dimension for transformer MLP blocks. " + "If None, defaults to [4 * d_model].", + ) + attn_config: AttnConfig + + @model_validator(mode="after") + def validate_config(self) -> Self: + assert self.d_model % self.attn_config.n_heads == 0, ( + f"d_model ({self.d_model}) must be divisible by " + f"attn_config.n_heads ({self.attn_config.n_heads})" + ) + d_head = self.d_model // self.attn_config.n_heads + assert d_head % 2 == 0, ( + f"d_head ({d_head}) must be even for RoPE. " + f"d_model={self.d_model}, " + f"n_heads={self.attn_config.n_heads}" + ) + return self + + +class GlobalCiConfig(BaseConfig): + """Configuration for global CI function (single function for all layers). + + For fn_type='global_shared_mlp': Concatenates all activations, processes through MLP. + For fn_type='global_reverse_residual': Processes blocks in reverse order with residual stream. + For fn_type='global_shared_transformer': Concatenates activations, projects to shared d_model, + and applies transformer blocks over the sequence dimension. + """ + + mode: Literal["global"] = "global" + fn_type: GlobalCiFnType = Field( + ..., + description="Type of global CI function: global_shared_mlp, " + "global_reverse_residual, or global_shared_transformer", + ) + hidden_dims: list[NonNegativeInt] | None = Field( + default=None, + description="Hidden dimensions for global_shared_mlp CI function. " + "Use reader_hidden_dims for global_reverse_residual.", + ) + reader_hidden_dims: list[NonNegativeInt] | None = Field( + default=None, + description="Hidden dimensions for reader MLPs in global_reverse_residual. " + "Required when fn_type='global_reverse_residual', ignored otherwise.", + ) + d_resid_ci_fn: PositiveInt | None = Field( + default=None, + description="Residual stream dimension for global_reverse_residual. " + "Required when fn_type='global_reverse_residual', ignored otherwise.", + ) + block_groups: list[BlockGroupConfig] | None = Field( + default=None, + description="Ordered list of block groups for global_reverse_residual. " + "Order determines processing sequence (first = processed first, typically unembed). " + "Required when fn_type='global_reverse_residual', ignored otherwise.", + ) + transition_attn_config: AttnConfig | None = Field( + default=None, + description="Self-attention config for transitions in global_reverse_residual. " + "If None, uses MLP-only transitions (original behavior). " + "Only applies when fn_type='global_reverse_residual'.", + ) + transition_hidden_dim: PositiveInt | None = Field( + default=None, + description="Hidden dimension for transition MLP in global_reverse_residual. " + "MLP structure: d_resid_ci_fn -> transition_hidden_dim -> d_resid_ci_fn with GeLU. " + "Required when fn_type='global_reverse_residual', ignored otherwise.", + ) + simple_transformer_ci_cfg: GlobalSharedTransformerCiConfig | None = None + + @model_validator(mode="after") + def validate_ci_config(self) -> Self: + if self.fn_type == "global_reverse_residual": + assert self.d_resid_ci_fn is not None, ( + "d_resid_ci_fn must be specified when fn_type='global_reverse_residual'" + ) + assert self.block_groups is not None and len(self.block_groups) > 0, ( + "block_groups must be specified with at least one block when " + "fn_type='global_reverse_residual'" + ) + assert self.reader_hidden_dims is not None, ( + "reader_hidden_dims must be specified when fn_type='global_reverse_residual'" + ) + if self.transition_attn_config is not None: + assert self.d_resid_ci_fn % self.transition_attn_config.n_heads == 0, ( + f"d_resid_ci_fn ({self.d_resid_ci_fn}) must be divisible by " + f"transition_attn_config.n_heads ({self.transition_attn_config.n_heads})" + ) + elif self.fn_type == "global_shared_mlp": + assert self.hidden_dims is not None, ( + "hidden_dims must be specified when fn_type='global_shared_mlp'" + ) + assert self.transition_attn_config is None, ( + "transition_attn_config is only valid for global_reverse_residual" + ) + elif self.fn_type == "global_shared_transformer": + assert self.simple_transformer_ci_cfg is not None, ( + "simple_transformer_ci_cfg must be specified when fn_type='global_shared_transformer'" + ) + return self + + +CiConfig = LayerwiseCiConfig | GlobalCiConfig class ScheduleConfig(BaseConfig): @@ -146,6 +301,10 @@ class LMTaskConfig(BaseConfig): default=False, description="Whether to use a streaming dataset", ) + dataset_seed: int | None = Field( + default=None, + description="Seed for dataset shuffling/sampling. When None, uses the global `seed`.", + ) class ModulePatternInfoConfig(BaseConfig): @@ -184,7 +343,15 @@ class ImportanceMinimalityLossConfig(LossMetricConfig): @model_validator(mode="before") @classmethod - def default_beta(cls, data: dict[str, Any]) -> dict[str, Any]: + def migrate_old_fields(cls, data: dict[str, Any]) -> dict[str, Any]: + # Migrate pnorm_1 to pnorm (intermediate format) + if "pnorm_1" in data and "pnorm" not in data: + data["pnorm"] = data.pop("pnorm_1") + elif "pnorm_1" in data: + data.pop("pnorm_1") + # Remove deprecated pnorm_2 + data.pop("pnorm_2", None) + # Default beta if missing if "beta" not in data: logger.warning("beta not in ImportanceMinimalityLossConfig, defaulting to 0.0") data["beta"] = 0.0 @@ -282,10 +449,190 @@ class PGDMultiBatchReconSubsetLossConfig(PGDMultiBatchConfig): ] +class SignPGDConfig(BaseConfig): + type: Literal["sign"] = "sign" + lr_schedule: ScheduleConfig + + @model_validator(mode="before") + @classmethod + def migrate_step_size(cls, data: Any) -> Any: + if isinstance(data, dict) and "step_size" in data and "lr_schedule" not in data: + data["lr_schedule"] = { + "start_val": data.pop("step_size"), + "warmup_pct": 0.0, + "final_val_frac": 1.0, + "fn_type": "constant", + } + return data + + +class AdamPGDConfig(BaseConfig): + type: Literal["adam"] = "adam" + beta1: Probability = Field(default=0.9, description="Adam beta1 for masks") + beta2: Probability = Field(default=0.999, description="Adam beta2 for masks") + eps: NonNegativeFloat = Field(default=1e-8, description="Adam epsilon for masks") + lr_schedule: ScheduleConfig + + @model_validator(mode="before") + @classmethod + def migrate_lr(cls, data: Any) -> Any: + if isinstance(data, dict) and "lr" in data and "lr_schedule" not in data: + data["lr_schedule"] = { + "start_val": data.pop("lr"), + "warmup_pct": 0.0, + "final_val_frac": 1.0, + "fn_type": "constant", + } + return data + + +PGDOptimizerConfig = SignPGDConfig | AdamPGDConfig + + +class SingleSourceScope(BaseConfig): + type: Literal["single_source"] = "single_source" + + +class BroadcastAcrossBatchScope(BaseConfig): + type: Literal["broadcast_across_batch"] = "broadcast_across_batch" + + +class RepeatAcrossBatchScope(BaseConfig): + """Sources of shape (N, S, C) where N divides both batch_size and eval_batch_size. + + Repeated along batch dim at forward time: (N, S, C) -> (B, S, C). + """ + + type: Literal["repeat_across_batch"] = "repeat_across_batch" + n_sources: PositiveInt + + +class PerBatchPerPositionScope(BaseConfig): + """Sources of shape (B, S, C) — one source per batch element per position, separate across + ranks. + + Unlike other scopes, gradients are NOT all-reduced across ranks, so each rank + maintains fully independent sources for its own batch elements. + """ + + type: Literal["per_batch_per_position"] = "per_batch_per_position" + + +PersistentPGDSourceScope = Annotated[ + SingleSourceScope + | BroadcastAcrossBatchScope + | RepeatAcrossBatchScope + | PerBatchPerPositionScope, + Field(discriminator="type"), +] + + +def _coerce_ppgd_scope(config_dict: dict[str, Any]) -> None: + """Backwards compat: migrate old scope format/names to current names.""" + scope = config_dict.get("scope") + if isinstance(scope, str): + scope = {"type": scope} + config_dict["scope"] = scope + if not isinstance(scope, dict): + return + match scope.get("type"): + case "single_mask": + scope["type"] = "single_source" + case "batch_invariant": + scope["type"] = "repeat_across_batch" + if "n_masks" in scope: + scope["n_sources"] = scope.pop("n_masks") + case "per_batch" | "unique_per_batch_per_token": + scope["type"] = "per_batch_per_position" + case _: + pass + + +class _PersistentPGDBaseConfig(LossMetricConfig): + """Shared fields for persistent PGD configs. + + Persistent PGD maintains persistent masks that receive one gradient update per training step, + amortizing PGD optimization across training. + """ + + optimizer: Annotated[PGDOptimizerConfig, Field(discriminator="type")] + scope: PersistentPGDSourceScope + use_sigmoid_parameterization: bool = False + n_warmup_steps: Annotated[ + NonNegativeInt, + Field( + description="Number of additional inner PGD source-optimization steps to run on each " + "batch before the final loss computation. Each training step always performs one PPGD " + "source update (grad + step) as part of the outer loop; these warmup steps add extra " + "source refinement iterations on the same batch in an inner loop beforehand." + ), + ] = 0 + + @model_validator(mode="before") + @classmethod + def _compat_scope(cls, data: Any) -> Any: + if isinstance(data, dict): + _coerce_ppgd_scope(data) + return data + + +class PersistentPGDReconLossConfig(_PersistentPGDBaseConfig): + classname: Literal["PersistentPGDReconLoss"] = "PersistentPGDReconLoss" + + +class PersistentPGDReconSubsetLossConfig(_PersistentPGDBaseConfig): + classname: Literal["PersistentPGDReconSubsetLoss"] = "PersistentPGDReconSubsetLoss" + routing: Annotated[ + SubsetRoutingType, Field(discriminator="type", default=UniformKSubsetRoutingConfig()) + ] + + class StochasticHiddenActsReconLossConfig(LossMetricConfig): classname: Literal["StochasticHiddenActsReconLoss"] = "StochasticHiddenActsReconLoss" +class CIHiddenActsReconLossConfig(BaseConfig): + classname: Literal["CIHiddenActsReconLoss"] = "CIHiddenActsReconLoss" + + +class PersistentPGDReconEvalConfig(BaseConfig): + classname: Literal["PersistentPGDReconEval"] = "PersistentPGDReconEval" + + +class PersistentPGDReconSubsetEvalConfig(BaseConfig): + classname: Literal["PersistentPGDReconSubsetEval"] = "PersistentPGDReconSubsetEval" + + +class _AttnPatternsReconLossBaseConfig(BaseConfig): + """Attention pattern reconstruction loss config. + + Supports standard attention and RoPE attention (auto-detected from the parent attention + module). Models using ALiBi, QK-norm, sliding window, etc. are not supported. + """ + + n_heads: int + q_proj_path: str | None = None + k_proj_path: str | None = None + c_attn_path: str | None = None + + @model_validator(mode="after") + def _validate_paths(self) -> Self: + has_separate = self.q_proj_path is not None and self.k_proj_path is not None + has_combined = self.c_attn_path is not None + assert has_separate != has_combined, ( + "Specify either (q_proj_path, k_proj_path) or c_attn_path, not both/neither" + ) + return self + + +class CIMaskedAttnPatternsReconLossConfig(_AttnPatternsReconLossBaseConfig): + classname: Literal["CIMaskedAttnPatternsReconLoss"] = "CIMaskedAttnPatternsReconLoss" + + +class StochasticAttnPatternsReconLossConfig(_AttnPatternsReconLossBaseConfig): + classname: Literal["StochasticAttnPatternsReconLoss"] = "StochasticAttnPatternsReconLoss" + + #### Metrics that can only be used in eval #### class CEandKLLossesConfig(BaseConfig): classname: Literal["CEandKLLosses"] = "CEandKLLosses" @@ -352,22 +699,29 @@ class UVPlotsConfig(BaseConfig): | PGDReconSubsetLossConfig | PGDReconLayerwiseLossConfig | StochasticHiddenActsReconLossConfig + | PersistentPGDReconLossConfig + | PersistentPGDReconSubsetLossConfig ) LossMetricConfigType = FaithfulnessLossConfig | ImportanceMinimalityLossConfig | ReconLossConfigType EvalOnlyMetricConfigType = ( CEandKLLossesConfig + | CIHiddenActsReconLossConfig | CIHistogramsConfig | CI_L0Config | CIMeanPerComponentConfig | ComponentActivationDensityConfig | IdentityCIErrorConfig + | PersistentPGDReconEvalConfig + | PersistentPGDReconSubsetEvalConfig | PermutedCIPlotsConfig | UVPlotsConfig | StochasticReconSubsetCEAndKLConfig | PGDMultiBatchReconLossConfig | PGDMultiBatchReconSubsetLossConfig + | CIMaskedAttnPatternsReconLossConfig + | StochasticAttnPatternsReconLossConfig ) MetricConfigType = LossMetricConfigType | EvalOnlyMetricConfigType @@ -392,7 +746,10 @@ class Config(BaseConfig): ) # --- General --- - seed: int = Field(default=0, description="Random seed for reproducibility") + seed: int = Field( + default=0, + description="Random seed for reproducibility. Does not affect dataset shuffling if dataset_seed is set in TaskConfig.", + ) autocast_bf16: bool = Field( default=True, description="Whether to use torch.autocast with bfloat16 mixed precision", @@ -401,13 +758,11 @@ class Config(BaseConfig): ..., description="Number of stochastic masks to sample when using stochastic recon losses", ) - ci_fn_type: CiFnType = Field( - default="vector_mlp", - description="Type of causal importance function used to calculate the causal importance.", - ) - ci_fn_hidden_dims: list[NonNegativeInt] = Field( - default=[8], - description="Hidden dimensions for the causal importance function used to calculate the causal importance", + ci_config: CiConfig = Field( + ..., + discriminator="mode", + description="Configuration for the causal importance function. " + "Use LayerwiseCiConfig for per-layer CI functions or GlobalCiConfig for a single global CI function.", ) sampling: SamplingType = Field( default="continuous", @@ -446,6 +801,11 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]: return result + init_spd_checkpoint: str | None = Field( + default=None, + description="Path to a .pth checkpoint from a prior SPD run for component/CI initialization", + ) + use_delta_component: bool = Field( default=True, description="If True, use an extra component containing the difference between the target " @@ -604,8 +964,6 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]: "pretrained_model_name_hf": "pretrained_model_name", "recon_coeff": "ci_recon_coeff", "recon_layerwise_coeff": "ci_recon_layerwise_coeff", - "gate_type": "ci_fn_type", - "gate_hidden_dims": "ci_fn_hidden_dims", } @model_validator(mode="before") @@ -616,6 +974,7 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, config_dict.pop("eval_metrics", None) cls._migrate_to_module_info(config_dict) + cls._migrate_to_ci_config(config_dict) migrate_to_lr_schedule_config(config_dict) for key in list(config_dict.keys()): @@ -673,6 +1032,44 @@ def _migrate_to_module_info(cls, config_dict: dict[str, Any]) -> None: {"module_pattern": p, "C": global_c} for p in identity_patterns ] + @classmethod + def _migrate_to_ci_config(cls, config_dict: dict[str, Any]) -> None: + """Migrate old ci_fn_type/ci_fn_hidden_dims/use_global_ci to new ci_config structure.""" + has_old_fields = ( + "ci_fn_type" in config_dict + or "ci_fn_hidden_dims" in config_dict + or "use_global_ci" in config_dict + ) + if not has_old_fields: + return + + logger.info( + "Migrating old ci_fn_type/ci_fn_hidden_dims/use_global_ci to ci_config structure" + ) + + ci_fn_type = config_dict.pop("ci_fn_type", "vector_mlp") + ci_fn_hidden_dims = config_dict.pop("ci_fn_hidden_dims", [8]) + use_global_ci = config_dict.pop("use_global_ci", False) + + # Determine if this is a global CI function + is_global = use_global_ci or ci_fn_type.startswith("global_") + + if is_global: + # Map layerwise type to global type if use_global_ci was set + if not ci_fn_type.startswith("global_"): + ci_fn_type = "global_shared_mlp" + config_dict["ci_config"] = { + "mode": "global", + "fn_type": ci_fn_type, + "hidden_dims": ci_fn_hidden_dims, + } + else: + config_dict["ci_config"] = { + "mode": "layerwise", + "fn_type": ci_fn_type, + "hidden_dims": ci_fn_hidden_dims, + } + @model_validator(mode="after") def validate_model(self) -> Self: assert self.slow_eval_freq % self.eval_freq == 0, ( @@ -685,4 +1082,12 @@ def validate_model(self) -> Self: for cfg in self.loss_metric_configs: assert cfg.coeff is not None, "All loss_metric_configs must have a coeff" + if any( + isinstance(cfg, PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig) + for cfg in self.loss_metric_configs + ): + assert isinstance(self.task_config, LMTaskConfig), ( + "Persistent PGD losses are only supported with LM tasks" + ) + return self diff --git a/spd/data.py b/spd/data.py index b1ed33a62..ad3b18f4c 100644 --- a/spd/data.py +++ b/spd/data.py @@ -304,6 +304,7 @@ def train_loader_and_tokenizer( streaming=task_config.streaming, column_name=task_config.column_name, shuffle_each_epoch=task_config.shuffle_each_epoch, + seed=task_config.dataset_seed, ) train_loader, tokenizer = create_data_loader( diff --git a/spd/dataset_attributions/CLAUDE.md b/spd/dataset_attributions/CLAUDE.md index 16a378d5d..2e2c0bd4c 100644 --- a/spd/dataset_attributions/CLAUDE.md +++ b/spd/dataset_attributions/CLAUDE.md @@ -5,148 +5,111 @@ Multi-GPU pipeline for computing component-to-component attribution strengths ag ## Usage (SLURM) ```bash -# Process specific number of batches spd-attributions --n_batches 1000 --n_gpus 8 - -# Process entire training dataset (omit --n_batches) -spd-attributions --n_gpus 24 - -# With optional parameters -spd-attributions --n_batches 1000 --n_gpus 8 \ - --batch_size 64 --ci_threshold 1e-6 --time 48:00:00 +spd-attributions --n_gpus 24 # whole dataset ``` The command: -1. Creates a git snapshot branch for reproducibility (jobs may be queued) -2. Submits a SLURM job array with N tasks (one per GPU) +1. Creates a git snapshot branch for reproducibility +2. Submits a SLURM job array (one per GPU) 3. Each task processes batches where `batch_idx % world_size == rank` -4. Submits a merge job (depends on array completion) that combines all worker results - -**Note**: `--n_batches` is optional. If omitted, the pipeline processes the entire training dataset. +4. Submits a merge job (depends on array completion) ## Usage (non-SLURM) -For environments without SLURM, run the worker script directly: - ```bash -# Single GPU with specific number of batches -python -m spd.dataset_attributions.scripts.run --n_batches 1000 - -# Single GPU processing entire dataset (omit --n_batches) -python -m spd.dataset_attributions.scripts.run - -# Multi-GPU (run in parallel via shell, tmux, etc.) -python -m spd.dataset_attributions.scripts.run --n_batches 1000 --rank 0 --world_size 4 & -python -m spd.dataset_attributions.scripts.run --n_batches 1000 --rank 1 --world_size 4 & -python -m spd.dataset_attributions.scripts.run --n_batches 1000 --rank 2 --world_size 4 & -python -m spd.dataset_attributions.scripts.run --n_batches 1000 --rank 3 --world_size 4 & +# Single GPU +python -m spd.dataset_attributions.scripts.run_worker + +# Multi-GPU +SUBRUN="da-$(date +%Y%m%d_%H%M%S)" +python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 1000}' --rank 0 --world_size 4 --subrun_id $SUBRUN & +python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 1000}' --rank 1 --world_size 4 --subrun_id $SUBRUN & +# ... wait - -# Merge results after all workers complete -python -m spd.dataset_attributions.scripts.run --merge +python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id $SUBRUN ``` -Each worker processes batches where `batch_idx % world_size == rank`, then the merge step combines all partial results. - ## Data Storage -Data is stored in `SPD_OUT_DIR/dataset_attributions/` (see `spd/settings.py`): - ``` SPD_OUT_DIR/dataset_attributions// -├── dataset_attributions.pt # Final merged attributions -└── dataset_attributions_rank_*.pt # Per-worker results (cleaned up after merge) +├── da-20260223_183250/ # sub-run (latest picked by repo) +│ ├── dataset_attributions.pt # merged result +│ └── worker_states/ +│ └── dataset_attributions_rank_*.pt ``` -## Architecture +`AttributionRepo.open(run_id)` loads the latest `da-*` subrun that has a `dataset_attributions.pt`. -### SLURM Launcher (`scripts/run_slurm.py`, `scripts/run_slurm_cli.py`) +## Attribution Metrics -Entry point via `spd-attributions`. Submits array job + dependent merge job. +Two metrics: `AttrMetric = Literal["attr", "attr_abs"]` -### Worker Script (`scripts/run.py`) +| Metric | Formula | Description | +|--------|---------|-------------| +| `attr` | E[∂y/∂x · x] | Signed mean attribution | +| `attr_abs` | E[∂\|y\|/∂x · x] | Attribution to absolute value of target (2 backward passes) | -Internal script called by SLURM jobs. Supports: -- `--rank R --world_size N`: Process subset of batches -- `--merge`: Combine per-rank results into final file +Naming convention: modifier *before* `attr` applies to the target (e.g. `attr_abs` = attribution to |target|). -### Harvest Logic (`harvest.py`) +## Architecture -Main harvesting functions: -- `harvest_attributions()`: Process batches for a single rank -- `merge_attributions()`: Combine results from all ranks +### Storage (`storage.py`) -### Attribution Harvester (`harvester.py`) +`DatasetAttributionStorage` stores four structurally distinct edge types: -Core class that accumulates attribution strengths using gradient × activation formula: +| Edge type | Fields | Shape | Has abs? | +|-----------|--------|-------|----------| +| component → component | `regular_attr`, `regular_attr_abs` | `dict[target, dict[source, (tgt_c, src_c)]]` | yes | +| embed → component | `embed_attr`, `embed_attr_abs` | `dict[target, (tgt_c, vocab)]` | yes | +| component → unembed | `unembed_attr` | `dict[source, (d_model, src_c)]` | no | +| embed → unembed | `embed_unembed_attr` | `(d_model, vocab)` | no | -``` -attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] -``` +All layer names use **canonical addressing** (`"embed"`, `"0.glu.up"`, `"output"`). -Key optimizations: -1. Sum outputs over positions before gradients (reduces backward passes) -2. For output targets, store attributions to output residual stream instead of vocab tokens (reduces storage from O((V+C)²) to O((V+C)×(C+d_model))) +Unembed edges are stored in residual space (d_model dimensions). `w_unembed` is stored alongside the attribution data, so output token attributions are computed on-the-fly internally — callers never need to provide the projection matrix. No abs variant for unembed edges because abs is a nonlinear operation incompatible with residual-space storage. -### Storage (`storage.py`) +**Normalization**: `normed[t, s] = raw[t, s] / source_denom[s] / target_rms[t]`. Component sources use `ci_sum[s]` as denominator, embed sources use `embed_token_count[s]` (per-token occurrence count). This puts both source types on comparable per-occurrence scales. -`DatasetAttributionStorage` class using output-residual-based storage for scalability. +Key methods: `get_top_sources(key, k, sign, metric)`, `get_top_targets(key, k, sign, metric)`. Both return `[]` for nonexistent components. `merge(paths)` classmethod for combining worker results via weighted average by n_tokens. -**Storage structure:** -- `source_to_component`: (n_sources, n_components) - direct attributions to component targets -- `source_to_out_residual`: (n_sources, d_model) - attributions to output residual stream for output queries +### Harvester (`harvester.py`) -**Source indexing (rows):** -- `[0, vocab_size)`: wte tokens -- `[vocab_size, vocab_size + n_components)`: component layers +Accumulates attributions using gradient × activation. Uses **concrete module paths** internally (talks to model cache/CI). Four accumulator groups mirror the storage edge types. Key optimizations: +1. Sum outputs over positions before gradients (reduces backward passes) +2. Output-residual storage (O(d_model) instead of O(vocab)) +3. `scatter_add_` for embed sources, vectorized `.add_()` for components (>14x faster than per-element loops) -**Target handling:** -- Component targets: direct lookup in `source_to_component` -- Output targets: computed on-the-fly via `source_to_out_residual @ w_unembed[:, token_id]` +### Harvest (`harvest.py`) -**Why output-residual-based storage?** +Orchestrates the pipeline: loads model, builds gradient connectivity, runs batches, translates concrete→canonical at storage boundary via `topology.target_to_canon()`. -For large vocab models (V=32K), the naive approach would require O((V+C)²) storage (~4 GB). -The output-residual-based approach requires only O((V+C)×(C+d)) storage (~670 MB for Llama-scale), -a 6.5x reduction. Output attributions are computed on-the-fly at query time with negligible latency. +### Scripts -### Loaders (`loaders.py`) +- `scripts/run_worker.py` — worker entrypoint (single GPU) +- `scripts/run_merge.py` — merge entrypoint (CPU only, needs ~200G RAM) +- `scripts/run_slurm.py` — SLURM launcher (array + merge jobs) +- `scripts/run_slurm_cli.py` — CLI wrapper for `spd-attributions` -```python -from spd.dataset_attributions.loaders import load_dataset_attributions +### Config (`config.py`) -storage = load_dataset_attributions(run_id) -if storage: - # Get top sources attributing to a component (no w_unembed needed) - top_sources = storage.get_top_sources("h.0.mlp.c_fc:5", k=10, sign="positive") +- `DatasetAttributionConfig`: n_batches, batch_size, ci_threshold +- `AttributionsSlurmConfig`: adds n_gpus, partition, time, merge_time, merge_mem (default 200G) - # Get top component targets (no w_unembed needed) - top_comp_targets = storage.get_top_component_targets("h.0.mlp.c_fc:5", k=10, sign="positive") +### Repository (`repo.py`) - # Get top targets including outputs (requires w_unembed) - w_unembed = model.target_model.lm_head.weight.T.detach() - top_targets = storage.get_top_targets("h.0.mlp.c_fc:5", k=10, sign="positive", w_unembed=w_unembed) +`AttributionRepo.open(run_id)` → loads latest subrun. Returns `None` if no data. - # Get top output targets only (requires w_unembed) - top_outputs = storage.get_top_output_targets("h.0.mlp.c_fc:5", k=10, sign="positive", w_unembed=w_unembed) -``` +## Query Methods -## Key Types +All query methods take `metric: AttrMetric` (`"attr"` or `"attr_abs"`). -```python -DatasetAttributionStorage # Main storage class with split matrices -DatasetAttributionEntry # Single entry: component_key, layer, component_idx, value -DatasetAttributionConfig # Config: wandb_path, n_batches, batch_size, ci_threshold -``` +| Method | Description | +|--------|-------------| +| `get_top_sources(target_key, k, sign, metric)` | Top sources → target | +| `get_top_targets(source_key, k, sign, metric)` | Top targets ← source | -## Query Methods +Key format: `"embed:{token_id}"`, `"0.glu.up:{c_idx}"`, `"output:{token_id}"`. -| Method | w_unembed required? | Description | -|--------|---------------------|-------------| -| `get_top_sources(component_key, k, sign)` | No | Top sources → component target | -| `get_top_sources(output_key, k, sign, w_unembed)` | Yes | Top sources → output token | -| `get_top_component_targets(source_key, k, sign)` | No | Top component targets | -| `get_top_output_targets(source_key, k, sign, w_unembed)` | Yes | Top output token targets | -| `get_top_targets(source_key, k, sign, w_unembed)` | Yes | All targets (components + outputs) | -| `get_attribution(source_key, component_key)` | No | Single component attribution | -| `get_attribution(source_key, output_key, w_unembed)` | Yes | Single output attribution | +Note: `attr_abs` returns empty for output targets (unembed edges have no abs variant). diff --git a/spd/dataset_attributions/__init__.py b/spd/dataset_attributions/__init__.py index 0cf90729f..89408acab 100644 --- a/spd/dataset_attributions/__init__.py +++ b/spd/dataset_attributions/__init__.py @@ -4,14 +4,15 @@ training dataset. """ -from spd.dataset_attributions.harvest import DatasetAttributionConfig, harvest_attributions -from spd.dataset_attributions.loaders import load_dataset_attributions +from spd.dataset_attributions.config import DatasetAttributionConfig +from spd.dataset_attributions.harvest import harvest_attributions +from spd.dataset_attributions.repo import AttributionRepo from spd.dataset_attributions.storage import DatasetAttributionEntry, DatasetAttributionStorage __all__ = [ + "AttributionRepo", "DatasetAttributionConfig", "DatasetAttributionEntry", "DatasetAttributionStorage", "harvest_attributions", - "load_dataset_attributions", ] diff --git a/spd/dataset_attributions/config.py b/spd/dataset_attributions/config.py new file mode 100644 index 000000000..e1167fc5b --- /dev/null +++ b/spd/dataset_attributions/config.py @@ -0,0 +1,30 @@ +"""Dataset attribution configuration. + +DatasetAttributionConfig: tuning params for the attribution pipeline. +AttributionsSlurmConfig: DatasetAttributionConfig + SLURM submission params. +""" + +from typing import Literal + +from pydantic import PositiveInt + +from spd.base_config import BaseConfig +from spd.settings import DEFAULT_PARTITION_NAME + + +class DatasetAttributionConfig(BaseConfig): + spd_run_wandb_path: str + n_batches: int | Literal["whole_dataset"] = 10_000 + batch_size: int = 32 + ci_threshold: float = 0.0 + + +class AttributionsSlurmConfig(BaseConfig): + """Config for dataset attributions SLURM submission.""" + + config: DatasetAttributionConfig + n_gpus: PositiveInt = 8 + partition: str = DEFAULT_PARTITION_NAME + time: str = "48:00:00" + merge_time: str = "01:00:00" + merge_mem: str = "200G" diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 0651a3a7b..b7ff56e35 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -13,7 +13,6 @@ """ import itertools -from dataclasses import dataclass from pathlib import Path import torch @@ -21,158 +20,83 @@ from jaxtyping import Bool from torch import Tensor -from spd.app.backend.compute import get_sources_by_target from spd.data import train_loader_and_tokenizer +from spd.dataset_attributions.config import DatasetAttributionConfig from spd.dataset_attributions.harvester import AttributionHarvester -from spd.dataset_attributions.loaders import get_attributions_dir from spd.dataset_attributions.storage import DatasetAttributionStorage -from spd.harvest.loaders import load_activation_contexts_summary +from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.topology import TransformerTopology, get_sources_by_target from spd.utils.distributed_utils import get_device from spd.utils.general_utils import extract_batch_data from spd.utils.wandb_utils import parse_wandb_run_path -@dataclass -class DatasetAttributionConfig: - wandb_path: str - n_batches: int | None - batch_size: int - ci_threshold: float - - -def _build_component_layer_keys(model: ComponentModel) -> list[str]: - """Build list of component layer keys in canonical order. - - Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. - wte and output keys are not included - they're constructed from vocab_size. - """ - component_layer_keys = [] - for layer in model.target_module_paths: - n_components = model.module_to_c[layer] - for c_idx in range(n_components): - component_layer_keys.append(f"{layer}:{c_idx}") - return component_layer_keys - - def _build_alive_masks( model: ComponentModel, run_id: str, - ci_threshold: float, - n_components: int, - vocab_size: int, -) -> tuple[Bool[Tensor, " n_sources"], Bool[Tensor, " n_components"]]: - """Build masks of alive components (mean_ci > threshold) for sources and targets. - - Falls back to all-alive if harvest summary not available. + harvest_subrun_id: str, +) -> dict[str, Bool[Tensor, " n_components"]]: + """Build masks of alive components (firing_density > 0) per target layer. - Index structure: - - Sources: [0, vocab_size) = wte tokens, [vocab_size, vocab_size + n_components) = component layers - - Targets: [0, n_components) = component layers (output handled via out_residual) + Only covers component layers — embed is always a valid source (not filtered). """ - summary = load_activation_contexts_summary(run_id) - n_sources = vocab_size + n_components + component_alive = { + layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) + for layer in model.target_module_paths + } - source_alive = torch.zeros(n_sources, dtype=torch.bool) - target_alive = torch.zeros(n_components, dtype=torch.bool) + harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) - # All wte tokens are always alive (source indices [0, vocab_size)) - source_alive[:vocab_size] = True - - if summary is None: - logger.warning("Harvest summary not available, using all components as alive") - source_alive.fill_(True) - target_alive.fill_(True) - return source_alive, target_alive - - # Build masks for component layers - source_idx = vocab_size # Start after wte tokens - target_idx = 0 + summary = harvest.get_summary() + assert summary is not None, "Harvest summary not available" for layer in model.target_module_paths: n_layer_components = model.module_to_c[layer] for c_idx in range(n_layer_components): component_key = f"{layer}:{c_idx}" - is_alive = component_key in summary and summary[component_key].mean_ci > ci_threshold - source_alive[source_idx] = is_alive - target_alive[target_idx] = is_alive - source_idx += 1 - target_idx += 1 - - n_source_alive = int(source_alive.sum().item()) - n_target_alive = int(target_alive.sum().item()) - logger.info( - f"Alive components: {n_source_alive}/{n_sources} sources, " - f"{n_target_alive}/{n_components} component targets (ci > {ci_threshold})" - ) - return source_alive, target_alive - + is_alive = component_key in summary and summary[component_key].firing_density > 0.0 + component_alive[layer][c_idx] = is_alive -def _get_output_path(run_id: str, rank: int | None) -> Path: - """Get output path for attributions.""" - output_dir = get_attributions_dir(run_id) - if rank is not None: - return output_dir / f"dataset_attributions_rank_{rank}.pt" - return output_dir / "dataset_attributions.pt" + return component_alive def harvest_attributions( config: DatasetAttributionConfig, - rank: int | None = None, - world_size: int | None = None, + output_dir: Path, + harvest_subrun_id: str, + rank: int, + world_size: int, ) -> None: - """Compute dataset attributions over the training dataset. - - Args: - config: Configuration for attribution harvesting. - rank: Worker rank for parallel execution (0 to world_size-1). - world_size: Total number of workers. If specified with rank, only processes - batches where batch_idx % world_size == rank. - """ - - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - device = torch.device(get_device()) logger.info(f"Loading model on {device}") - _, _, run_id = parse_wandb_run_path(config.wandb_path) + _, _, run_id = parse_wandb_run_path(config.spd_run_wandb_path) - run_info = SPDRunInfo.from_path(config.wandb_path) + run_info = SPDRunInfo.from_path(config.spd_run_wandb_path) model = ComponentModel.from_run_info(run_info).to(device) model.eval() spd_config = run_info.config - train_loader, tokenizer = train_loader_and_tokenizer(spd_config, config.batch_size) - vocab_size = tokenizer.vocab_size - assert isinstance(vocab_size, int), f"vocab_size must be int, got {type(vocab_size)}" - logger.info(f"Vocab size: {vocab_size}") - - # Build component keys and alive masks - component_layer_keys = _build_component_layer_keys(model) - n_components = len(component_layer_keys) - source_alive, target_alive = _build_alive_masks( - model, run_id, config.ci_threshold, n_components, vocab_size - ) - source_alive = source_alive.to(device) - target_alive = target_alive.to(device) - - n_sources = vocab_size + n_components - logger.info(f"Component layers: {n_components}, Sources: {n_sources}") + train_loader, _ = train_loader_and_tokenizer(spd_config, config.batch_size) # Get gradient connectivity logger.info("Computing sources_by_target...") - sources_by_target_raw = get_sources_by_target(model, str(device), spd_config.sampling) - - # Filter sources_by_target: - # - Valid targets: component layers + output - # - Valid sources: wte + component layers + topology = TransformerTopology(model.target_model) + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + sources_by_target_raw = get_sources_by_target(model, topology, str(device), spd_config.sampling) + + # Filter to valid source/target pairs: + # - Valid sources: embedding + component layers + # - Valid targets: component layers + unembed component_layers = set(model.target_module_paths) - valid_sources = component_layers | {"wte"} - valid_targets = component_layers | {"output"} + valid_sources = component_layers | {embed_path} + valid_targets = component_layers | {unembed_path} - sources_by_target = {} + sources_by_target: dict[str, list[str]] = {} for target, sources in sources_by_target_raw.items(): if target not in valid_targets: continue @@ -181,124 +105,57 @@ def harvest_attributions( sources_by_target[target] = filtered_sources logger.info(f"Found {len(sources_by_target)} target layers with gradient connections") - # Create harvester + # Build alive masks + component_alive = _build_alive_masks(model, run_id, harvest_subrun_id) + harvester = AttributionHarvester( model=model, + topology=topology, sources_by_target=sources_by_target, - n_components=n_components, - vocab_size=vocab_size, - source_alive=source_alive, - target_alive=target_alive, + component_alive=component_alive, sampling=spd_config.sampling, - device=device, - show_progress=True, ) # Process batches train_iter = iter(train_loader) - batch_range = range(config.n_batches) if config.n_batches is not None else itertools.count() + match config.n_batches: + case int(n_batches): + batch_range = range(n_batches) + case "whole_dataset": + batch_range = itertools.count() + for batch_idx in tqdm.tqdm(batch_range, desc="Attribution batches"): try: batch_data = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break - # Skip batches not assigned to this rank - if world_size is not None and batch_idx % world_size != rank: + + if batch_idx % world_size != rank: continue + batch = extract_batch_data(batch_data).to(device) harvester.process_batch(batch) - logger.info( - f"Processing complete. Tokens: {harvester.n_tokens:,}, Batches: {harvester.n_batches}" - ) + logger.info(f"Processing complete. Tokens: {harvester.n_tokens:,}") - # Normalize by n_tokens to get per-token average attribution - normalized_comp = harvester.comp_accumulator / harvester.n_tokens - normalized_out_residual = harvester.out_residual_accumulator / harvester.n_tokens - - # Build and save storage - storage = DatasetAttributionStorage( - component_layer_keys=component_layer_keys, - vocab_size=vocab_size, - d_model=harvester.d_model, - source_to_component=normalized_comp.cpu(), - source_to_out_residual=normalized_out_residual.cpu(), - n_batches_processed=harvester.n_batches, - n_tokens_processed=harvester.n_tokens, - ci_threshold=config.ci_threshold, - ) + storage = harvester.finalize(config.ci_threshold) - output_path = _get_output_path(run_id, rank) + worker_dir = output_dir / "worker_states" + worker_dir.mkdir(parents=True, exist_ok=True) + output_path = worker_dir / f"dataset_attributions_rank_{rank}.pt" storage.save(output_path) - logger.info(f"Saved dataset attributions to {output_path}") - -def merge_attributions(wandb_path: str) -> None: - """Merge partial attribution files from parallel workers. - Looks for dataset_attributions_rank_*.pt files and merges them into - dataset_attributions.pt. - - Uses streaming merge to avoid OOM - loads one file at a time instead of all at once. - """ - _, _, run_id = parse_wandb_run_path(wandb_path) - output_dir = get_attributions_dir(run_id) - - # Find all rank files - rank_files = sorted(output_dir.glob("dataset_attributions_rank_*.pt")) - assert rank_files, f"No rank files found in {output_dir}" +def merge_attributions(output_dir: Path) -> None: + """Merge partial attribution files from parallel workers.""" + worker_dir = output_dir / "worker_states" + rank_files = sorted(worker_dir.glob("dataset_attributions_rank_*.pt")) + assert rank_files, f"No rank files found in {worker_dir}" logger.info(f"Found {len(rank_files)} rank files to merge") - # Load first file to get metadata and initialize accumulators - # Use double precision for accumulation to prevent precision loss with billions of tokens - first = DatasetAttributionStorage.load(rank_files[0]) - total_comp = (first.source_to_component * first.n_tokens_processed).double() - total_out_residual = (first.source_to_out_residual * first.n_tokens_processed).double() - total_tokens = first.n_tokens_processed - total_batches = first.n_batches_processed - logger.info(f"Loaded rank 0: {first.n_tokens_processed:,} tokens") - - # Stream remaining files one at a time - for rank_file in tqdm.tqdm(rank_files[1:], desc="Merging rank files"): - storage = DatasetAttributionStorage.load(rank_file) - - # Validate consistency - assert storage.component_layer_keys == first.component_layer_keys, ( - "Component layer keys mismatch" - ) - assert storage.vocab_size == first.vocab_size, "Vocab size mismatch" - assert storage.d_model == first.d_model, "d_model mismatch" - assert storage.ci_threshold == first.ci_threshold, "CI threshold mismatch" - - # Accumulate de-normalized values - total_comp += storage.source_to_component * storage.n_tokens_processed - total_out_residual += storage.source_to_out_residual * storage.n_tokens_processed - total_tokens += storage.n_tokens_processed - total_batches += storage.n_batches_processed - - # Normalize by total tokens and convert back to float32 for storage - merged_comp = (total_comp / total_tokens).float() - merged_out_residual = (total_out_residual / total_tokens).float() - - # Save merged result - merged = DatasetAttributionStorage( - component_layer_keys=first.component_layer_keys, - vocab_size=first.vocab_size, - d_model=first.d_model, - source_to_component=merged_comp, - source_to_out_residual=merged_out_residual, - n_batches_processed=total_batches, - n_tokens_processed=total_tokens, - ci_threshold=first.ci_threshold, - ) + merged = DatasetAttributionStorage.merge(rank_files) output_path = output_dir / "dataset_attributions.pt" merged.save(output_path) - logger.info(f"Merged {len(rank_files)} files -> {output_path}") - logger.info(f"Total: {total_batches} batches, {total_tokens:,} tokens") - - # Clean up per-rank files after successful merge - for rank_file in rank_files: - rank_file.unlink() - logger.info(f"Deleted {len(rank_files)} per-rank files") + logger.info(f"Total: {merged.n_tokens_processed:,} tokens") diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index bea3b37d2..137089142 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -4,27 +4,33 @@ training dataset using gradient x activation formula, summed over all positions and batches. -Uses residual-based storage for scalability: -- Component targets: accumulated directly to comp_accumulator -- Output targets: accumulated as attributions to output residual stream (source_to_out_residual) - Output attributions computed on-the-fly at query time via w_unembed +Three metrics are accumulated: +- attr: E[∂y/∂x · x] (signed mean attribution) +- attr_abs: E[∂|y|/∂x · x] (attribution to absolute value of target) + +Output (pseudo-) component attributions are handled differently: We accumulate attributions +to the output residual stream, then later project this into token space. + +All layer keys are concrete module paths (e.g. "wte", "h.0.attn.q_proj", "lm_head"). +Translation to canonical names happens at the storage boundary in harvest.py. """ from typing import Any import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Int from torch import Tensor, nn -from tqdm.auto import tqdm from spd.configs import SamplingType +from spd.dataset_attributions.storage import DatasetAttributionStorage from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos +from spd.topology import TransformerTopology from spd.utils.general_utils import bf16_autocast class AttributionHarvester: - """Accumulates attribution strengths across batches. + """Accumulates attribution strengths across batches using concrete module paths. The attribution formula is: attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] @@ -35,11 +41,6 @@ class AttributionHarvester: 2. For output targets, store attributions to the pre-unembed residual (d_model dimensions) instead of vocab tokens. This eliminates the expensive O((V+C) × d_model × V) matmul during harvesting and reduces storage. - - Index structure: - - Sources: wte tokens [0, vocab_size) + component layers [vocab_size, ...) - - Component targets: [0, n_components) in comp_accumulator - - Output targets: via out_residual_accumulator (computed on-the-fly at query time) """ sampling: SamplingType @@ -47,96 +48,112 @@ class AttributionHarvester: def __init__( self, model: ComponentModel, + topology: TransformerTopology, sources_by_target: dict[str, list[str]], - n_components: int, - vocab_size: int, - source_alive: Bool[Tensor, " n_sources"], - target_alive: Bool[Tensor, " n_components"], + component_alive: dict[str, Bool[Tensor, " n_components"]], sampling: SamplingType, - device: torch.device, - show_progress: bool = False, ): self.model = model + self.topology = topology self.sources_by_target = sources_by_target - self.n_components = n_components - self.vocab_size = vocab_size - self.source_alive = source_alive - self.target_alive = target_alive + self.component_alive = component_alive self.sampling = sampling - self.device = device - self.show_progress = show_progress - - self.n_sources = vocab_size + n_components - self.n_batches = 0 - self.n_tokens = 0 - - # Split accumulators for component and output targets - self.comp_accumulator = torch.zeros(self.n_sources, n_components, device=device) - - # For output targets: store attributions to output residual dimensions - assert hasattr(model.target_model, "lm_head"), "Model must have lm_head" - lm_head = model.target_model.lm_head - assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" - self.d_model = lm_head.in_features - self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) - self.lm_head = lm_head - - # Build per-layer index ranges for sources - self.component_layer_names = list(model.target_module_paths) - self.source_layer_to_idx_range = self._build_source_layer_index_ranges() - self.target_layer_to_idx_range = self._build_target_layer_index_ranges() - - # Pre-compute alive indices per layer - self.alive_source_idxs_per_layer = self._build_alive_indices( - self.source_layer_to_idx_range, source_alive + self.embed_path = topology.path_schema.embedding_path + self.embedding_module = topology.embedding_module + self.unembed_path = topology.path_schema.unembed_path + self.unembed_module = topology.unembed_module + self.output_d_model = self.unembed_module.in_features + self.device = next(model.parameters()).device + + # attribution accumulators + self._straight_through_attr_acc = torch.zeros( + (self.output_d_model, self.embedding_module.num_embeddings), device=self.device ) - self.alive_target_idxs_per_layer = self._build_alive_indices( - self.target_layer_to_idx_range, target_alive + self._embed_tgts_acc = self._get_embed_targets_attr_accumulator(sources_by_target) + self._embed_tgts_acc_abs = self._get_embed_targets_attr_accumulator(sources_by_target) + self._unembed_srcs_acc = self._get_unembed_sources_attr_accumulator(sources_by_target) + self._regular_layers_acc = self._get_regular_layer_attr_accumulator(sources_by_target) + self._regular_layers_acc_abs = self._get_regular_layer_attr_accumulator(sources_by_target) + + # embed token occurrence counts for normalization (analogous to ci_sum for components) + self._embed_token_count = torch.zeros( + (self.embedding_module.num_embeddings,), dtype=torch.long, device=self.device ) - def _build_source_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Source order: wte tokens [0, vocab_size), then component layers.""" - ranges: dict[str, tuple[int, int]] = {"wte": (0, self.vocab_size)} - idx = self.vocab_size - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - return ranges - - def _build_target_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Target order: component layers [0, n_components). Output handled separately.""" - ranges: dict[str, tuple[int, int]] = {} - idx = 0 - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - # Note: "output" not included - handled via out_residual_accumulator - return ranges - - def _build_alive_indices( - self, layer_ranges: dict[str, tuple[int, int]], alive_mask: Bool[Tensor, " n"] - ) -> dict[str, list[int]]: - """Get alive local indices for each layer.""" - return { - layer: torch.where(alive_mask[start:end])[0].tolist() - for layer, (start, end) in layer_ranges.items() + # rms normalization accumulators + self.n_tokens = 0 + self._ci_sum_accumulator = { + layer: torch.zeros((c,), device=self.device) + for layer, c in self.model.module_to_c.items() + } + self._square_component_act_accumulator = { + layer: torch.zeros((c,), device=self.device) + for layer, c in self.model.module_to_c.items() } + self._logit_sq_sum = torch.zeros((self.unembed_module.out_features,), device=self.device) + + def _get_embed_targets_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, Tensor]: + # extract targets who's sources include the embedding + embed_targets_attr_accumulators: dict[str, Tensor] = {} + for target, sources in sources_by_target.items(): + if target == self.unembed_path: + # ignore straight-through edge + continue + if self.embed_path in sources: + embed_targets_attr_accumulators[target] = torch.zeros( + (self.model.module_to_c[target], self.embedding_module.num_embeddings), + device=self.device, + ) + return embed_targets_attr_accumulators + + def _get_unembed_sources_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, Tensor]: + # extract the unembed's sources + unembed_sources_attr_accumulators: dict[str, Tensor] = {} + for source in sources_by_target[self.unembed_path]: + if source == self.embed_path: + # ignore straight-through edge + continue + unembed_sources_attr_accumulators[source] = torch.zeros( + (self.output_d_model, self.model.module_to_c[source]), device=self.device + ) + return unembed_sources_attr_accumulators + + def _get_regular_layer_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, dict[str, Tensor]]: + regular_layers_shapes: dict[str, dict[str, Tensor]] = {} + for target, sources in sources_by_target.items(): + if target == self.unembed_path: + continue + regular_layers_shapes[target] = {} + for source in sources: + if source == self.embed_path: + continue + regular_layers_shapes[target][source] = torch.zeros( + (self.model.module_to_c[target], self.model.module_to_c[source]), + device=self.device, + ) + return regular_layers_shapes def process_batch(self, tokens: Int[Tensor, "batch seq"]) -> None: """Accumulate attributions from one batch.""" - self.n_batches += 1 self.n_tokens += tokens.numel() + self._embed_token_count.add_( + torch.bincount(tokens.flatten(), minlength=self.embedding_module.num_embeddings) + ) - # Setup hooks to capture wte output and pre-unembed residual - wte_out: list[Tensor] = [] + # Setup hooks to capture embedding output and pre-unembed residual + embed_out: list[Tensor] = [] pre_unembed: list[Tensor] = [] - def wte_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: + def embed_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: out.requires_grad_(True) - wte_out.clear() - wte_out.append(out) + embed_out.clear() + embed_out.append(out) return out def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> None: @@ -144,10 +161,8 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - wte = self.model.target_model.wte - assert isinstance(wte, nn.Module) - h1 = wte.register_forward_hook(wte_hook, with_kwargs=True) - h2 = self.lm_head.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) + h1 = self.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) + h2 = self.unembed_module.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) # Get masks with all components active with torch.no_grad(), bf16_autocast(): @@ -155,6 +170,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No ci = self.model.calc_causal_importances( pre_weight_acts=out.cache, sampling=self.sampling, detach_inputs=False ) + mask_infos = make_mask_infos( component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, routing_masks="all", @@ -162,100 +178,144 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No # Forward pass with gradients with torch.enable_grad(), bf16_autocast(): - comp_output: OutputWithCache = self.model( + model_output: OutputWithCache = self.model( tokens, mask_infos=mask_infos, cache_type="component_acts" ) h1.remove() h2.remove() - cache = comp_output.cache - cache["wte_post_detach"] = wte_out[0] - cache["pre_unembed"] = pre_unembed[0] - cache["tokens"] = tokens - - # Process each target layer - layers = list(self.sources_by_target.items()) - pbar = tqdm(layers, desc="Targets", disable=not self.show_progress, leave=False) - for target_layer, source_layers in pbar: - if target_layer == "output": - self._process_output_targets(source_layers, cache) + cache = model_output.cache + cache[f"{self.embed_path}_post_detach"] = embed_out[0] + cache[f"{self.unembed_path}_pre_detach"] = pre_unembed[0] + + with torch.no_grad(): + for real_layer, ci_vals in ci.lower_leaky.items(): + self._ci_sum_accumulator[real_layer].add_(ci_vals.sum(dim=(0, 1))) + self._logit_sq_sum.add_(model_output.output.detach().square().sum(dim=(0, 1))) + + for target_layer in self.sources_by_target: + if target_layer == self.unembed_path: + self._process_output_targets(cache, tokens, ci.lower_leaky) else: - self._process_component_targets(target_layer, source_layers, cache) + with torch.no_grad(): + sum_sq_acts = cache[f"{target_layer}_post_detach"].square().sum(dim=(0, 1)) + self._square_component_act_accumulator[target_layer].add_(sum_sq_acts) + self._process_component_targets(cache, tokens, ci.lower_leaky, target_layer) - def _process_component_targets( + def _process_output_targets( self, - target_layer: str, - source_layers: list[str], cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], ) -> None: - """Process attributions to a component layer.""" - target_start, _ = self.target_layer_to_idx_range[target_layer] - alive_targets = self.alive_target_idxs_per_layer[target_layer] - if not alive_targets: - return + """Process output attributions via output-residual-space storage.""" + out_residual = cache[f"{self.unembed_path}_pre_detach"] + + out_residual_sum = out_residual.sum(dim=(0, 1)) + + source_layers = self.sources_by_target[self.unembed_path] + assert self.embed_path in source_layers - # Sum over batch and sequence - target_acts = cache[f"{target_layer}_pre_detach"].sum(dim=(0, 1)) source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for t_idx in alive_targets: - grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.comp_accumulator[:, target_start + t_idx], - source_layers, - grads, - source_acts, - cache["tokens"], - ) + for d_idx in range(self.output_d_model): + grads = torch.autograd.grad(out_residual_sum[d_idx], source_acts, retain_graph=True) + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == self.embed_path: + token_attr = (grad * act).sum(dim=-1) # (B S) + self._straight_through_attr_acc[d_idx].scatter_add_( + 0, tokens.flatten(), token_attr.flatten() + ) + else: + ci_weighted_attr = (grad * act * ci[source_layer]).sum(dim=(0, 1)) + self._unembed_srcs_acc[source_layer][d_idx].add_(ci_weighted_attr) - def _process_output_targets( + def _process_component_targets( self, - source_layers: list[str], cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], + target_layer: str, ) -> None: - """Process output attributions via output-residual-space storage. - - Instead of computing and storing attributions to vocab tokens directly, - we store attributions to output residual dimensions. Output attributions are - computed on-the-fly at query time via: attr[src, token] = out_residual[src] @ w_unembed[:, token] - """ - # Sum output residual over batch and sequence -> [d_model] - out_residual = cache["pre_unembed"].sum(dim=(0, 1)) + """Process attributions to a component layer.""" + alive_targets = self.component_alive[target_layer] + if not alive_targets.any(): + return + + target_acts_raw = cache[f"{target_layer}_pre_detach"] + + target_acts = target_acts_raw.sum(dim=(0, 1)) + # abs() before sum — needs its own backward pass because each element has a different + # sign, so sign·grad can't be factored out of the sum. (In the app backend's per-prompt + # computation the target is a single scalar, so sign·grad works as an analytical shortcut + # and avoids a second backward. See app/backend/compute.py::_compute_edges_for_target.) + target_acts_abs = target_acts_raw.abs().sum(dim=(0, 1)) + + source_layers = self.sources_by_target[target_layer] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for d_idx in range(self.d_model): - grads = torch.autograd.grad(out_residual[d_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.out_residual_accumulator[:, d_idx], - source_layers, - grads, - source_acts, - cache["tokens"], + def _accumulate_grads( + grads: tuple[Tensor, ...], + t_idx: int, + embed_acc: dict[str, Tensor], + regular_acc: dict[str, dict[str, Tensor]], + ) -> None: + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == self.embed_path: + token_attr = (grad * act).sum(dim=-1) # (B S) + embed_acc[target_layer][t_idx].scatter_add_( + 0, tokens.flatten(), token_attr.flatten() + ) + else: + ci_weighted = (grad * act * ci[source_layer]).sum(dim=(0, 1)) # (C,) + regular_acc[target_layer][source_layer][t_idx].add_(ci_weighted) + + for t_idx in torch.where(alive_targets)[0].tolist(): + grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) + _accumulate_grads( + grads=grads, + t_idx=t_idx, + embed_acc=self._embed_tgts_acc, + regular_acc=self._regular_layers_acc, ) - def _accumulate_attributions( - self, - target_col: Float[Tensor, " n_sources"], - source_layers: list[str], - grads: tuple[Tensor, ...], - source_acts: list[Tensor], - tokens: Int[Tensor, "batch seq"], - ) -> None: - """Accumulate grad*act attributions from sources to a target column.""" - with torch.no_grad(): - for layer, grad, act in zip(source_layers, grads, source_acts, strict=True): - alive = self.alive_source_idxs_per_layer[layer] - if not alive: - continue + grads_abs = torch.autograd.grad(target_acts_abs[t_idx], source_acts, retain_graph=True) + _accumulate_grads( + grads=grads_abs, + t_idx=t_idx, + embed_acc=self._embed_tgts_acc_abs, + regular_acc=self._regular_layers_acc_abs, + ) - if layer == "wte": - # Per-token: sum grad*act over d_model, scatter by token id - attr = (grad * act).sum(dim=-1).flatten() - target_col.scatter_add_(0, tokens.flatten(), attr) - else: - # Per-component: sum grad*act over batch and sequence - start, _ = self.source_layer_to_idx_range[layer] - attr = (grad * act).sum(dim=(0, 1)) - for c in alive: - target_col[start + c] += attr[c] + def finalize(self, ci_threshold: float) -> DatasetAttributionStorage: + """Package raw accumulators into storage. No normalization — that happens at query time.""" + assert self.n_tokens > 0, "No batches processed" + + to_canon = self.topology.target_to_canon + + def _canon_nested(acc: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + to_canon(t): {to_canon(s): v for s, v in srcs.items()} for t, srcs in acc.items() + } + + def _canon(acc: dict[str, Tensor]) -> dict[str, Tensor]: + return {to_canon(k): v for k, v in acc.items()} + + return DatasetAttributionStorage( + regular_attr=_canon_nested(self._regular_layers_acc), + regular_attr_abs=_canon_nested(self._regular_layers_acc_abs), + embed_attr=_canon(self._embed_tgts_acc), + embed_attr_abs=_canon(self._embed_tgts_acc_abs), + unembed_attr=_canon(self._unembed_srcs_acc), + embed_unembed_attr=self._straight_through_attr_acc, + w_unembed=self.topology.get_unembed_weight(), + ci_sum=_canon(self._ci_sum_accumulator), + component_act_sq_sum=_canon(self._square_component_act_accumulator), + logit_sq_sum=self._logit_sq_sum, + embed_token_count=self._embed_token_count, + ci_threshold=ci_threshold, + n_tokens_processed=self.n_tokens, + ) diff --git a/spd/dataset_attributions/loaders.py b/spd/dataset_attributions/loaders.py deleted file mode 100644 index 545393049..000000000 --- a/spd/dataset_attributions/loaders.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Loaders for dataset attributions.""" - -from pathlib import Path - -from spd.dataset_attributions.storage import DatasetAttributionStorage -from spd.settings import SPD_OUT_DIR - -# Base directory for dataset attributions -DATASET_ATTRIBUTIONS_DIR = SPD_OUT_DIR / "dataset_attributions" - - -def get_attributions_dir(wandb_run_id: str) -> Path: - """Get the dataset attributions directory for a run.""" - return DATASET_ATTRIBUTIONS_DIR / wandb_run_id - - -def load_dataset_attributions(wandb_run_id: str) -> DatasetAttributionStorage | None: - """Load dataset attributions, if available.""" - path = get_attributions_dir(wandb_run_id) / "dataset_attributions.pt" - if not path.exists(): - return None - return DatasetAttributionStorage.load(path) diff --git a/spd/dataset_attributions/repo.py b/spd/dataset_attributions/repo.py new file mode 100644 index 000000000..1175d584e --- /dev/null +++ b/spd/dataset_attributions/repo.py @@ -0,0 +1,54 @@ +"""Dataset attributions data repository. + +Owns SPD_OUT_DIR/dataset_attributions// and provides read access +to the attribution matrix. + +Use AttributionRepo.open() to construct — returns None if no attribution data exists. +Layout: dataset_attributions//da-YYYYMMDD_HHMMSS/dataset_attributions.pt +""" + +from pathlib import Path + +from spd.dataset_attributions.storage import DatasetAttributionStorage +from spd.settings import SPD_OUT_DIR + +DATASET_ATTRIBUTIONS_DIR = SPD_OUT_DIR / "dataset_attributions" + + +def get_attributions_dir(run_id: str) -> Path: + return DATASET_ATTRIBUTIONS_DIR / run_id + + +def get_attributions_subrun_dir(run_id: str, subrun_id: str) -> Path: + return get_attributions_dir(run_id) / subrun_id + + +class AttributionRepo: + """Read access to dataset attribution data for a single run. + + Constructed via AttributionRepo.open(). Storage is loaded eagerly at construction. + """ + + def __init__(self, storage: DatasetAttributionStorage, subrun_id: str) -> None: + self._storage = storage + self.subrun_id = subrun_id + + @classmethod + def open(cls, run_id: str) -> "AttributionRepo | None": + """Open attribution data for a run. Returns None if no attribution data exists.""" + base_dir = get_attributions_dir(run_id) + if not base_dir.exists(): + return None + candidates = sorted( + [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("da-")], + key=lambda d: d.name, + reverse=True, + ) + for subrun_dir in candidates: + path = subrun_dir / "dataset_attributions.pt" + if path.exists(): + return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + return None + + def get_attributions(self) -> DatasetAttributionStorage: + return self._storage diff --git a/spd/dataset_attributions/scripts/run.py b/spd/dataset_attributions/scripts/run.py deleted file mode 100644 index 42d75b935..000000000 --- a/spd/dataset_attributions/scripts/run.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Worker script for dataset attribution computation. - -Called by SLURM jobs submitted via spd-attributions, or run directly for non-SLURM environments. - -Usage: - # Single GPU - python -m spd.dataset_attributions.scripts.run --n_batches 1000 - - # Multi-GPU (run in parallel) - python -m spd.dataset_attributions.scripts.run --n_batches 1000 --rank 0 --world_size 4 - python -m spd.dataset_attributions.scripts.run --n_batches 1000 --rank 1 --world_size 4 - ... - python -m spd.dataset_attributions.scripts.run --merge -""" - -from spd.dataset_attributions.harvest import ( - DatasetAttributionConfig, - harvest_attributions, - merge_attributions, -) - - -def main( - wandb_path: str, - n_batches: int | None = None, - batch_size: int = 256, - ci_threshold: float = 0.0, - rank: int | None = None, - world_size: int | None = None, - merge: bool = False, -) -> None: - """Compute dataset attributions, or merge results. - - Args: - wandb_path: WandB run path for the target decomposition run. - n_batches: Number of batches to process. If None, processes entire training dataset. - batch_size: Batch size for processing. - ci_threshold: CI threshold for filtering components. Components with mean_ci <= threshold - are excluded. Default 0.0 includes all components. - rank: Worker rank for parallel execution (0 to world_size-1). - world_size: Total number of workers. If specified with rank, only processes - batches where batch_idx % world_size == rank. - merge: If True, merge partial results from workers. - """ - if merge: - assert rank is None and world_size is None, "Cannot specify rank/world_size with --merge" - print(f"Merging attribution results for {wandb_path}") - merge_attributions(wandb_path) - return - - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - - if world_size is not None: - print(f"Distributed harvest: {wandb_path} (rank {rank}/{world_size})") - else: - print(f"Single-GPU harvest: {wandb_path}") - - config = DatasetAttributionConfig( - wandb_path=wandb_path, - n_batches=n_batches, - batch_size=batch_size, - ci_threshold=ci_threshold, - ) - - harvest_attributions(config, rank=rank, world_size=world_size) - - -def cli() -> None: - import fire - - fire.Fire(main) - - -if __name__ == "__main__": - cli() diff --git a/spd/dataset_attributions/scripts/run_merge.py b/spd/dataset_attributions/scripts/run_merge.py new file mode 100644 index 000000000..913ea5374 --- /dev/null +++ b/spd/dataset_attributions/scripts/run_merge.py @@ -0,0 +1,37 @@ +"""Merge script for dataset attribution rank files. + +Combines per-rank attribution files into a single merged result. + +Usage: + python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id da-xxx +""" + +from spd.dataset_attributions.harvest import merge_attributions +from spd.dataset_attributions.repo import get_attributions_subrun_dir +from spd.log import logger +from spd.utils.wandb_utils import parse_wandb_run_path + + +def main( + *, + wandb_path: str, + subrun_id: str, +) -> None: + _, _, run_id = parse_wandb_run_path(wandb_path) + output_dir = get_attributions_subrun_dir(run_id, subrun_id) + logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") + merge_attributions(output_dir) + + +def get_command(wandb_path: str, subrun_id: str) -> str: + return ( + f"python -m spd.dataset_attributions.scripts.run_merge " + f'--wandb_path "{wandb_path}" ' + f"--subrun_id {subrun_id}" + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 9ffb0a4fc..97ea65e75 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -10,13 +10,17 @@ """ import secrets +from dataclasses import dataclass +from datetime import datetime +from spd.dataset_attributions.config import AttributionsSlurmConfig +from spd.dataset_attributions.scripts import run_merge, run_worker from spd.log import logger -from spd.settings import DEFAULT_PARTITION_NAME from spd.utils.git_utils import create_git_snapshot from spd.utils.slurm import ( SlurmArrayConfig, SlurmConfig, + SubmitResult, generate_array_script, generate_script, submit_slurm_job, @@ -24,52 +28,54 @@ from spd.utils.wandb_utils import wandb_path_to_url +@dataclass +class AttributionsSubmitResult: + array_result: SubmitResult + merge_result: SubmitResult + subrun_id: str + + @property + def job_id(self) -> str: + return self.merge_result.job_id + + def submit_attributions( wandb_path: str, - n_gpus: int, - n_batches: int | None = None, - batch_size: int = 256, - ci_threshold: float = 0.0, - partition: str = DEFAULT_PARTITION_NAME, - time: str = "48:00:00", + config: AttributionsSlurmConfig, + harvest_subrun_id: str, job_suffix: str | None = None, -) -> None: - """Submit multi-GPU attribution harvesting job to SLURM. - - Submits a job array where each task processes a subset of batches, then - submits a merge job that depends on all workers completing. Creates a git - snapshot to ensure consistent code across all workers. - - Args: - wandb_path: WandB run path for the target decomposition run. - n_batches: Total number of batches to process (divided among workers). - If None, processes entire training dataset. - n_gpus: Number of GPUs (each gets its own array task). - batch_size: Batch size for processing. - ci_threshold: CI threshold for filtering components. - partition: SLURM partition name. - time: Job time limit. - job_suffix: Optional suffix for SLURM job names (e.g., "1h" -> "spd-attr-1h"). - """ - launch_id = f"attr-{secrets.token_hex(4)}" - snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=launch_id) - logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + snapshot_branch: str | None = None, + dependency_job_id: str | None = None, +) -> AttributionsSubmitResult: + """Submit multi-GPU attribution harvesting job to SLURM.""" + n_gpus = config.n_gpus + partition = config.partition + time = config.time + + if snapshot_branch is None: + run_id = f"attr-{secrets.token_hex(4)}" + snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=run_id) + logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + else: + commit_hash = "shared" + + subrun_id = "da-" + datetime.now().strftime("%Y%m%d_%H%M%S") suffix = f"-{job_suffix}" if job_suffix else "" array_job_name = f"spd-attr{suffix}" + config_json = config.config.model_dump_json(exclude_none=True) + # SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc. worker_commands = [] for rank in range(n_gpus): - n_batches_arg = f"--n_batches {n_batches} " if n_batches is not None else "" - cmd = ( - f"python -m spd.dataset_attributions.scripts.run " - f'"{wandb_path}" ' - f"{n_batches_arg}" - f"--batch_size {batch_size} " - f"--ci_threshold {ci_threshold} " - f"--rank {rank} " - f"--world_size {n_gpus}" + cmd = run_worker.get_command( + wandb_path, + config_json, + harvest_subrun_id=harvest_subrun_id, + rank=rank, + world_size=n_gpus, + subrun_id=subrun_id, ) worker_commands.append(cmd) @@ -81,6 +87,7 @@ def submit_attributions( n_gpus=1, # 1 GPU per worker time=time, snapshot_branch=snapshot_branch, + dependency_job_id=dependency_job_id, comment=wandb_url, ) array_script = generate_array_script(array_config, worker_commands) @@ -92,12 +99,13 @@ def submit_attributions( ) # Submit merge job with dependency on array completion - merge_cmd = f'python -m spd.dataset_attributions.scripts.run "{wandb_path}" --merge' + merge_cmd = run_merge.get_command(wandb_path, subrun_id) merge_config = SlurmConfig( job_name="spd-attr-merge", partition=partition, - n_gpus=0, # No GPU needed for merge - time="01:00:00", # Merge is quick + n_gpus=0, + time=config.merge_time, + mem=config.merge_mem, snapshot_branch=snapshot_branch, dependency_job_id=array_result.job_id, comment=wandb_url, @@ -109,9 +117,10 @@ def submit_attributions( logger.values( { "WandB path": wandb_path, - "N batches": n_batches, + "Sub-run ID": subrun_id, + "N batches": config.config.n_batches, "N GPUs": n_gpus, - "Batch size": batch_size, + "Batch size": config.config.batch_size, "Snapshot": f"{snapshot_branch} ({commit_hash[:8]})", "Array Job ID": array_result.job_id, "Merge Job ID": merge_result.job_id, @@ -121,3 +130,9 @@ def submit_attributions( "Merge script": str(merge_result.script_path), } ) + + return AttributionsSubmitResult( + array_result=array_result, + merge_result=merge_result, + subrun_id=subrun_id, + ) diff --git a/spd/dataset_attributions/scripts/run_slurm_cli.py b/spd/dataset_attributions/scripts/run_slurm_cli.py index 776dc0878..50fc48fe8 100644 --- a/spd/dataset_attributions/scripts/run_slurm_cli.py +++ b/spd/dataset_attributions/scripts/run_slurm_cli.py @@ -3,55 +3,38 @@ Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. Usage: - spd-attributions --n_gpus 24 - spd-attributions --n_batches 1000 --n_gpus 8 # Only process 1000 batches + spd-attributions --n_gpus 8 + spd-attributions --config attr_config.yaml """ import fire -from spd.settings import DEFAULT_PARTITION_NAME - def submit_attributions( wandb_path: str, - n_gpus: int, - n_batches: int | None = None, - batch_size: int = 256, - ci_threshold: float = 0.0, - partition: str = DEFAULT_PARTITION_NAME, - time: str = "48:00:00", + config: str, + harvest_subrun_id: str, job_suffix: str | None = None, ) -> None: """Submit multi-GPU dataset attribution harvesting to SLURM. - Submits a job array where each GPU processes a subset of batches, - then a merge job that combines results after all workers complete. - - Examples: - spd-attributions wandb:spd/runs/abc123 --n_gpus 24 - spd-attributions wandb:spd/runs/abc123 --n_batches 1000 --n_gpus 8 # Only process 1000 batches - Args: wandb_path: WandB run path for the target decomposition run. - n_batches: Total number of batches to process (divided among workers). - If None, processes entire training dataset. - n_gpus: Number of GPUs (each gets its own array task). - batch_size: Batch size for processing. - ci_threshold: CI threshold for filtering components. - partition: SLURM partition name. - time: Job time limit for worker jobs. + config: Path to AttributionsSlurmConfig YAML/JSON. Uses built-in defaults if omitted. + harvest_subrun_id: Harvest subrun to use for alive masks (e.g. "h-20260306_120000"). job_suffix: Optional suffix for SLURM job names (e.g., "v2" -> "spd-attr-v2"). """ + from spd.dataset_attributions.config import AttributionsSlurmConfig from spd.dataset_attributions.scripts.run_slurm import submit_attributions as impl + from spd.utils.wandb_utils import parse_wandb_run_path + + parse_wandb_run_path(wandb_path) + slurm_config = AttributionsSlurmConfig.from_file(config) impl( wandb_path=wandb_path, - n_batches=n_batches, - n_gpus=n_gpus, - batch_size=batch_size, - ci_threshold=ci_threshold, - partition=partition, - time=time, + config=slurm_config, + harvest_subrun_id=harvest_subrun_id, job_suffix=job_suffix, ) diff --git a/spd/dataset_attributions/scripts/run_worker.py b/spd/dataset_attributions/scripts/run_worker.py new file mode 100644 index 000000000..8e4079e95 --- /dev/null +++ b/spd/dataset_attributions/scripts/run_worker.py @@ -0,0 +1,66 @@ +"""Worker script for dataset attribution computation. + +Called by SLURM jobs submitted via spd-attributions. + +Usage: + python -m spd.dataset_attributions.scripts.run_worker \ + --config_json '{"n_batches": 500}' \ + --rank 0 --world_size 4 --subrun_id da-xxx +""" + +from typing import Any + +from spd.dataset_attributions.config import DatasetAttributionConfig +from spd.dataset_attributions.harvest import harvest_attributions +from spd.dataset_attributions.repo import get_attributions_subrun_dir +from spd.utils.wandb_utils import parse_wandb_run_path + + +def main( + wandb_path: str, + config_json: dict[str, Any], + harvest_subrun_id: str, + rank: int, + world_size: int, + subrun_id: str, +) -> None: + # Fire parses JSON strings into dicts automatically + assert isinstance(config_json, dict), f"Expected dict from Fire, got {type(config_json)}" + _, _, run_id = parse_wandb_run_path(wandb_path) + + config = DatasetAttributionConfig.model_validate(config_json) + assert config.spd_run_wandb_path == wandb_path + output_dir = get_attributions_subrun_dir(run_id, subrun_id) + + harvest_attributions( + config=config, + output_dir=output_dir, + harvest_subrun_id=harvest_subrun_id, + rank=rank, + world_size=world_size, + ) + + +def get_command( + wandb_path: str, + config_json: str, + harvest_subrun_id: str, + rank: int, + world_size: int, + subrun_id: str, +) -> str: + return ( + f"python -m spd.dataset_attributions.scripts.run_worker " + f'"{wandb_path}" ' + f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " + f"--rank {rank} " + f"--world_size {world_size} " + f"--subrun_id {subrun_id}" + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/storage.py b/spd/dataset_attributions/storage.py index 16181201d..9e6041cca 100644 --- a/spd/dataset_attributions/storage.py +++ b/spd/dataset_attributions/storage.py @@ -1,22 +1,37 @@ """Storage classes for dataset attributions. -Uses a residual-based storage approach for scalability: -- Component targets: stored directly in source_to_component matrix -- Output targets: stored as attributions to residual stream, computed on-the-fly via w_unembed +Stores raw (unnormalized) attribution sums. Normalization happens at query time using +stored metadata (CI sums, activation RMS, logit RMS). + +Four edge types, each with its own shape: +- regular: component → component [tgt_c, src_c] (signed + abs) +- embed: embed → component [tgt_c, vocab] (signed + abs) +- unembed: component → unembed [d_model, src_c] (signed only, residual space) +- embed_unembed: embed → unembed [d_model, vocab] (signed only, residual space) + +Abs variants are unavailable for unembed edges because abs is a nonlinear operation +incompatible with the residual-space storage trick. + +Normalization formula: + normed[t, s] = raw[t, s] / source_denom[s] / target_rms[t] +- source_denom is ci_sum[s] for component sources, embed_token_count[s] for embed sources +- target_rms is component activation RMS for component targets, logit RMS for output targets """ -import dataclasses -from collections.abc import Callable +import bisect from dataclasses import dataclass from pathlib import Path from typing import Literal import torch -from jaxtyping import Float from torch import Tensor from spd.log import logger +AttrMetric = Literal["attr", "attr_abs"] + +EPS = 1e-10 + @dataclass class DatasetAttributionEntry: @@ -28,318 +43,323 @@ class DatasetAttributionEntry: value: float -@dataclass class DatasetAttributionStorage: """Dataset-aggregated attribution strengths between components. - Uses residual-based storage for scalability with large vocabularies: - - source_to_component: direct attributions to component targets - - source_to_out_residual: attributions to output residual stream (for computing output attributions) - - Output attributions are computed on-the-fly: attr[src, output_token] = out_residual[src] @ w_unembed[:, token] + All layer names use canonical addressing (e.g., "embed", "0.glu.up", "output"). - Source indexing (rows): - - [0, vocab_size): wte tokens - - [vocab_size, vocab_size + n_components): component layers - - Target indexing: - - Component targets: [0, n_components) in source_to_component - - Output targets: computed via source_to_out_residual @ w_unembed + Internally stores raw sums — normalization applied at query time. + Public interface: get_top_sources(), get_top_targets(), save/load/merge. Key formats: - - wte tokens: "wte:{token_id}" - - component layers: "layer:c_idx" (e.g., "h.0.attn.q_proj:5") + - embed tokens: "embed:{token_id}" + - component layers: "canonical_layer:c_idx" (e.g., "0.glu.up:5") - output tokens: "output:{token_id}" """ - component_layer_keys: list[str] - """Component layer keys in order: ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...]""" + def __init__( + self, + regular_attr: dict[str, dict[str, Tensor]], + regular_attr_abs: dict[str, dict[str, Tensor]], + embed_attr: dict[str, Tensor], + embed_attr_abs: dict[str, Tensor], + unembed_attr: dict[str, Tensor], + embed_unembed_attr: Tensor, + w_unembed: Tensor, + ci_sum: dict[str, Tensor], + component_act_sq_sum: dict[str, Tensor], + logit_sq_sum: Tensor, + embed_token_count: Tensor, + ci_threshold: float, + n_tokens_processed: int, + ): + self._regular_attr = regular_attr + self._regular_attr_abs = regular_attr_abs + self._embed_attr = embed_attr + self._embed_attr_abs = embed_attr_abs + self._unembed_attr = unembed_attr + self._embed_unembed_attr = embed_unembed_attr + self._w_unembed = w_unembed + self._ci_sum = ci_sum + self._component_act_sq_sum = component_act_sq_sum + self._logit_sq_sum = logit_sq_sum + self._embed_token_count = embed_token_count + self.ci_threshold = ci_threshold + self.n_tokens_processed = n_tokens_processed - vocab_size: int - """Vocabulary size (number of wte and output tokens)""" + @property + def target_layers(self) -> set[str]: + return self._regular_attr.keys() | self._embed_attr.keys() - d_model: int - """Model hidden dimension (residual stream size)""" + def _target_n_components(self, layer: str) -> int | None: + if layer in self._embed_attr: + return self._embed_attr[layer].shape[0] + if layer in self._regular_attr: + first_source = next(iter(self._regular_attr[layer].values())) + return first_source.shape[0] + return None - source_to_component: Float[Tensor, "n_sources n_components"] - """Attributions from sources to component targets. Shape: (vocab_size + n_components, n_components)""" + @property + def n_components(self) -> int: + total = 0 + for layer in self.target_layers: + n = self._target_n_components(layer) + assert n is not None + total += n + return total + + @staticmethod + def _parse_key(key: str) -> tuple[str, int]: + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) - source_to_out_residual: Float[Tensor, "n_sources d_model"] - """Attributions from sources to output residual dimensions. Shape: (vocab_size + n_components, d_model)""" + def _select_metric( + self, metric: AttrMetric + ) -> tuple[dict[str, dict[str, Tensor]], dict[str, Tensor]]: + match metric: + case "attr": + return self._regular_attr, self._embed_attr + case "attr_abs": + return self._regular_attr_abs, self._embed_attr_abs - n_batches_processed: int - n_tokens_processed: int - ci_threshold: float + def _component_activation_rms(self, layer: str) -> Tensor: + """RMS activation for a component layer. Shape (n_components,).""" + return (self._component_act_sq_sum[layer] / self.n_tokens_processed).sqrt().clamp(min=EPS) - _component_key_to_idx: dict[str, int] = dataclasses.field( - default_factory=dict, repr=False, init=False - ) + def _logit_activation_rms(self) -> Tensor: + """RMS logit per token. Shape (vocab,).""" + return (self._logit_sq_sum / self.n_tokens_processed).sqrt().clamp(min=EPS) - def __post_init__(self) -> None: - self._component_key_to_idx = {k: i for i, k in enumerate(self.component_layer_keys)} + def _layer_ci_sum(self, layer: str) -> Tensor: + """CI sum for a source layer, clamped. Shape (n_components,).""" + return self._ci_sum[layer].clamp(min=EPS) - n_components = len(self.component_layer_keys) - n_sources = self.vocab_size + n_components + def _embed_count(self) -> Tensor: + """Per-token occurrence count, clamped. Shape (vocab,).""" + return self._embed_token_count.float().clamp(min=EPS) - expected_comp_shape = (n_sources, n_components) - assert self.source_to_component.shape == expected_comp_shape, ( - f"source_to_component shape {self.source_to_component.shape} " - f"doesn't match expected {expected_comp_shape}" - ) + def get_top_sources( + self, + target_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + ) -> list[DatasetAttributionEntry]: + target_layer, target_idx = self._parse_key(target_key) + + value_segments: list[Tensor] = [] + layer_names: list[str] = [] + if target_layer == "embed": + return [] + + if target_layer == "output": + if metric == "attr_abs": + return [] + w = self._w_unembed[:, target_idx].to(self._embed_unembed_attr.device) + target_act_rms = self._logit_activation_rms()[target_idx] + + for source_layer, attr_matrix in self._unembed_attr.items(): + raw = w @ attr_matrix # (src_c,) + value_segments.append(raw / self._layer_ci_sum(source_layer) / target_act_rms) + layer_names.append(source_layer) + + raw = w @ self._embed_unembed_attr # (vocab,) + value_segments.append(raw / self._embed_count() / target_act_rms) + layer_names.append("embed") + else: + regular_attr, embed_target_attr = self._select_metric(metric) + target_act_rms = self._component_activation_rms(target_layer)[target_idx] - expected_resid_shape = (n_sources, self.d_model) - assert self.source_to_out_residual.shape == expected_resid_shape, ( - f"source_to_out_residual shape {self.source_to_out_residual.shape} " - f"doesn't match expected {expected_resid_shape}" - ) + if target_layer in regular_attr: + for source_layer, attr_matrix in regular_attr[target_layer].items(): + raw = attr_matrix[target_idx, :] # (src_c,) + value_segments.append(raw / self._layer_ci_sum(source_layer) / target_act_rms) + layer_names.append(source_layer) - @property - def n_components(self) -> int: - return len(self.component_layer_keys) + if target_layer in embed_target_attr: + raw = embed_target_attr[target_layer][target_idx, :] # (vocab,) + value_segments.append(raw / self._embed_count() / target_act_rms) + layer_names.append("embed") - @property - def n_sources(self) -> int: - return self.vocab_size + self.n_components + return self._top_k_from_segments(value_segments, layer_names, k, sign) - def _parse_key(self, key: str) -> tuple[str, int]: - """Parse a key into (layer, idx).""" - layer, idx_str = key.rsplit(":", 1) - return layer, int(idx_str) + def get_top_targets( + self, + source_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + include_outputs: bool = True, + ) -> list[DatasetAttributionEntry]: + source_layer, source_idx = self._parse_key(source_key) - def _source_idx(self, key: str) -> int: - """Get source (row) index for a key. Raises KeyError if not a valid source.""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - assert 0 <= idx < self.vocab_size, ( - f"wte index {idx} out of range [0, {self.vocab_size})" - ) - return idx - case "output": - raise KeyError(f"output tokens cannot be sources: {key}") - case _: - return self.vocab_size + self._component_key_to_idx[key] - - def _component_target_idx(self, key: str) -> int: - """Get target index for a component key. Raises KeyError if output or invalid.""" - if key.startswith(("wte:", "output:")): - raise KeyError(f"Not a component target: {key}") - return self._component_key_to_idx[key] - - def _source_idx_to_key(self, idx: int) -> str: - """Convert source (row) index to key.""" - if idx < self.vocab_size: - return f"wte:{idx}" - return self.component_layer_keys[idx - self.vocab_size] - - def _component_target_idx_to_key(self, idx: int) -> str: - """Convert component target index to key.""" - return self.component_layer_keys[idx] - - def _output_target_idx_to_key(self, idx: int) -> str: - """Convert output token index to key.""" - return f"output:{idx}" - - def _is_output_target(self, key: str) -> bool: - """Check if key is an output target.""" - return key.startswith("output:") - - def _output_token_id(self, key: str) -> int: - """Extract token_id from an output key like 'output:123'. Asserts valid range.""" - _, token_id = self._parse_key(key) - assert 0 <= token_id < self.vocab_size, f"output index {token_id} out of range" - return token_id - - def has_source(self, key: str) -> bool: - """Check if a key can be a source (wte token or component layer).""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - return 0 <= idx < self.vocab_size - case "output": - return False - case _: - return key in self._component_key_to_idx - - def has_target(self, key: str) -> bool: - """Check if a key can be a target (component layer or output token).""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - return False - case "output": - return 0 <= idx < self.vocab_size - case _: - return key in self._component_key_to_idx + value_segments: list[Tensor] = [] + layer_names: list[str] = [] - def save(self, path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "component_layer_keys": self.component_layer_keys, - "vocab_size": self.vocab_size, - "d_model": self.d_model, - "source_to_component": self.source_to_component.cpu(), - "source_to_out_residual": self.source_to_out_residual.cpu(), - "n_batches_processed": self.n_batches_processed, - "n_tokens_processed": self.n_tokens_processed, - "ci_threshold": self.ci_threshold, - }, - path, - ) - size_mb = path.stat().st_size / (1024 * 1024) - logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") + if source_layer == "output": + return [] + elif source_layer == "embed": + regular, embed = self._select_metric(metric) + embed_count = self._embed_count()[source_idx] - @classmethod - def load(cls, path: Path) -> "DatasetAttributionStorage": - data = torch.load(path, weights_only=True, mmap=True) - return cls( - component_layer_keys=data["component_layer_keys"], - vocab_size=data["vocab_size"], - d_model=data["d_model"], - source_to_component=data["source_to_component"], - source_to_out_residual=data["source_to_out_residual"], - n_batches_processed=data["n_batches_processed"], - n_tokens_processed=data["n_tokens_processed"], - ci_threshold=data["ci_threshold"], - ) + for target_layer, attr_matrix in embed.items(): + raw = attr_matrix[:, source_idx] # (tgt_c,) + value_segments.append( + raw / embed_count / self._component_activation_rms(target_layer) + ) + layer_names.append(target_layer) - def get_attribution( - self, - source_key: str, - target_key: str, - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> float: - """Get attribution strength from source to target. - - Args: - source_key: Source component key (wte or component layer) - target_key: Target component key (component layer or output token) - w_unembed: Unembedding matrix, required if target is an output token - """ - src_idx = self._source_idx(source_key) + if include_outputs and metric == "attr": + residual = self._embed_unembed_attr[:, source_idx] # (d_model,) + raw = residual @ self._w_unembed # (vocab,) + value_segments.append(raw / embed_count / self._logit_activation_rms()) + layer_names.append("output") + else: + regular, embed = self._select_metric(metric) + ci = self._layer_ci_sum(source_layer)[source_idx] + + for target_layer, sources in regular.items(): + if source_layer not in sources: + continue + raw = sources[source_layer][:, source_idx] # (tgt_c,) + value_segments.append(raw / ci / self._component_activation_rms(target_layer)) + layer_names.append(target_layer) - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - return (self.source_to_out_residual[src_idx] @ w_unembed[:, token_id]).item() + if include_outputs and metric == "attr" and source_layer in self._unembed_attr: + residual = self._unembed_attr[source_layer][:, source_idx] # (d_model,) + raw = residual @ self._w_unembed # (vocab,) + value_segments.append(raw / ci / self._logit_activation_rms()) + layer_names.append("output") - tgt_idx = self._component_target_idx(target_key) - return self.source_to_component[src_idx, tgt_idx].item() + return self._top_k_from_segments(value_segments, layer_names, k, sign) - def _get_top_k( + def _top_k_from_segments( self, - values: Tensor, + value_segments: list[Tensor], + layer_names: list[str], k: int, sign: Literal["positive", "negative"], - idx_to_key: Callable[[int], str], ) -> list[DatasetAttributionEntry]: - """Get top-k entries from a 1D tensor of attribution values.""" + if not value_segments: + return [] + + all_values = torch.cat(value_segments) + offsets = [0] + for seg in value_segments: + offsets.append(offsets[-1] + len(seg)) + is_positive = sign == "positive" - top_vals, top_idxs = torch.topk(values, min(k, len(values)), largest=is_positive) + top_vals, top_idxs = torch.topk(all_values, min(k, len(all_values)), largest=is_positive) - # Filter to only values matching the requested sign mask = top_vals > 0 if is_positive else top_vals < 0 top_vals, top_idxs = top_vals[mask], top_idxs[mask] results = [] - for idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): - key = idx_to_key(idx) - layer, c_idx = self._parse_key(key) + for flat_idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): + seg_idx = bisect.bisect_right(offsets, flat_idx) - 1 + local_idx = flat_idx - offsets[seg_idx] + layer = layer_names[seg_idx] results.append( DatasetAttributionEntry( - component_key=key, + component_key=f"{layer}:{local_idx}", layer=layer, - component_idx=c_idx, + component_idx=local_idx, value=val, ) ) return results - def get_top_sources( - self, - target_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target. + def save(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "regular_attr": _to_cpu_nested(self._regular_attr), + "regular_attr_abs": _to_cpu_nested(self._regular_attr_abs), + "embed_attr": _to_cpu(self._embed_attr), + "embed_attr_abs": _to_cpu(self._embed_attr_abs), + "unembed_attr": _to_cpu(self._unembed_attr), + "embed_unembed_attr": self._embed_unembed_attr.detach().cpu(), + "w_unembed": self._w_unembed.detach().cpu(), + "ci_sum": _to_cpu(self._ci_sum), + "component_act_sq_sum": _to_cpu(self._component_act_sq_sum), + "logit_sq_sum": self._logit_sq_sum.detach().cpu(), + "embed_token_count": self._embed_token_count.detach().cpu(), + "ci_threshold": self.ci_threshold, + "n_tokens_processed": self.n_tokens_processed, + }, + path, + ) + size_mb = path.stat().st_size / (1024 * 1024) + logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") - Args: - target_key: Target component key (component layer or output token) - k: Number of top sources to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if target is an output token - """ - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - values = self.source_to_out_residual @ w_unembed[:, token_id] # (n_sources,) - else: - tgt_idx = self._component_target_idx(target_key) - values = self.source_to_component[:, tgt_idx] + @classmethod + def load(cls, path: Path) -> "DatasetAttributionStorage": + data = torch.load(path, weights_only=True) + return cls( + regular_attr=data["regular_attr"], + regular_attr_abs=data["regular_attr_abs"], + embed_attr=data["embed_attr"], + embed_attr_abs=data["embed_attr_abs"], + unembed_attr=data["unembed_attr"], + embed_unembed_attr=data["embed_unembed_attr"], + w_unembed=data["w_unembed"], + ci_sum=data["ci_sum"], + component_act_sq_sum=data["component_act_sq_sum"], + logit_sq_sum=data["logit_sq_sum"], + embed_token_count=data["embed_token_count"], + ci_threshold=data["ci_threshold"], + n_tokens_processed=data["n_tokens_processed"], + ) - return self._get_top_k(values, k, sign, self._source_idx_to_key) + @classmethod + def merge(cls, paths: list[Path]) -> "DatasetAttributionStorage": + """Merge partial attribution files from parallel workers. - def get_top_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - include_outputs: bool = True, - ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO. - - Args: - source_key: Source component key (wte or component layer) - k: Number of top targets to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if include_outputs=True - include_outputs: Whether to include output tokens in results + All stored values are raw sums — merge is element-wise addition. """ - src_idx = self._source_idx(source_key) - comp_values = self.source_to_component[src_idx, :] # (n_components,) + assert paths, "No files to merge" - if include_outputs: - assert w_unembed is not None, "w_unembed required when include_outputs=True" - # Compute attributions to all output tokens - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - all_values = torch.cat([comp_values, output_values]) + merged = cls.load(paths[0]) - def combined_idx_to_key(idx: int) -> str: - if idx < self.n_components: - return self._component_target_idx_to_key(idx) - return self._output_target_idx_to_key(idx - self.n_components) + for path in paths[1:]: + other = cls.load(path) + assert other.ci_threshold == merged.ci_threshold, "CI threshold mismatch" - return self._get_top_k(all_values, k, sign, combined_idx_to_key) + for target, sources in other._regular_attr.items(): + for source, tensor in sources.items(): + merged._regular_attr[target][source] += tensor + merged._regular_attr_abs[target][source] += other._regular_attr_abs[target][ + source + ] - return self._get_top_k(comp_values, k, sign, self._component_target_idx_to_key) + for target, tensor in other._embed_attr.items(): + merged._embed_attr[target] += tensor + merged._embed_attr_abs[target] += other._embed_attr_abs[target] - def get_top_component_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - ) -> list[DatasetAttributionEntry]: - """Get top-k component targets (excluding outputs) this source attributes TO. + for source, tensor in other._unembed_attr.items(): + merged._unembed_attr[source] += tensor - Convenience method that doesn't require w_unembed. - """ - return self.get_top_targets(source_key, k, sign, w_unembed=None, include_outputs=False) + merged._embed_unembed_attr += other._embed_unembed_attr - def get_top_output_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"], - ) -> list[DatasetAttributionEntry]: - """Get top-k output token targets this source attributes TO.""" - src_idx = self._source_idx(source_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - return self._get_top_k(output_values, k, sign, self._output_target_idx_to_key) + for layer in other._ci_sum: + merged._ci_sum[layer] += other._ci_sum[layer] + + for layer in other._component_act_sq_sum: + merged._component_act_sq_sum[layer] += other._component_act_sq_sum[layer] + + merged._logit_sq_sum += other._logit_sq_sum + merged._embed_token_count += other._embed_token_count + merged.n_tokens_processed += other.n_tokens_processed + + return merged + + +def _to_cpu_nested(d: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + target: {source: v.detach().cpu() for source, v in sources.items()} + for target, sources in d.items() + } + + +def _to_cpu(d: dict[str, Tensor]) -> dict[str, Tensor]: + return {k: v.detach().cpu() for k, v in d.items()} diff --git a/spd/editing/README.md b/spd/editing/README.md new file mode 100644 index 000000000..2860a7650 --- /dev/null +++ b/spd/editing/README.md @@ -0,0 +1,95 @@ +# spd.editing + +Component-level model editing for VPD decompositions. + +## Setup + +```python +from spd.editing import EditableModel, generate, measure_kl, measure_token_probs +from spd.harvest.repo import HarvestRepo +from spd.autointerp.repo import InterpRepo + +em, tok = EditableModel.from_wandb("wandb:goodfire/spd/s-892f140b") +harvest = HarvestRepo("s-892f140b") +interp = InterpRepo("s-892f140b") +``` + +## Finding components + +By autointerp label: +```python +from spd.editing import search_interpretations +matches = search_interpretations(harvest, interp, r"male pronoun") +# -> [ComponentMatch(key='h.1.attn.v_proj:52', label='male pronouns', ...)] +``` + +By output token PMI (best for ablation targets): +```python +from spd.editing import search_by_token_pmi +he_id = tok.encode("he") +matches = search_by_token_pmi(harvest, he_id, side="output", min_pmi=1.0) +``` + +By circuit optimization across examples: +```python +examples = [(tokens1, target_pos1), (tokens2, target_pos2), ...] +components = em.find_components_by_examples(examples, optim_steps=100) +# -> [('h.1.attn.v_proj:52', 0.9), ('h.1.mlp.down_proj:798', 0.8), ...] +``` + +## Inspecting components + +```python +from spd.editing import inspect_component +data = inspect_component(harvest, interp, "h.1.mlp.down_proj:798", tok) +# Prints: label, input/output PMI tokens, activation examples +``` + +Component geometry: +```python +vecs = em.get_component_vectors("h.1.mlp.down_proj:798") # read (V) and write (U) vectors +alignment = em.component_alignment("h.1.attn.o_proj:82", "h.1.mlp.c_fc:144") # cosine, percentile +boosted, suppressed = em.unembed_alignment("h.1.mlp.down_proj:798", tok) # top logit-lens tokens +``` + +## Editing (runtime masks) + +```python +# 0.0 = ablate, 2.0 = boost +edit_fn = em.make_edit_fn({"h.1.mlp.down_proj:798": 0.0, "h.1.attn.v_proj:52": 0.0}) + +# Generate with edits +text = generate(edit_fn, tokens, tok) + +# Measure effect +effect = measure_kl(em, edit_fn, eval_seqs) +print(f"KL={effect.mean_kl:.3f}, PPL: {effect.baseline_ppl:.1f} -> {effect.edited_ppl:.1f}") + +# Token group probability shifts +shifts = measure_token_probs(em, edit_fn, eval_seqs, { + "he": tok.encode("he"), + "she": tok.encode("she"), +}) +print(f"P(he) change: {shifts['he'].change_pct:+.1f}%") +``` + +CI-conditional editing (only edit where component is active): +```python +edit_fn = em.make_edit_fn({"h.1.mlp.down_proj:798": 0.0}, ci_threshold=0.1) +``` + +## Permanent weight editing + +```python +clean_em = em.without_components(["h.1.mlp.down_proj:798"]) +# Returns a new EditableModel with rank-1 subtraction baked into weights +text = generate(clean_em, tokens, tok) +``` + +## Circuit analysis + +```python +circuit = em.optimize_circuit(tokens, target_position=15, target_token=tok.encode("he")[0]) +em.print_circuit(circuit, tokens, tok, interp=interp) +# Prints: edges, node CI, component labels +``` diff --git a/spd/editing/__init__.py b/spd/editing/__init__.py new file mode 100644 index 000000000..784a4a2d4 --- /dev/null +++ b/spd/editing/__init__.py @@ -0,0 +1,39 @@ +"""Component-level model editing for VPD decompositions.""" + +from spd.editing._editing import ( + AblationEffect, + AlignmentResult, + ComponentMatch, + ComponentVectors, + EditableModel, + ForwardFn, + TokenGroupShift, + TokenPMIMatch, + UnembedMatch, + generate, + inspect_component, + measure_kl, + measure_token_probs, + parse_component_key, + search_by_token_pmi, + search_interpretations, +) + +__all__ = [ + "AblationEffect", + "AlignmentResult", + "ComponentMatch", + "ComponentVectors", + "EditableModel", + "ForwardFn", + "TokenGroupShift", + "TokenPMIMatch", + "UnembedMatch", + "generate", + "inspect_component", + "measure_kl", + "measure_token_probs", + "parse_component_key", + "search_by_token_pmi", + "search_interpretations", +] diff --git a/spd/editing/_editing.py b/spd/editing/_editing.py new file mode 100644 index 000000000..f34e384a4 --- /dev/null +++ b/spd/editing/_editing.py @@ -0,0 +1,808 @@ +"""Component-level model editing for VPD decompositions. + +Core class: EditableModel wraps ComponentModel + TransformerTopology and provides +methods for component analysis, editing, and measurement. It's callable +(tokens → logits) so it works as a ForwardFn anywhere. + +Usage: + from spd.editing import EditableModel, search_interpretations, generate + + em = EditableModel.from_wandb("wandb:goodfire/spd/s-892f140b") + matches = search_interpretations(harvest, interp, r"male pronoun") + + edit_fn = em.make_edit_fn({m.key: 0.0 for m in matches[:3]}) + text = generate(edit_fn, tokens, tokenizer) + effect = em.measure_kl(edit_fn, token_seqs) +""" + +import copy +import re +import sqlite3 +from collections.abc import Callable +from dataclasses import dataclass + +import orjson +import torch +import torch.nn.functional as F +from jaxtyping import Float, Int +from torch import Tensor + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.compute import OptimizedPromptAttributionResult +from spd.autointerp.repo import InterpRepo +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import make_mask_infos +from spd.topology.topology import TransformerTopology + +ForwardFn = Callable[[Int[Tensor, " seq"]], Float[Tensor, "seq vocab"]] + + +# -- Component key utilities --------------------------------------------------- + + +def parse_component_key(key: str) -> tuple[str, int]: + """'h.1.mlp.c_fc:802' -> ('h.1.mlp.c_fc', 802).""" + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +# -- Search (free functions, don't need the model) ----------------------------- + + +@dataclass +class ComponentMatch: + key: str + label: str + confidence: str + firing_density: float + mean_activations: dict[str, float] + + +def search_interpretations( + harvest: HarvestRepo, + interp: InterpRepo, + pattern: str, + min_firing_density: float = 0.0, +) -> list[ComponentMatch]: + """Search component interpretations by regex on label. Sorted by firing density desc.""" + all_interps = interp.get_all_interpretations() + summary = harvest.get_summary() + + matches = [] + for key, result in all_interps.items(): + if key not in summary: + continue + if not re.search(pattern, result.label, re.IGNORECASE): + continue + s = summary[key] + if s.firing_density < min_firing_density: + continue + matches.append( + ComponentMatch( + key=key, + label=result.label, + confidence=result.confidence, + firing_density=s.firing_density, + mean_activations=s.mean_activations, + ) + ) + + matches.sort(key=lambda m: -m.firing_density) + return matches + + +@dataclass +class TokenPMIMatch: + key: str + pmi: float + firing_density: float + + +def search_by_token_pmi( + harvest: HarvestRepo, + token_ids: list[int], + side: str, + min_pmi: float = 0.5, + min_firing_density: float = 0.01, + top_k: int = 20, +) -> list[TokenPMIMatch]: + """Find components by input or output token PMI. + + side="output" finds components that PREDICT the given tokens. + side="input" finds components that RESPOND TO (fire on) the given tokens. + + For ablation, you almost always want side="output" — ablating output-side + components suppresses token production with far less collateral damage than + ablating input-side components. + """ + assert side in ("input", "output") + column = "output_token_pmi" if side == "output" else "input_token_pmi" + target_set = set(token_ids) + summary = harvest.get_summary() + + db_path = harvest._dir / "harvest.db" + conn = sqlite3.connect(f"file:{db_path}?immutable=1", uri=True) + + results = [] + for row in conn.execute(f"SELECT component_key, {column} FROM components"): + key: str = row[0] + if key not in summary or summary[key].firing_density < min_firing_density: + continue + pmi_data: dict[str, list[list[float]]] = orjson.loads(row[1]) + max_pmi = 0.0 + for tok_id, pmi in pmi_data.get("top", []): + if int(tok_id) in target_set and pmi > max_pmi: + max_pmi = pmi + if max_pmi >= min_pmi: + results.append( + TokenPMIMatch( + key=key, + pmi=max_pmi, + firing_density=summary[key].firing_density, + ) + ) + + conn.close() + results.sort(key=lambda r: -r.pmi) + return results[:top_k] + + +def inspect_component( + harvest: HarvestRepo, + interp: InterpRepo, + key: str, + tokenizer: AppTokenizer, + n_examples: int = 5, + n_pmi_tokens: int = 10, +) -> ComponentData: + """Print a detailed inspection of a component and return its data.""" + comp = harvest.get_component(key) + assert comp is not None, f"No harvest data for {key}" + interp_result = interp.get_interpretation(key) + + ci = comp.mean_activations.get("causal_importance", None) + ci_str = f", ci={ci:.4f}" if ci is not None else "" + print(f"{'=' * 70}") + print(f"{key} (density={comp.firing_density:.4f}{ci_str})") + if interp_result: + print(f"Label: [{interp_result.confidence}] {interp_result.label}") + print() + + decode = tokenizer.decode + + print("INPUT tokens (what makes it fire):") + for tok_id, pmi in comp.input_token_pmi.top[:n_pmi_tokens]: + print(f" {decode([tok_id]):15s} PMI={pmi:.2f}") + + print("\nOUTPUT tokens (what it predicts):") + for tok_id, pmi in comp.output_token_pmi.top[:n_pmi_tokens]: + print(f" {decode([tok_id]):15s} PMI={pmi:.2f}") + + print(f"\nActivation examples ({n_examples}):") + for ex in comp.activation_examples[:n_examples]: + parts = [] + for tid, firing in zip(ex.token_ids, ex.firings, strict=True): + tok_str = decode([tid]) + parts.append(f">>>{tok_str}<<<" if firing else tok_str) + act_vals = ex.activations.get("causal_importance", ex.activations.get("activation", [])) + max_act = max(act_vals) if act_vals else 0 + print(f" [max_act={max_act:.3f}] {''.join(parts)}") + print() + + return comp + + +# -- Result types -------------------------------------------------------------- + + +@dataclass +class ComponentVectors: + """Read (V) and write (U) vectors for a single rank-1 component. + + The component forward is: act = x @ read, out = act * write. + So `read` is the input direction (d_in) and `write` is the output direction (d_out). + """ + + key: str + read: Tensor + write: Tensor + d_in: int + d_out: int + + +@dataclass +class AlignmentResult: + cosine: float + dot: float + norm_a: float + norm_b: float + percentile: float + space_dim: int + space_name: str + + +@dataclass +class UnembedMatch: + token_id: int + token_str: str + cosine: float + dot: float + + +@dataclass +class AblationEffect: + mean_kl: float + baseline_ppl: float + edited_ppl: float + n_tokens: int + + @property + def ppl_increase_pct(self) -> float: + return (self.edited_ppl / self.baseline_ppl - 1) * 100 + + +@dataclass +class TokenGroupShift: + group_name: str + baseline_mean_prob: float + edited_mean_prob: float + n_positions: int + + @property + def change_pct(self) -> float: + if self.baseline_mean_prob == 0: + return float("inf") if self.edited_mean_prob > 0 else 0.0 + return (self.edited_mean_prob / self.baseline_mean_prob - 1) * 100 + + +# -- EditableModel ------------------------------------------------------------- + + +class EditableModel: + """ComponentModel + TransformerTopology with methods for editing and analysis. + + Callable: em(tokens) returns logits, so it works as a ForwardFn. + """ + + def __init__(self, model: ComponentModel) -> None: + self.model = model + self.topology = TransformerTopology(model.target_model) + + @classmethod + def from_wandb( + cls, wandb_path: str, device: str = "cuda" + ) -> tuple["EditableModel", AppTokenizer]: + """Load from wandb path. Returns (editable_model, tokenizer).""" + run_info = SPDRunInfo.from_path(wandb_path) + model = ComponentModel.from_run_info(run_info).to(device).eval() + assert run_info.config.tokenizer_name is not None + tokenizer = AppTokenizer.from_pretrained(run_info.config.tokenizer_name) + return cls(model), tokenizer + + def __call__(self, tokens: Int[Tensor, " seq"]) -> Float[Tensor, "seq vocab"]: + return self.model(tokens.unsqueeze(0)).squeeze(0) + + # -- Component geometry ---------------------------------------------------- + + def get_component_vectors(self, key: str) -> ComponentVectors: + """Get the read (V[:, c]) and write (U[c, :]) vectors for a component.""" + layer, idx = parse_component_key(key) + comp = self.model.components[layer] + return ComponentVectors( + key=key, + read=comp.V[:, idx], + write=comp.U[idx, :], + d_in=int(comp.d_in), # pyright: ignore[reportArgumentType] + d_out=int(comp.d_out), # pyright: ignore[reportArgumentType] + ) + + def component_alignment(self, key_a: str, key_b: str) -> AlignmentResult: + """Cosine/dot between key_a's write direction and key_b's read direction. + + Asserts they share a space (key_a's d_out == key_b's d_in). + Percentile is empirical over all pairs in the same two layers. + """ + a = self.get_component_vectors(key_a) + b = self.get_component_vectors(key_b) + assert a.d_out == b.d_in, ( + f"{key_a} writes d={a.d_out}, {key_b} reads d={b.d_in} — no shared space" + ) + + cos = F.cosine_similarity(a.write.unsqueeze(0), b.read.unsqueeze(0)).item() + dot = (a.write * b.read).sum().item() + + layer_a, _ = parse_component_key(key_a) + layer_b, _ = parse_component_key(key_b) + all_writes = self.model.components[layer_a].U + all_reads = self.model.components[layer_b].V + all_cos = F.normalize(all_writes, dim=1) @ F.normalize(all_reads, dim=0) + percentile = (all_cos.abs() < abs(cos)).float().mean().item() * 100 + + resid_dim = self.topology.unembed_module.in_features + space_name = "residual" if a.d_out == resid_dim else "neuron" + + return AlignmentResult( + cosine=cos, + dot=dot, + norm_a=a.write.norm().item(), + norm_b=b.read.norm().item(), + percentile=percentile, + space_dim=a.d_out, + space_name=space_name, + ) + + def unembed_alignment( + self, + key: str, + tokenizer: AppTokenizer, + top_k: int = 10, + ) -> tuple[list[UnembedMatch], list[UnembedMatch]]: + """Top boosted and suppressed tokens by alignment with write direction. + + Only works for components that write to the residual stream. + Returns (top_boosted, top_suppressed). + """ + vecs = self.get_component_vectors(key) + unembed = self.topology.unembed_module.weight # [vocab, d_model] + assert vecs.d_out == unembed.shape[1], ( + f"{key} writes d={vecs.d_out}, unembed expects d={unembed.shape[1]}" + ) + + all_cos = F.cosine_similarity(vecs.write.unsqueeze(0), unembed, dim=1) + all_dot = (vecs.write.unsqueeze(0) * unembed).sum(dim=1) + + decode = tokenizer.decode + + top_vals, top_ids = all_cos.topk(top_k) + boosted = [ + UnembedMatch(int(t), decode([int(t)]), v.item(), all_dot[t].item()) + for v, t in zip(top_vals, top_ids, strict=True) + ] + + bot_vals, bot_ids = all_cos.topk(top_k, largest=False) + suppressed = [ + UnembedMatch(int(t), decode([int(t)]), v.item(), all_dot[t].item()) + for v, t in zip(bot_vals, bot_ids, strict=True) + ] + + return boosted, suppressed + + def get_component_activations( + self, + tokens: Int[Tensor, " seq"], + key: str, + ) -> Float[Tensor, " seq"]: + """Component activation (v_c^T @ x) at each sequence position.""" + layer, idx = parse_component_key(key) + with torch.no_grad(): + out = self.model(tokens.unsqueeze(0), cache_type="input") + pre_weight_acts = out.cache[layer] # [1, seq, d_in] + comp = self.model.components[layer] + return (pre_weight_acts @ comp.V[:, idx]).squeeze(0) # [seq] + + def get_ci( + self, + tokens: Int[Tensor, " seq"], + ) -> dict[str, Float[Tensor, " seq C"]]: + """Get CI values for all components at all positions. Returns {layer: [seq, C]}.""" + with torch.no_grad(): + out = self.model(tokens.unsqueeze(0), cache_type="input") + ci = self.model.calc_causal_importances( + pre_weight_acts=out.cache, + sampling="continuous", + detach_inputs=False, + ) + return {layer: vals.squeeze(0) for layer, vals in ci.lower_leaky.items()} + + def find_components_by_examples( + self, + examples: list[tuple[Int[Tensor, " seq"], int]], + optim_steps: int = 100, + context_window: int = 10, + ci_alive_threshold: float = 0.0, + min_frequency: float = 0.7, + top_k: int = 20, + ) -> list[tuple[str, float]]: + """Find components needed for a behavior by optimizing sparse CI on examples. + + For each (token_sequence, target_position) pair, runs CI optimization + to find the minimal set of components needed to predict the token at + target_position. Components that appear in the sparse set across + >= min_frequency of examples are returned. + + Args: + examples: List of (token_sequence, target_position) pairs. + target_position is the sequence index of the token whose + prediction we want to explain. + optim_steps: Number of optimization steps per example. + ci_alive_threshold: CI threshold for considering a component "active" + in the optimized mask. + min_frequency: Fraction of examples where a component must be active. + top_k: Number of components to return. + + Returns: + List of (component_key, frequency) sorted by frequency descending. + """ + from spd.app.backend.optim_cis import ( + CELossConfig, + OptimCIConfig, + optimize_ci_values, + ) + from spd.configs import ImportanceMinimalityLossConfig + + counts: dict[str, int] = {} + n_examples = len(examples) + + for i, (tokens, target_pos) in enumerate(examples): + assert target_pos > 0, "target_position must be > 0 (need a previous position)" + + # Truncate to context window ending at target_pos (inclusive) + start = max(0, target_pos - context_window + 1) + window = tokens[start : target_pos + 1] + window_target_pos = target_pos - start + target_token = window[window_target_pos].item() + + config = OptimCIConfig( + seed=42, + lr=0.1, + steps=optim_steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.1, + log_freq=optim_steps + 1, # suppress logging + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=1.0), + loss_config=CELossConfig( + coeff=20.0, + position=window_target_pos - 1, + label_token=int(target_token), + ), + sampling="continuous", + ce_kl_rounding_threshold=0.5, + mask_type="ci", + adv_pgd=None, + ) + + result = optimize_ci_values( + model=self.model, + tokens=window.unsqueeze(0), + config=config, + device=str(tokens.device), + ) + + # Extract active components from optimized CI + ci_outputs = result.params.create_ci_outputs(self.model, str(tokens.device)) + for layer_name, ci_vals in ci_outputs.lower_leaky.items(): + # ci_vals: [1, window_len, C] + pred_pos = window_target_pos - 1 + active = ci_vals[0, pred_pos, :] > ci_alive_threshold + for c in active.nonzero(as_tuple=True)[0]: + key = f"{layer_name}:{c.item()}" + counts[key] = counts.get(key, 0) + 1 + + print(f" Example {i + 1}/{n_examples}: L0={result.metrics.l0_total:.0f}") + + min_count = int(min_frequency * n_examples) + freq_results = [ + (key, count / n_examples) for key, count in counts.items() if count >= min_count + ] + freq_results.sort(key=lambda x: -x[1]) + return freq_results[:top_k] + + def optimize_circuit( + self, + tokens: Int[Tensor, " seq"], + target_position: int, + target_token: int, + optim_steps: int = 200, + imp_min_coeff: float = 0.1, + ce_coeff: float = 20.0, + ) -> OptimizedPromptAttributionResult: + """Optimize a sparse circuit for predicting target_token at target_position. + + Returns the full attribution graph (edges between components) from the + app's compute pipeline. The result includes node CI values, component + activations, and edge strengths. + + target_position is the sequence index of the token being predicted + (the logits at position target_position predict this token, so internally + we optimize for loss at position target_position). + """ + from spd.app.backend.compute import compute_prompt_attributions_optimized + from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig + from spd.configs import ImportanceMinimalityLossConfig + from spd.topology.gradient_connectivity import get_sources_by_target + + device = str(tokens.device) + batched = tokens.unsqueeze(0) + + sources_by_target = get_sources_by_target(self.model, self.topology, device, "continuous") + + config = OptimCIConfig( + seed=42, + lr=0.1, + steps=optim_steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.1, + log_freq=optim_steps + 1, + imp_min_config=ImportanceMinimalityLossConfig(coeff=imp_min_coeff, pnorm=0.5, beta=1.0), + loss_config=CELossConfig( + coeff=ce_coeff, + position=target_position, + label_token=target_token, + ), + sampling="continuous", + ce_kl_rounding_threshold=0.5, + mask_type="ci", + adv_pgd=None, + ) + + return compute_prompt_attributions_optimized( + model=self.model, + topology=self.topology, + tokens=batched, + sources_by_target=sources_by_target, + optim_config=config, + output_prob_threshold=0.01, + device=device, + ) + + def print_circuit( + self, + circuit: OptimizedPromptAttributionResult, + tokens: Int[Tensor, " seq"], + tok: AppTokenizer, + interp: "InterpRepo | None" = None, + top_edges: int = 5, + min_ci: float = 0.0, + ) -> None: + """Print a human-readable summary of an optimized circuit.""" + from collections import defaultdict + + spans = tok.get_spans(tokens.tolist()) + + def parse_node(key: str) -> tuple[str, int, int]: + parts = key.split(":") + return ":".join(parts[:-2]), int(parts[-2]), int(parts[-1]) + + def node_label(key: str) -> str: + layer, seq, cidx = parse_node(key) + label = "" + if interp is not None: + ir = interp.get_interpretation(f"{layer}:{cidx}") + if ir: + label = f" [{ir.label[:35]}]" + return f"{layer}:{cidx}@{spans[seq].strip()}(p{seq}){label}" + + edges_by_target: dict[str, list[tuple[str, float, bool]]] = defaultdict(list) + for e in circuit.edges: + edges_by_target[str(e.target)].append((str(e.source), e.strength, e.is_cross_seq)) + + print(f"Circuit: {len(circuit.edges)} edges, L0={circuit.metrics.l0_total:.0f}") + print(f"Tokens: {list(enumerate(spans))}\n") + + for tgt_key in sorted(edges_by_target.keys()): + ci = circuit.node_ci_vals.get(tgt_key, 0) + if ci <= min_ci: + continue + + sources = edges_by_target[tgt_key] + sources.sort(key=lambda x: -abs(x[1])) + + print(f"{node_label(tgt_key)} ci={ci:.3f}") + for src_key, strength, cross_seq in sources[:top_edges]: + cross = " [x-seq]" if cross_seq else "" + print(f" <- {node_label(src_key)} attr={strength:+.4f}{cross}") + print() + + # -- Editing (mask-based, runtime) ----------------------------------------- + + def _edited_forward_batched( + self, + tokens: Int[Tensor, "1 seq"], + edits: dict[str, float], + ) -> Float[Tensor, "1 seq vocab"]: + """Forward with component mask edits applied uniformly (batched internal).""" + seq_len = tokens.shape[1] + device = tokens.device + + component_masks = { + layer: torch.ones(1, seq_len, C, device=device) + for layer, C in self.model.module_to_c.items() + } + for key, value in edits.items(): + layer, idx = parse_component_key(key) + assert layer in component_masks, f"Unknown layer: {layer}" + component_masks[layer][0, :, idx] = value + + mask_infos = make_mask_infos(component_masks, routing_masks="all") + return self.model(tokens, mask_infos=mask_infos) + + def _ci_guided_forward_batched( + self, + tokens: Int[Tensor, "1 seq"], + edits: dict[str, float], + ci_threshold: float, + ) -> Float[Tensor, "1 seq vocab"]: + """Forward with edits applied only where component CI exceeds threshold (batched).""" + seq_len = tokens.shape[1] + device = tokens.device + + output_with_cache = self.model(tokens, cache_type="input") + ci_outputs = self.model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling="continuous", + detach_inputs=False, + ) + ci_vals = ci_outputs.lower_leaky + + component_masks = { + layer: torch.ones(1, seq_len, C, device=device) + for layer, C in self.model.module_to_c.items() + } + for key, value in edits.items(): + layer, idx = parse_component_key(key) + assert layer in component_masks, f"Unknown layer: {layer}" + high_ci = ci_vals[layer][0, :, idx] > ci_threshold + component_masks[layer][0, high_ci, idx] = value + + mask_infos = make_mask_infos(component_masks, routing_masks="all") + return self.model(tokens, mask_infos=mask_infos) + + def make_edit_fn( + self, + edits: dict[str, float], + ci_threshold: float | None = None, + ) -> ForwardFn: + """Create a reusable unbatched tokens [seq] → logits [seq, vocab] function.""" + if ci_threshold is not None: + return lambda tokens: self._ci_guided_forward_batched( + tokens.unsqueeze(0), edits, ci_threshold + ).squeeze(0) + return lambda tokens: self._edited_forward_batched(tokens.unsqueeze(0), edits).squeeze(0) + + # -- Permanent weight editing ---------------------------------------------- + + def without_components(self, ablate_keys: list[str]) -> "EditableModel": + """Deep copy with components permanently subtracted from target model weights. + + The returned model's target_model is a standard transformer — no CI + function or mask_infos needed at inference. + """ + edited_model = copy.deepcopy(self.model) + + by_layer: dict[str, list[int]] = {} + for key in ablate_keys: + layer, idx = parse_component_key(key) + by_layer.setdefault(layer, []).append(idx) + + for layer_name, indices in by_layer.items(): + components = edited_model.components[layer_name] + target_module = edited_model.target_model.get_submodule(layer_name) + + for idx in indices: + contribution = (components.V[:, idx : idx + 1] @ components.U[idx : idx + 1, :]).T + target_module.weight.data -= contribution # pyright: ignore[reportOperatorIssue] + + return EditableModel(edited_model) + + +# -- Free functions (work with any ForwardFn) ---------------------------------- + + +def generate( + forward_fn: ForwardFn, + tokens: Int[Tensor, " seq"], + tokenizer: AppTokenizer, + max_new_tokens: int = 30, + temperature: float = 0.0, +) -> str: + """Greedy (temperature=0) or sampled generation from an arbitrary forward function. + + Takes unbatched tokens [seq]. Strips trailing EOS to avoid the model + treating the prompt as complete. + """ + eos_id = tokenizer.eos_token_id + if tokens[-1].item() == eos_id: + tokens = tokens[:-1] + generated = tokens.clone() + for _ in range(max_new_tokens): + logits = forward_fn(generated) + next_logits = logits[-1] + if temperature == 0: + next_id = next_logits.argmax() + else: + probs = F.softmax(next_logits / temperature, dim=-1) + next_id = torch.multinomial(probs, 1).squeeze() + generated = torch.cat([generated, next_id.unsqueeze(0)]) + if next_id.item() == tokenizer.eos_token_id: + break + return tokenizer.decode(generated.tolist()) + + +def measure_kl( + baseline_fn: ForwardFn, + edited_fn: ForwardFn, + token_seqs: list[Int[Tensor, " seq"]], +) -> AblationEffect: + """KL divergence and perplexity shift between two forward functions. + + Takes unbatched token sequences [seq]. + """ + total_kl = 0.0 + total_baseline_nll = 0.0 + total_edited_nll = 0.0 + total_tokens = 0 + + for tokens in token_seqs: + if tokens.shape[0] < 3: + continue + + with torch.no_grad(): + baseline_logits = baseline_fn(tokens) + edited_logits = edited_fn(tokens) + + baseline_lp = F.log_softmax(baseline_logits[:-1], dim=-1) + edited_lp = F.log_softmax(edited_logits[:-1], dim=-1) + + kl = F.kl_div(edited_lp, baseline_lp.exp(), reduction="sum", log_target=False) + + targets = tokens[1:] + baseline_nll = -baseline_lp[range(len(targets)), targets].sum() + edited_nll = -edited_lp[range(len(targets)), targets].sum() + + total_kl += kl.item() + total_baseline_nll += baseline_nll.item() + total_edited_nll += edited_nll.item() + total_tokens += len(targets) + + assert total_tokens > 0, "No tokens to evaluate" + return AblationEffect( + mean_kl=total_kl / total_tokens, + baseline_ppl=torch.exp(torch.tensor(total_baseline_nll / total_tokens)).item(), + edited_ppl=torch.exp(torch.tensor(total_edited_nll / total_tokens)).item(), + n_tokens=total_tokens, + ) + + +def measure_token_probs( + baseline_fn: ForwardFn, + edited_fn: ForwardFn, + token_seqs: list[Int[Tensor, " seq"]], + token_groups: dict[str, list[int]], +) -> dict[str, TokenGroupShift]: + """Probability shift for named groups of token IDs between two forward functions. + + Takes unbatched token sequences [seq]. + """ + baseline_sums: dict[str, float] = {name: 0.0 for name in token_groups} + edited_sums: dict[str, float] = {name: 0.0 for name in token_groups} + total_positions = 0 + + for tokens in token_seqs: + with torch.no_grad(): + baseline_logits = baseline_fn(tokens) + edited_logits = edited_fn(tokens) + + bp = F.softmax(baseline_logits, dim=-1) + ep = F.softmax(edited_logits, dim=-1) + + for name, ids in token_groups.items(): + baseline_sums[name] += bp[:, ids].sum().item() + edited_sums[name] += ep[:, ids].sum().item() + total_positions += bp.shape[0] + + assert total_positions > 0 + return { + name: TokenGroupShift( + group_name=name, + baseline_mean_prob=baseline_sums[name] / total_positions, + edited_mean_prob=edited_sums[name] / total_positions, + n_positions=total_positions, + ) + for name in token_groups + } diff --git a/spd/editing/generate_token_divergence.py b/spd/editing/generate_token_divergence.py new file mode 100644 index 000000000..f7569df72 --- /dev/null +++ b/spd/editing/generate_token_divergence.py @@ -0,0 +1,198 @@ +"""Generate per-token divergence data for the token divergence visualisation. + +Runs forward passes on dataset text under named component ablations, +computes KL, reverse KL, JSD, and CE diff per token, writes JSON. + +Usage: + python -m spd.editing.generate_token_divergence \\ + wandb:goodfire/spd/s-892f140b \\ + --edits edits.yaml \\ + --n_tokens 1500 \\ + --out_path /path/to/www/data/kl_tokens.json + +edits.yaml format: + Male pronouns: + - h.1.mlp.down_proj:798 + - h.1.mlp.c_fc:144 + - h.1.attn.o_proj:82 + Question marks: + - h.1.mlp.down_proj:534 +""" + +import json +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +import yaml +from datasets import load_dataset + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.editing import EditableModel, ForwardFn +from spd.settings import SPD_OUT_DIR + +TokenData = dict[str, Any] + + +def compute_token_divergence( + em: EditableModel, + edit_fn: ForwardFn, + token_ids: list[int], + tok: AppTokenizer, + top_k: int = 5, +) -> list[TokenData]: + tokens = torch.tensor(token_ids, device="cuda") + spans = tok.get_spans(token_ids) + + with torch.no_grad(): + bl_logits = em(tokens) + ed_logits = edit_fn(tokens) + + bl_probs = F.softmax(bl_logits, dim=-1) + ed_probs = F.softmax(ed_logits, dim=-1) + bl_lp = F.log_softmax(bl_logits, dim=-1) + ed_lp = F.log_softmax(ed_logits, dim=-1) + + # All metrics at positions [0..seq-2], predicting tokens [1..seq-1] + fwd_kl_per_vocab = bl_probs[:-1] * (bl_lp[:-1] - ed_lp[:-1]) + fwd_kl = fwd_kl_per_vocab.sum(dim=-1) + rev_kl = (ed_probs[:-1] * (ed_lp[:-1] - bl_lp[:-1])).sum(dim=-1) + + m_probs = 0.5 * (bl_probs[:-1] + ed_probs[:-1]) + m_lp = m_probs.log() + jsd = 0.5 * (bl_probs[:-1] * (bl_lp[:-1] - m_lp)).sum(-1) + 0.5 * ( + ed_probs[:-1] * (ed_lp[:-1] - m_lp) + ).sum(-1) + + targets = tokens[1:] + ce_diff = -ed_lp[:-1][range(len(targets)), targets] - ( + -bl_lp[:-1][range(len(targets)), targets] + ) + + result: list[TokenData] = [] + for i in range(len(tokens)): + if i == 0: + result.append( + {"s": spans[i], "kl": 0, "rkl": 0, "jsd": 0, "ce": 0, "bl": [], "ed": [], "kc": []} + ) + continue + + prev = i - 1 + bl_top_v, bl_top_i = bl_probs[prev].topk(top_k) + ed_top_v, ed_top_i = ed_probs[prev].topk(top_k) + + bl_top = [ + [tok.decode([int(t)]), round(v.item(), 4)] + for v, t in zip(bl_top_v, bl_top_i, strict=True) + ] + ed_top = [ + [tok.decode([int(t)]), round(v.item(), 4)] + for v, t in zip(ed_top_v, ed_top_i, strict=True) + ] + + kl_contribs = fwd_kl_per_vocab[prev] + _, kl_top_i = kl_contribs.abs().topk(top_k) + kl_top = [ + [ + tok.decode([int(idx)]), + round(bl_probs[prev, idx].item(), 4), + round(ed_probs[prev, idx].item(), 4), + round(kl_contribs[idx].item(), 5), + ] + for idx in kl_top_i + ] + + result.append( + { + "s": spans[i], + "kl": round(fwd_kl[prev].item(), 5), + "rkl": round(rev_kl[prev].item(), 5), + "jsd": round(jsd[prev].item(), 5), + "ce": round(ce_diff[prev].item(), 5), + "bl": bl_top, + "ed": ed_top, + "kc": kl_top, + } + ) + + return result + + +def load_stories(n_tokens: int, max_seq_len: int = 300) -> list[list[int]]: + """Load stories from SimpleStories until we have >= n_tokens.""" + ds = load_dataset("SimpleStories/SimpleStories", split="train", streaming=True) + tok = AppTokenizer.from_pretrained("goodfire/SimpleStories-Llama-tokenizer") + stories = [] + total = 0 + for item in ds: + token_ids = tok.encode(item["story"]) + if len(token_ids) > max_seq_len: + token_ids = token_ids[:max_seq_len] + stories.append(token_ids) + total += len(token_ids) + if total >= n_tokens: + break + return stories + + +def main( + wandb_path: str, + edits: str, + n_tokens: int = 1500, + out_path: str | None = None, +) -> None: + edits_path = Path(edits) + assert edits_path.exists(), f"Edits file not found: {edits_path}" + with open(edits_path) as f: + edits_config: dict[str, list[str]] = yaml.safe_load(f) + + if out_path is None: + out_path = str(SPD_OUT_DIR / "www" / "data" / "kl_tokens.json") + out = Path(out_path) + out.parent.mkdir(parents=True, exist_ok=True) + + em, tok = EditableModel.from_wandb(wandb_path) + stories = load_stories(n_tokens) + total_tokens = sum(len(s) for s in stories) + print(f"Loaded {len(stories)} stories, {total_tokens} tokens") + + all_data: dict[str, Any] = {} + for edit_name, component_keys in edits_config.items(): + edit_dict = {k: 0.0 for k in component_keys} + edit_fn = em.make_edit_fn(edit_dict) + + edit_stories = [] + for story_ids in stories: + tokens = compute_token_divergence(em, edit_fn, story_ids, tok) + edit_stories.append(tokens) + + all_data[edit_name] = {"components": component_keys, "stories": edit_stories} + print(f" {edit_name}: done") + + # Global p99 scales + def p99(vals: list[float]) -> float: + s = sorted(vals) + return s[int(0.99 * len(s))] + + def collect(key: str) -> list[float]: + return [t[key] for e in all_data.values() for s in e["stories"] for t in s if t[key] != 0] + + all_data["_meta"] = { + "kl_max": round(p99(collect("kl")), 4), + "rkl_max": round(p99(collect("rkl")), 4), + "jsd_max": round(p99(collect("jsd")), 4), + "ce_max": round(p99([abs(v) for v in collect("ce")]), 4), + } + + with open(out, "w") as f: + json.dump(all_data, f, separators=(",", ":")) + + size_kb = out.stat().st_size / 1024 + print(f"Wrote {size_kb:.0f} KB to {out}") + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/eval.py b/spd/eval.py index c6f0b47ff..a3f999e5f 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -12,7 +12,9 @@ from spd.configs import ( CEandKLLossesConfig, CI_L0Config, + CIHiddenActsReconLossConfig, CIHistogramsConfig, + CIMaskedAttnPatternsReconLossConfig, CIMaskedReconLayerwiseLossConfig, CIMaskedReconLossConfig, CIMaskedReconSubsetLossConfig, @@ -24,11 +26,16 @@ ImportanceMinimalityLossConfig, MetricConfigType, PermutedCIPlotsConfig, + PersistentPGDReconEvalConfig, + PersistentPGDReconLossConfig, + PersistentPGDReconSubsetEvalConfig, + PersistentPGDReconSubsetLossConfig, PGDMultiBatchReconLossConfig, PGDMultiBatchReconSubsetLossConfig, PGDReconLayerwiseLossConfig, PGDReconLossConfig, PGDReconSubsetLossConfig, + StochasticAttnPatternsReconLossConfig, StochasticHiddenActsReconLossConfig, StochasticReconLayerwiseLossConfig, StochasticReconLossConfig, @@ -38,6 +45,10 @@ UVPlotsConfig, ) from spd.metrics import UnmaskedReconLoss +from spd.metrics.attn_patterns_recon_loss import ( + CIMaskedAttnPatternsReconLoss, + StochasticAttnPatternsReconLoss, +) from spd.metrics.base import Metric from spd.metrics.ce_and_kl_losses import CEandKLLosses from spd.metrics.ci_histograms import CIHistograms @@ -48,6 +59,7 @@ from spd.metrics.ci_mean_per_component import CIMeanPerComponent from spd.metrics.component_activation_density import ComponentActivationDensity from spd.metrics.faithfulness_loss import FaithfulnessLoss +from spd.metrics.hidden_acts_recon_loss import CIHiddenActsReconLoss, StochasticHiddenActsReconLoss from spd.metrics.identity_ci_error import IdentityCIError from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss from spd.metrics.permuted_ci_plots import PermutedCIPlots @@ -55,13 +67,14 @@ from spd.metrics.pgd_masked_recon_loss import PGDReconLoss from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss from spd.metrics.pgd_utils import CreateDataIter, calc_multibatch_pgd_masked_recon_loss -from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss +from spd.metrics.ppgd_eval_losses import PPGDReconEval from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss from spd.metrics.stochastic_recon_loss import StochasticReconLoss from spd.metrics.stochastic_recon_subset_ce_and_kl import StochasticReconSubsetCEAndKL from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss from spd.metrics.uv_plots import UVPlots from spd.models.component_model import ComponentModel, OutputWithCache +from spd.persistent_pgd import PersistentPGDState from spd.routing import AllLayersRouter, get_subset_router from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed from spd.utils.general_utils import dict_safe_update_, extract_batch_data @@ -121,6 +134,9 @@ def init_metric( model: ComponentModel, run_config: Config, device: str, + ppgd_states: dict[ + PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, PersistentPGDState + ], ) -> Metric: match cfg: case ImportanceMinimalityLossConfig(): @@ -260,6 +276,57 @@ def init_metric( use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, ) + case CIHiddenActsReconLossConfig(): + metric = CIHiddenActsReconLoss(model=model, device=device) + case PersistentPGDReconEvalConfig(): + matching = [ + s for k, s in ppgd_states.items() if isinstance(k, PersistentPGDReconLossConfig) + ] + assert len(matching) == 1 + metric = PPGDReconEval( + model=model, + device=device, + effective_sources=matching[0].get_effective_sources(), + use_delta_component=run_config.use_delta_component, + output_loss_type=run_config.output_loss_type, + metric_name=cfg.classname, + ) + case PersistentPGDReconSubsetEvalConfig(): + matching = [ + s + for k, s in ppgd_states.items() + if isinstance(k, PersistentPGDReconSubsetLossConfig) + ] + assert len(matching) == 1 + metric = PPGDReconEval( + model=model, + device=device, + effective_sources=matching[0].get_effective_sources(), + use_delta_component=run_config.use_delta_component, + output_loss_type=run_config.output_loss_type, + metric_name=cfg.classname, + ) + case CIMaskedAttnPatternsReconLossConfig(): + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device=device, + n_heads=cfg.n_heads, + q_proj_path=cfg.q_proj_path, + k_proj_path=cfg.k_proj_path, + c_attn_path=cfg.c_attn_path, + ) + case StochasticAttnPatternsReconLossConfig(): + metric = StochasticAttnPatternsReconLoss( + model=model, + device=device, + sampling=run_config.sampling, + use_delta_component=run_config.use_delta_component, + n_mask_samples=run_config.n_mask_samples, + n_heads=cfg.n_heads, + q_proj_path=cfg.q_proj_path, + k_proj_path=cfg.k_proj_path, + c_attn_path=cfg.c_attn_path, + ) case UVPlotsConfig(): metric = UVPlots( model=model, @@ -273,10 +340,12 @@ def init_metric( device=device, output_loss_type=run_config.output_loss_type, ) - - case _: - # We shouldn't handle **all** cases because PGDMultiBatch metrics should be handled by - # the evaluate_multibatch_pgd function below. + case ( + PGDMultiBatchReconLossConfig() + | PGDMultiBatchReconSubsetLossConfig() + | PersistentPGDReconLossConfig() + | PersistentPGDReconSubsetLossConfig() + ): raise ValueError(f"Unsupported metric config for eval: {cfg}") return metric @@ -290,12 +359,28 @@ def evaluate( slow_step: bool, n_eval_steps: int, current_frac_of_training: float, + ppgd_states: dict[ + PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, PersistentPGDState + ], ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" + # Persistent PGD losses are training-only (sources are coupled to train batch size) + eval_metric_configs = [ + cfg + for cfg in eval_metric_configs + if not isinstance(cfg, PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig) + ] + metrics: list[Metric] = [] for cfg in eval_metric_configs: - metric = init_metric(cfg=cfg, model=model, run_config=run_config, device=device) + metric = init_metric( + cfg=cfg, + model=model, + run_config=run_config, + device=device, + ppgd_states=ppgd_states, + ) if metric.slow and not slow_step: continue metrics.append(metric) diff --git a/spd/experiments/ih/ih_config.yaml b/spd/experiments/ih/ih_config.yaml index 9c844723a..4ce329ff2 100644 --- a/spd/experiments/ih/ih_config.yaml +++ b/spd/experiments/ih/ih_config.yaml @@ -34,8 +34,10 @@ stochastic_recon_layerwise_coeff: 1 importance_minimality_coeff: 1e-2 pnorm: 0.1 output_loss_type: kl -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [128] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [128] # --- Faithfulness Warmup --- diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index fc58ab8c0..0d45c610e 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [12] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [12] sigmoid_type: "leaky_hard" module_info: - module_pattern: "transformer.h.1.attn.c_attn" diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index e06d01481..b663a3013 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -4,7 +4,12 @@ import fire -from spd.configs import LMTaskConfig +from spd.configs import ( + LMTaskConfig, + PersistentPGDReconLossConfig, + PersistentPGDReconSubsetLossConfig, + RepeatAcrossBatchScope, +) from spd.data import DatasetConfig, create_data_loader from spd.log import logger from spd.pretrain.run_info import PretrainRunInfo @@ -79,7 +84,7 @@ def main( streaming=config.task_config.streaming, column_name=config.task_config.column_name, shuffle_each_epoch=config.task_config.shuffle_each_epoch, - seed=None, + seed=config.task_config.dataset_seed, ) match dist_state: @@ -93,6 +98,16 @@ def main( case None: train_rank_batch_size = config.batch_size + for cfg in config.loss_metric_configs: + if isinstance( + cfg, PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig + ) and isinstance(cfg.scope, RepeatAcrossBatchScope): + n = cfg.scope.n_sources + assert train_rank_batch_size % n == 0, ( + f"repeat_across_batch n_sources={n} must divide per-rank batch_size=" + f"{train_rank_batch_size}" + ) + train_loader, _tokenizer = create_data_loader( dataset_config=train_data_config, batch_size=train_rank_batch_size, @@ -110,7 +125,7 @@ def main( streaming=config.task_config.streaming, column_name=config.task_config.column_name, shuffle_each_epoch=config.task_config.shuffle_each_epoch, - seed=None, + seed=config.task_config.dataset_seed, ) match dist_state: diff --git a/spd/experiments/lm/pile_llama_simple_mlp-12L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-12L.yaml new file mode 100644 index 000000000..13cea558d --- /dev/null +++ b/spd/experiments/lm/pile_llama_simple_mlp-12L.yaml @@ -0,0 +1,154 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: '' +seed: 0 +autocast_bf16: true +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + hidden_dims: null + reader_hidden_dims: null + d_resid_ci_fn: null + block_groups: null + transition_attn_config: null + transition_hidden_dim: null + simple_transformer_ci_cfg: + d_model: 2048 + n_blocks: 8 + mlp_hidden_dim: + - 8192 + attn_config: + n_heads: 16 + max_len: 512 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: +- module_pattern: h.*.mlp.c_fc + C: 3072 +- module_pattern: h.*.mlp.down_proj + C: 3584 +- module_pattern: h.*.attn.q_proj + C: 512 +- module_pattern: h.*.attn.k_proj + C: 512 +- module_pattern: h.*.attn.v_proj + C: 1024 +- module_pattern: h.*.attn.o_proj + C: 1024 +identity_module_info: null +init_spd_checkpoint: null +use_delta_component: true +loss_metric_configs: +- coeff: 0.0001 + classname: ImportanceMinimalityLoss + pnorm: 2.0 + beta: 0.5 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.4 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 +- coeff: 0.5 + classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 0.5 + optimizer: + type: adam + beta1: 0.5 + beta2: 0.99 + eps: 1.0e-08 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.025 + final_val_frac: 1.0 + fn_type: constant + scope: + type: per_batch_per_position + use_sigmoid_parameterization: false + n_warmup_steps: 2 + classname: PersistentPGDReconLoss +- coeff: 10000000.0 + classname: FaithfulnessLoss +output_loss_type: kl +lr_schedule: + start_val: 5.0e-05 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400000 +batch_size: 64 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 400 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 200 +eval_freq: 1000 +eval_batch_size: 64 +slow_eval_freq: 10000 +n_eval_steps: 1 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: +- classname: CIHistograms + n_batches_accum: 1 +- classname: ComponentActivationDensity +- classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + layer_2: + - h.2.* + layer_3: + - h.3.* + layer_4: + - h.4.* + layer_5: + - h.5.* + layer_6: + - h.6.* + layer_7: + - h.7.* + layer_8: + - h.8.* + layer_9: + - h.9.* + layer_10: + - h.10.* + layer_11: + - h.11.* + total: + - '*' +- classname: CEandKLLosses + rounding_threshold: 0.0 +- classname: CIMeanPerComponent +- coeff: null + classname: StochasticHiddenActsReconLoss +- classname: CIHiddenActsReconLoss +- coeff: null + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch + classname: PGDReconLoss +ci_alive_threshold: 0.0 +pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: goodfire/spd/runs/t-f99617bb +pretrained_model_output_attr: idx_0 +tokenizer_name: EleutherAI/gpt-neox-20b +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: danbraunai/pile-uncopyrighted-tok-shuffled + column_name: input_ids + train_data_split: train + eval_data_split: val + shuffle_each_epoch: true + is_tokenized: true + streaming: true diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index 0d57991a9..b8ffb26ce 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -56,13 +56,17 @@ loss_metric_configs: classname: PersistentPGDReconSubsetLoss optimizer: type: adam - lr: 0.1 + lr_schedule: + start_val: 0.1 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant beta1: 0.9 beta2: 0.99 eps: 1.0e-08 scope: - type: batch_invariant - n_masks: 8 + type: repeat_across_batch + n_sources: 8 routing: type: uniform_k_subset - coeff: 1000000.0 @@ -102,10 +106,10 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- classname: CIHiddenActsReconLoss +- classname: PersistentPGDReconSubsetEval +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml new file mode 100644 index 000000000..ba2fe9e19 --- /dev/null +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L-finetune-s-788ccb89.yaml @@ -0,0 +1,130 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: 'finetune-s-788ccb89-' +seed: 0 +autocast_bf16: true +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + hidden_dims: null + reader_hidden_dims: null + d_resid_ci_fn: null + block_groups: null + transition_attn_config: null + transition_hidden_dim: null + simple_transformer_ci_cfg: + d_model: 2048 + n_blocks: 8 + mlp_hidden_dim: + - 8192 + attn_config: + n_heads: 16 + max_len: 512 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: +- module_pattern: h.*.mlp.c_fc + C: 3072 +- module_pattern: h.*.mlp.down_proj + C: 3584 +- module_pattern: h.*.attn.q_proj + C: 512 +- module_pattern: h.*.attn.k_proj + C: 512 +- module_pattern: h.*.attn.v_proj + C: 1024 +- module_pattern: h.*.attn.o_proj + C: 1024 +identity_module_info: null +use_delta_component: true + +init_spd_checkpoint: /mnt/polished-lake/artifacts/mechanisms/spd/spd/s-788ccb89/model_400000.pth + +loss_metric_configs: +# ImpMin: pnorm fixed at 0.4 (end-of-training annealed value), coeff 1.5x original +- coeff: 0.0004 + classname: ImportanceMinimalityLoss + pnorm: 0.4 + beta: 0.2 + p_anneal_start_frac: 1.0 + p_anneal_final_p: null + p_anneal_end_frac: 1.0 + eps: 1.0e-12 +- coeff: 0.5 + classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset +# PGDReconLoss replaces PersistentPGDReconSubsetLoss +- coeff: 0.5 + classname: PGDReconLoss + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch +- coeff: 10000000.0 + classname: FaithfulnessLoss + +output_loss_type: kl +lr_schedule: + start_val: 5.0e-06 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant +steps: 5000 +batch_size: 64 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 0 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 500 +eval_batch_size: 128 +slow_eval_freq: 1000 +n_eval_steps: 1 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: +- classname: CIHistograms + n_batches_accum: 1 +- classname: ComponentActivationDensity +- classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + layer_2: + - h.2.* + layer_3: + - h.3.* + total: + - '*' +- classname: CEandKLLosses + rounding_threshold: 0.0 +- classname: CIMeanPerComponent +- classname: StochasticHiddenActsReconLoss +- init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch + classname: PGDReconLoss +ci_alive_threshold: 0.0 +pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: goodfire/spd/runs/t-9d2b8f02 +pretrained_model_output_attr: idx_0 +tokenizer_name: EleutherAI/gpt-neox-20b +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: danbraunai/pile-uncopyrighted-tok + column_name: input_ids + train_data_split: train + eval_data_split: val + shuffle_each_epoch: true + is_tokenized: true + streaming: true diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L-hidden-acts-recon.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L-hidden-acts-recon.yaml new file mode 100644 index 000000000..331e348ee --- /dev/null +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L-hidden-acts-recon.yaml @@ -0,0 +1,112 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: '' +seed: 0 +autocast_bf16: true +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + hidden_dims: null + reader_hidden_dims: null + d_resid_ci_fn: null + block_groups: null + transition_attn_config: null + transition_hidden_dim: null + simple_transformer_ci_cfg: + d_model: 2048 + n_blocks: 8 + mlp_hidden_dim: + - 819 + attn_config: + n_heads: 16 + max_len: 512 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: +- module_pattern: h.*.mlp.c_fc + C: 300 +- module_pattern: h.*.mlp.down_proj + C: 35 +- module_pattern: h.*.attn.q_proj + C: 51 +- module_pattern: h.*.attn.k_proj + C: 51 +- module_pattern: h.*.attn.v_proj + C: 102 +- module_pattern: h.*.attn.o_proj + C: 102 +identity_module_info: null +use_delta_component: true +loss_metric_configs: +- coeff: 0.0004 + classname: ImportanceMinimalityLoss + pnorm: 2.0 + beta: 0.2 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.4 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 +- coeff: 0.5 + classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 0.5 + classname: PersistentPGDReconLoss + optimizer: + type: adam + beta1: 0.5 + beta2: 0.99 + eps: 1.0e-08 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.001 + final_val_frac: 0.1 + fn_type: cosine + scope: + type: per_batch_per_position + use_sigmoid_parameterization: false + n_warmup_steps: 2 +- coeff: 10000000.0 + classname: FaithfulnessLoss +output_loss_type: kl +lr_schedule: + start_val: 5.0e-05 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 10000 +batch_size: 16 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 400 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 200 +eval_freq: 100 +eval_batch_size: 16 +slow_eval_freq: 10000 +n_eval_steps: 1 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: +- classname: StochasticHiddenActsReconLoss +- classname: CIHiddenActsReconLoss +ci_alive_threshold: 0.0 +pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: goodfire/spd/runs/t-9d2b8f02 +pretrained_model_output_attr: idx_0 +tokenizer_name: EleutherAI/gpt-neox-20b +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: danbraunai/pile-uncopyrighted-tok + column_name: input_ids + train_data_split: train + eval_data_split: val + shuffle_each_epoch: true + is_tokenized: true + streaming: true diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml new file mode 100644 index 000000000..3407b39be --- /dev/null +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L-lucius.yaml @@ -0,0 +1,135 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: '' +seed: 0 +autocast_bf16: true +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + hidden_dims: null + reader_hidden_dims: null + d_resid_ci_fn: null + block_groups: null + transition_attn_config: null + transition_hidden_dim: null + simple_transformer_ci_cfg: + d_model: 2048 + n_blocks: 8 + mlp_hidden_dim: + - 8192 + attn_config: + n_heads: 16 + max_len: 512 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: +- module_pattern: h.*.mlp.c_fc + C: 4096 +- module_pattern: h.*.mlp.down_proj + C: 3072 +- module_pattern: h.*.attn.q_proj + C: 512 +- module_pattern: h.*.attn.k_proj + C: 256 +- module_pattern: h.*.attn.v_proj + C: 1024 +- module_pattern: h.*.attn.o_proj + C: 1024 +identity_module_info: null +use_delta_component: true +loss_metric_configs: +- coeff: 0.0005 + classname: ImportanceMinimalityLoss + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.4 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 +- coeff: 0.5 + classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 0.5 + classname: PersistentPGDReconSubsetLoss + optimizer: + type: adam + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.8 + beta2: 0.99 + eps: 1.0e-08 + scope: + type: repeat_across_batch + n_sources: 8 + routing: + type: uniform_k_subset +- coeff: 1000000.0 + classname: FaithfulnessLoss +output_loss_type: kl +lr_schedule: + start_val: 3.0e-05 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400000 +batch_size: 128 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 400 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 200 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 1 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: +- classname: CIHistograms + n_batches_accum: 1 +- classname: ComponentActivationDensity +- classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + layer_2: + - h.2.* + layer_3: + - h.3.* + total: + - '*' +- classname: CEandKLLosses + rounding_threshold: 0.0 +- classname: CIMeanPerComponent +- classname: StochasticHiddenActsReconLoss +- init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch + classname: PGDReconLoss +ci_alive_threshold: 0.0 +pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: goodfire/spd/runs/t-9d2b8f02 +pretrained_model_output_attr: idx_0 +tokenizer_name: EleutherAI/gpt-neox-20b +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: danbraunai/pile-uncopyrighted-tok + column_name: input_ids + train_data_split: train + eval_data_split: val + shuffle_each_epoch: true + is_tokenized: true + streaming: true \ No newline at end of file diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index 570c8023d..7e091e4e7 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -14,7 +14,7 @@ ci_config: transition_attn_config: null transition_hidden_dim: null simple_transformer_ci_cfg: - d_model: 1024 + d_model: 2048 n_blocks: 8 mlp_hidden_dim: - 8192 @@ -25,25 +25,62 @@ ci_config: sampling: continuous sigmoid_type: leaky_hard module_info: -- module_pattern: h.*.mlp.c_fc - C: 4096 -- module_pattern: h.*.mlp.down_proj - C: 3072 -- module_pattern: h.*.attn.q_proj +- module_pattern: h.3.mlp.c_fc + C: 1536 +- module_pattern: h.2.mlp.c_fc C: 512 -- module_pattern: h.*.attn.k_proj +- module_pattern: h.1.mlp.c_fc + C: 512 +- module_pattern: h.0.mlp.c_fc + C: 1536 +- module_pattern: h.3.mlp.down_proj + C: 2560 +- module_pattern: h.2.mlp.down_proj + C: 512 +- module_pattern: h.1.mlp.down_proj + C: 512 +- module_pattern: h.0.mlp.down_proj + C: 1536 +- module_pattern: h.3.attn.q_proj + C: 128 +- module_pattern: h.2.attn.q_proj + C: 128 +- module_pattern: h.1.attn.q_proj + C: 128 +- module_pattern: h.0.attn.q_proj + C: 256 +- module_pattern: h.3.attn.k_proj + C: 128 +- module_pattern: h.2.attn.k_proj + C: 256 +- module_pattern: h.1.attn.k_proj + C: 128 +- module_pattern: h.0.attn.k_proj + C: 256 +- module_pattern: h.3.attn.v_proj + C: 512 +- module_pattern: h.2.attn.v_proj + C: 640 +- module_pattern: h.1.attn.v_proj C: 256 -- module_pattern: h.*.attn.v_proj - C: 1024 -- module_pattern: h.*.attn.o_proj - C: 1024 +- module_pattern: h.0.attn.v_proj + C: 512 +- module_pattern: h.3.attn.o_proj + C: 512 +- module_pattern: h.2.attn.o_proj + C: 640 +- module_pattern: h.1.attn.o_proj + C: 256 +- module_pattern: h.0.attn.o_proj + C: 512 identity_module_info: null +init_spd_checkpoint: null use_delta_component: true loss_metric_configs: -- coeff: 0.0005 +- coeff: 0.0002 classname: ImportanceMinimalityLoss pnorm: 2.0 - beta: 0.1 + beta: 0.5 p_anneal_start_frac: 0.0 p_anneal_final_p: 0.4 p_anneal_end_frac: 1.0 @@ -53,23 +90,26 @@ loss_metric_configs: routing: type: uniform_k_subset - coeff: 0.5 - classname: PersistentPGDReconSubsetLoss optimizer: type: adam - lr: 0.1 - beta1: 0.9 + beta1: 0.5 beta2: 0.99 eps: 1.0e-08 + lr_schedule: + start_val: 0.01 + warmup_pct: 0.025 + final_val_frac: 0.1 + fn_type: cosine scope: - type: batch_invariant - n_masks: 8 - routing: - type: uniform_k_subset -- coeff: 1000000.0 + type: per_batch_per_position + use_sigmoid_parameterization: false + n_warmup_steps: 2 + classname: PersistentPGDReconLoss +- coeff: 10000000.0 classname: FaithfulnessLoss output_loss_type: kl lr_schedule: - start_val: 1.0e-04 + start_val: 5.0e-05 warmup_pct: 0.0 final_val_frac: 0.1 fn_type: cosine @@ -82,7 +122,7 @@ faithfulness_warmup_lr: 0.001 faithfulness_warmup_weight_decay: 0.0 train_log_freq: 200 eval_freq: 1000 -eval_batch_size: 256 +eval_batch_size: 128 slow_eval_freq: 10000 n_eval_steps: 1 slow_eval_on_first_step: true @@ -108,6 +148,7 @@ eval_metric_configs: - classname: CIMeanPerComponent - coeff: null classname: StochasticHiddenActsReconLoss +- classname: CIHiddenActsReconLoss - coeff: null init: random step_size: 0.1 @@ -117,17 +158,18 @@ eval_metric_configs: ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null -pretrained_model_name: wandb:goodfire/spd/t-32d1bb3b +pretrained_model_name: goodfire/spd/runs/t-9d2b8f02 pretrained_model_output_attr: idx_0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm max_seq_len: 512 buffer_size: 1000 - dataset_name: danbraunai/pile-uncopyrighted-tok + dataset_name: danbraunai/pile-uncopyrighted-tok-shuffled column_name: input_ids train_data_split: train eval_data_split: val shuffle_each_epoch: true is_tokenized: true - streaming: true \ No newline at end of file + streaming: true + dataset_seed: null diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index a08002e99..1c11fa6aa 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [12] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [12] sigmoid_type: "leaky_hard" module_info: - module_pattern: "transformer.h.1.mlp.c_fc" diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index 9d7bf3164..f0f857629 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 1 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [550] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" @@ -78,10 +80,16 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- classname: StochasticAttnPatternsReconLoss + n_heads: 4 + q_proj_path: "h.*.attn.q_proj" + k_proj_path: "h.*.attn.k_proj" +- classname: CIMaskedAttnPatternsReconLoss + n_heads: 4 + q_proj_path: "h.*.attn.q_proj" + k_proj_path: "h.*.attn.k_proj" +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch @@ -111,4 +119,4 @@ task_config: # n_head=4, # n_embd=128, # flash_attention=False, -# ), \ No newline at end of file +# ), diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index f4676d18c..e1ff2a0d9 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 1 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [550] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" @@ -80,10 +82,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index dbe6bc70e..5d6b2fdeb 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [1000] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [1000] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index 83616f355..5e7893035 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [12] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [12] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index 5852ea0cc..b9a330e48 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: @@ -79,10 +80,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index bf4837709..34d6106ff 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: @@ -81,10 +82,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 3aae41850..ade4cc047 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [1000] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [1000] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.gate_proj" diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index 5cb08db5a..146bc1362 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: @@ -73,10 +74,8 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index c7a3e57ba..79ab8fea7 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 1250 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [1250] sampling: continuous sigmoid_type: leaky_hard module_info: @@ -82,9 +83,7 @@ eval_metric_configs: rounding_threshold: 0 - classname: CIMeanPerComponent - classname: StochasticHiddenActsReconLoss - coeff: null - classname: PGDReconLoss - coeff: null init: random step_size: 0.1 n_steps: 20 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml new file mode 100644 index 000000000..84c63b083 --- /dev/null +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml @@ -0,0 +1,126 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: '' +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_reverse_residual + transition_hidden_dim: 1024 + d_resid_ci_fn: 1024 + reader_hidden_dims: [1024] + block_groups: + - name: layer_1_mlp + patterns: + - h.1.mlp.* + - name: layer_1_attn + patterns: + - h.1.attn.* + - name: layer_0_mlp + patterns: + - h.0.mlp.* + - name: layer_0_attn + patterns: + - h.0.attn.* + transition_attn_config: + n_heads: 4 + max_len: 512 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: +- module_pattern: h.*.mlp.c_fc + C: 1152 +- module_pattern: h.*.mlp.down_proj + C: 960 +- module_pattern: h.*.attn.q_proj + C: 288 +- module_pattern: h.*.attn.k_proj + C: 288 +- module_pattern: h.*.attn.v_proj + C: 384 +- module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: +- coeff: 0.003 + classname: ImportanceMinimalityLoss + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 +- coeff: 0.5 + classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 0.5 + init: random + step_size: 0.5 + n_steps: 4 + mask_scope: shared_across_batch + classname: PGDReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 1000000.0 + classname: FaithfulnessLoss +output_loss_type: kl +lr_schedule: + start_val: 0.0005 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400000 +batch_size: 256 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 200 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: +- classname: CIHistograms + n_batches_accum: 5 +- classname: ComponentActivationDensity +- classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - '*' +- classname: CEandKLLosses + rounding_threshold: 0.0 +- classname: CIMeanPerComponent +- classname: StochasticHiddenActsReconLoss +- init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch + classname: PGDReconLoss +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false \ No newline at end of file diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index c8e5726ce..f2d6262f1 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: @@ -37,11 +38,20 @@ loss_metric_configs: routing: type: uniform_k_subset - coeff: 0.5 - init: random - step_size: 1.0 - n_steps: 1 - mask_scope: shared_across_batch - classname: PGDReconSubsetLoss + classname: PersistentPGDReconSubsetLoss + optimizer: + type: adam + lr_schedule: + start_val: 0.1 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.9 + beta2: 0.99 + eps: 1.0e-08 + scope: + type: repeat_across_batch + n_sources: 8 routing: type: uniform_k_subset - coeff: 1000000.0 @@ -79,10 +89,10 @@ eval_metric_configs: - classname: CEandKLLosses rounding_threshold: 0.0 - classname: CIMeanPerComponent -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- classname: CIHiddenActsReconLoss +- classname: PersistentPGDReconSubsetEval +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index c2ad572a1..f77da81ff 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 800 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [800] sampling: continuous sigmoid_type: leaky_hard module_info: @@ -104,10 +105,8 @@ eval_metric_configs: - h.2.* all_but_layer_3: - h.3.* -- coeff: null - classname: StochasticHiddenActsReconLoss -- coeff: null - init: random +- classname: StochasticHiddenActsReconLoss +- init: random step_size: 0.1 n_steps: 20 mask_scope: shared_across_batch diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index 4f8c966ac..f60c4e589 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -9,8 +9,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [8] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [8] sigmoid_type: "leaky_hard" module_info: - module_pattern: "transformer.h.3.mlp.c_fc" diff --git a/spd/experiments/lm/z-jan22.yaml b/spd/experiments/lm/z-jan22.yaml new file mode 100644 index 000000000..e0d14f13f --- /dev/null +++ b/spd/experiments/lm/z-jan22.yaml @@ -0,0 +1,129 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: '' +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_reverse_residual + transition_hidden_dim: 1024 + reader_hidden_dims: + - 1024 + - 1024 + d_resid_ci_fn: 2048 + block_groups: + - name: layer_1_mlp + patterns: + - h.1.mlp.* + - name: layer_1_attn + patterns: + - h.1.attn.* + - name: layer_0_mlp + patterns: + - h.0.mlp.* + - name: layer_0_attn + patterns: + - h.0.attn.* + transition_attn_config: + n_heads: 8 + max_len: 2048 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: +- module_pattern: h.*.mlp.c_fc + C: 1152 +- module_pattern: h.*.mlp.down_proj + C: 960 +- module_pattern: h.*.attn.q_proj + C: 288 +- module_pattern: h.*.attn.k_proj + C: 288 +- module_pattern: h.*.attn.v_proj + C: 384 +- module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: +- coeff: 0.003 + classname: ImportanceMinimalityLoss + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 +- coeff: 0.5 + classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 0.5 + init: random + step_size: 0.5 + n_steps: 4 + mask_scope: shared_across_batch + classname: PGDReconSubsetLoss + routing: + type: uniform_k_subset +- coeff: 1000000.0 + classname: FaithfulnessLoss +output_loss_type: kl +lr_schedule: + start_val: 0.0005 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400000 +batch_size: 16 +gradient_accumulation_steps: 1 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: +- classname: CIHistograms + n_batches_accum: 5 +- classname: ComponentActivationDensity +- classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - '*' +- classname: CEandKLLosses + rounding_threshold: 0.0 +- classname: CIMeanPerComponent +- classname: StochasticHiddenActsReconLoss +- init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch + classname: PGDReconLoss +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false diff --git a/spd/experiments/lm/z-jan22_ppgd.yaml b/spd/experiments/lm/z-jan22_ppgd.yaml new file mode 100644 index 000000000..d2f4b63e7 --- /dev/null +++ b/spd/experiments/lm/z-jan22_ppgd.yaml @@ -0,0 +1,147 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" +seed: 0 +n_mask_samples: 1 +ci_config: + __variants__: + - mode: global + fn_type: global_shared_transformer + simple_transformer_ci_cfg: + d_model: 512 + n_blocks: 4 + mlp_hidden_dim: [2048] + attn_config: + n_heads: 8 + max_len: 1024 + rope_base: 10000.0 + - mode: global + fn_type: global_reverse_residual + transition_hidden_dim: 1024 + reader_hidden_dims: + - 1024 + - 1024 + d_resid_ci_fn: 2048 + block_groups: + - name: layer_1_mlp + patterns: + - h.1.mlp.* + - name: layer_1_attn + patterns: + - h.1.attn.* + - name: layer_0_mlp + patterns: + - h.0.mlp.* + - name: layer_0_attn + patterns: + - h.0.attn.* + transition_attn_config: + n_heads: 8 + max_len: 2048 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: + - module_pattern: h.*.mlp.c_fc + C: 1152 + - module_pattern: h.*.mlp.down_proj + C: 960 + - module_pattern: h.*.attn.q_proj + C: 288 + - module_pattern: h.*.attn.k_proj + C: 288 + - module_pattern: h.*.attn.v_proj + C: 384 + - module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: + - classname: FaithfulnessLoss + coeff: 1000000.0 + - classname: ImportanceMinimalityLoss + coeff: 0.003 + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 + - classname: StochasticReconSubsetLoss + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: PersistentPGDReconSubsetLoss + coeff: 0.5 + optimizer: + type: adam + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.8 + beta2: 0.99 + scope: + type: per_batch_per_position + routing: + type: uniform_k_subset +output_loss_type: kl +lr_schedule: + start_val: 0.0003 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400_000 +batch_size: 32 +gradient_accumulation_steps: 1 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: + - classname: CIHistograms + n_batches_accum: 5 + - classname: ComponentActivationDensity + - classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - "*" + - classname: CEandKLLosses + rounding_threshold: 0.0 + - classname: CIMeanPerComponent + - classname: StochasticHiddenActsReconLoss + - classname: PGDReconLoss + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false diff --git a/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml b/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml new file mode 100644 index 000000000..d5a102eaa --- /dev/null +++ b/spd/experiments/lm/z-jan22_ppgd_reverse_resid_ablations.yaml @@ -0,0 +1,158 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "reverse-resid-ci" +seed: 0 +n_mask_samples: 1 +ci_config: + __variants__: + - mode: global + fn_type: global_reverse_residual + d_resid_ci_fn: 1024 + reader_hidden_dims: [1024] + transition_hidden_dim: 1024 + block_groups: + - name: layer_1_mlp + patterns: + - h.1.mlp.* + - name: layer_1_attn + patterns: + - h.1.attn.* + - name: layer_0_mlp + patterns: + - h.0.mlp.* + - name: layer_0_attn + patterns: + - h.0.attn.* + transition_attn_config: + n_heads: 8 + max_len: 2048 + rope_base: 10000.0 + - mode: global + fn_type: global_reverse_residual + d_resid_ci_fn: 1024 + reader_hidden_dims: [2048] + transition_hidden_dim: 2048 + block_groups: + - name: layer_1_mlp + patterns: + - h.1.mlp.* + - name: layer_1_attn + patterns: + - h.1.attn.* + - name: layer_0_mlp + patterns: + - h.0.mlp.* + - name: layer_0_attn + patterns: + - h.0.attn.* + transition_attn_config: + n_heads: 8 + max_len: 2048 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: + - module_pattern: h.*.mlp.c_fc + C: 1152 + - module_pattern: h.*.mlp.down_proj + C: 960 + - module_pattern: h.*.attn.q_proj + C: 288 + - module_pattern: h.*.attn.k_proj + C: 288 + - module_pattern: h.*.attn.v_proj + C: 384 + - module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: + - classname: FaithfulnessLoss + coeff: 1000000.0 + - classname: ImportanceMinimalityLoss + coeff: 0.003 + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 + - classname: StochasticReconSubsetLoss + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: PersistentPGDReconSubsetLoss + coeff: 0.5 + optimizer: + type: adam + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.8 + beta2: 0.99 + scope: + type: repeat_across_batch + n_sources: 32 + routing: + type: uniform_k_subset +output_loss_type: kl +lr_schedule: + start_val: 0.0003 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400_000 +batch_size: 32 +gradient_accumulation_steps: 1 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: + - classname: CIHistograms + n_batches_accum: 5 + - classname: ComponentActivationDensity + - classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - "*" + - classname: CEandKLLosses + rounding_threshold: 0.0 + - classname: CIMeanPerComponent + - classname: StochasticHiddenActsReconLoss + - classname: PGDReconLoss + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false diff --git a/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml b/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml new file mode 100644 index 000000000..8487a9ce8 --- /dev/null +++ b/spd/experiments/lm/z-jan22_ppgd_transformer_normed.yaml @@ -0,0 +1,122 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + simple_transformer_ci_cfg: + d_model: 512 + n_blocks: 4 + mlp_hidden_dim: [2048] + attn_config: + n_heads: 8 + max_len: 1024 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: + - module_pattern: h.*.mlp.c_fc + C: 1152 + - module_pattern: h.*.mlp.down_proj + C: 960 + - module_pattern: h.*.attn.q_proj + C: 288 + - module_pattern: h.*.attn.k_proj + C: 288 + - module_pattern: h.*.attn.v_proj + C: 384 + - module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: + - classname: FaithfulnessLoss + coeff: 1000000.0 + - classname: ImportanceMinimalityLoss + coeff: 0.003 + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 + - classname: StochasticReconSubsetLoss + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: PersistentPGDReconSubsetLoss + coeff: 0.5 + optimizer: + type: adam + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.8 + beta2: 0.99 + scope: + type: per_batch_per_position + routing: + type: uniform_k_subset +output_loss_type: kl +lr_schedule: + start_val: 0.0003 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400_000 +batch_size: 32 +gradient_accumulation_steps: 1 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: + - classname: CIHistograms + n_batches_accum: 5 + - classname: ComponentActivationDensity + - classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - "*" + - classname: CEandKLLosses + rounding_threshold: 0.0 + - classname: CIMeanPerComponent + - classname: StochasticHiddenActsReconLoss + - classname: PGDReconLoss + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false diff --git a/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml new file mode 100644 index 000000000..b20a94d06 --- /dev/null +++ b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_ablations.yaml @@ -0,0 +1,137 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "transformer-ci-normed" +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + simple_transformer_ci_cfg: + __variants__: + - d_model: 1024 + n_blocks: 2 + mlp_hidden_dim: [4096] + attn_config: + n_heads: 8 + max_len: 1024 + rope_base: 10000.0 + - d_model: 256 + n_blocks: 8 + mlp_hidden_dim: [1024] + attn_config: + n_heads: 8 + max_len: 1024 + rope_base: 10000.0 + - d_model: 512 + n_blocks: 4 + mlp_hidden_dim: [2048] + attn_config: + n_heads: 16 + max_len: 1024 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: + - module_pattern: h.*.mlp.c_fc + C: 1152 + - module_pattern: h.*.mlp.down_proj + C: 960 + - module_pattern: h.*.attn.q_proj + C: 288 + - module_pattern: h.*.attn.k_proj + C: 288 + - module_pattern: h.*.attn.v_proj + C: 384 + - module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: + - classname: FaithfulnessLoss + coeff: 1000000.0 + - classname: ImportanceMinimalityLoss + coeff: 0.003 + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 + - classname: StochasticReconSubsetLoss + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: PersistentPGDReconSubsetLoss + coeff: 0.5 + optimizer: + type: adam + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.8 + beta2: 0.99 + scope: + type: per_batch_per_position + routing: + type: uniform_k_subset +output_loss_type: kl +lr_schedule: + start_val: 0.0003 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400_000 +batch_size: 32 +gradient_accumulation_steps: 1 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: + - classname: CIHistograms + n_batches_accum: 5 + - classname: ComponentActivationDensity + - classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - "*" + - classname: CEandKLLosses + rounding_threshold: 0.0 + - classname: CIMeanPerComponent + - classname: StochasticHiddenActsReconLoss + - classname: PGDReconLoss + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false diff --git a/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml new file mode 100644 index 000000000..de678b368 --- /dev/null +++ b/spd/experiments/lm/z-jan22_ppgd_transformer_normed_deep.yaml @@ -0,0 +1,123 @@ +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "transformer-ci-normed" +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_transformer + simple_transformer_ci_cfg: + d_model: 256 + n_blocks: 8 + mlp_hidden_dim: [1024] + attn_config: + n_heads: 8 + max_len: 1024 + rope_base: 10000.0 +sampling: continuous +sigmoid_type: leaky_hard +module_info: + - module_pattern: h.*.mlp.c_fc + C: 1152 + - module_pattern: h.*.mlp.down_proj + C: 960 + - module_pattern: h.*.attn.q_proj + C: 288 + - module_pattern: h.*.attn.k_proj + C: 288 + - module_pattern: h.*.attn.v_proj + C: 384 + - module_pattern: h.*.attn.o_proj + C: 480 +identity_module_info: null +use_delta_component: true +loss_metric_configs: + - classname: FaithfulnessLoss + coeff: 1000000.0 + - classname: ImportanceMinimalityLoss + coeff: 0.003 + pnorm: 2.0 + beta: 0.1 + p_anneal_start_frac: 0.0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1.0 + eps: 1.0e-12 + - classname: StochasticReconSubsetLoss + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: PersistentPGDReconSubsetLoss + coeff: 0.5 + optimizer: + type: adam + lr_schedule: + start_val: 0.01 + warmup_pct: 0.0 + final_val_frac: 1.0 + fn_type: constant + beta1: 0.8 + beta2: 0.99 + scope: + type: repeat_across_batch + n_sources: 32 + routing: + type: uniform_k_subset +output_loss_type: kl +lr_schedule: + start_val: 0.0003 + warmup_pct: 0.0 + final_val_frac: 0.1 + fn_type: cosine +steps: 400_000 +batch_size: 32 +gradient_accumulation_steps: 1 +grad_clip_norm_components: 0.01 +grad_clip_norm_ci_fns: null +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.001 +faithfulness_warmup_weight_decay: 0.0 +train_log_freq: 50 +eval_freq: 1000 +eval_batch_size: 128 +slow_eval_freq: 10000 +n_eval_steps: 5 +slow_eval_on_first_step: true +save_freq: null +eval_metric_configs: + - classname: CIHistograms + n_batches_accum: 5 + - classname: ComponentActivationDensity + - classname: CI_L0 + groups: + layer_0: + - h.0.* + layer_1: + - h.1.* + total: + - "*" + - classname: CEandKLLosses + rounding_threshold: 0.0 + - classname: CIMeanPerComponent + - classname: StochasticHiddenActsReconLoss + - classname: PGDReconLoss + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch +ci_alive_threshold: 0.0 +pretrained_model_class: simple_stories_train.models.llama_simple_mlp.LlamaSimpleMLP +pretrained_model_path: null +pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 +pretrained_model_output_attr: idx_0 +tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: SimpleStories/SimpleStories + column_name: story + train_data_split: train + eval_data_split: test + shuffle_each_epoch: true + is_tokenized: false + streaming: false diff --git a/spd/experiments/resid_mlp/resid_mlp1_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_config.yaml index 1d8bc7fed..6178f579b 100644 --- a/spd/experiments/resid_mlp/resid_mlp1_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp1_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "layers.*.mlp_in" diff --git a/spd/experiments/resid_mlp/resid_mlp1_global_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_global_config.yaml new file mode 100644 index 000000000..058db7a63 --- /dev/null +++ b/spd/experiments/resid_mlp/resid_mlp1_global_config.yaml @@ -0,0 +1,84 @@ +# ResidualMLP 1 layer - Global CI +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_mlp + hidden_dims: [400, 300] +sigmoid_type: "leaky_hard" +module_info: + - module_pattern: "layers.*.mlp_in" + C: 100 + - module_pattern: "layers.*.mlp_out" + C: 100 +identity_module_info: null +use_delta_component: true + +# --- Loss config --- +loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: 1e-5 + pnorm: 2.0 + beta: 0 + - classname: "StochasticReconLayerwiseLoss" + coeff: 1.0 + - classname: "StochasticReconLoss" + coeff: 1.0 +output_loss_type: mse + +# --- Training --- +batch_size: 2048 +eval_batch_size: 2048 +steps: 20_000 +lr_schedule: + start_val: 2e-3 + fn_type: constant + warmup_pct: 0.0 + +# --- Faithfulness Warmup --- +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.01 +faithfulness_warmup_weight_decay: 0.1 + +# --- Logging & Saving --- +train_log_freq: 100 +eval_freq: 500 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +eval_metric_configs: + - classname: "CIHistograms" + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + identity_patterns: ["layers.*.mlp_in"] + dense_patterns: ["layers.*.mlp_out"] + - classname: "IdentityCIError" + identity_ci: + - layer_pattern: "layers.*.mlp_in" + n_features: 100 + dense_ci: + - layer_pattern: "layers.*.mlp_out" + k: 50 + - classname: "CI_L0" + groups: null + - classname: "CIMeanPerComponent" + - classname: "StochasticHiddenActsReconLoss" + +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" +pretrained_model_path: "wandb:goodfire/spd-pre-Sep-2025/runs/pziyck78" + +# --- Task Specific --- +task_config: + task_name: resid_mlp + feature_probability: 0.01 + data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml new file mode 100644 index 000000000..4e708862e --- /dev/null +++ b/spd/experiments/resid_mlp/resid_mlp1_global_reverse_config.yaml @@ -0,0 +1,106 @@ +# ResidualMLP 1 layer - Global Reverse Residual CI +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_reverse_residual + d_resid_ci_fn: 512 + reader_hidden_dims: [512] + block_groups: + - name: "layer_0_mlp" + patterns: ["layers.0.mlp_*"] +sampling: continuous +sigmoid_type: "leaky_hard" +module_info: + - module_pattern: "layers.*.mlp_in" + C: 100 + - module_pattern: "layers.*.mlp_out" + C: 100 +identity_module_info: null +use_delta_component: true + +# --- Loss config --- +loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: 1e-5 + pnorm: 2.0 + beta: 0 + p_anneal_start_frac: 0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1 + eps: 1e-12 + - classname: "StochasticReconSubsetLoss" + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: "PGDReconSubsetLoss" + coeff: 0.5 + init: random + step_size: 1 + n_steps: 1 + mask_scope: shared_across_batch + routing: + type: uniform_k_subset + - classname: "StochasticReconLoss" + coeff: 1.0 + - classname: "FaithfulnessLoss" + coeff: 10 +output_loss_type: mse + +# --- Training --- +batch_size: 2048 +eval_batch_size: 2048 +steps: 100_000 +lr_schedule: + start_val: 5e-3 + fn_type: cosine + final_val_frac: 0.1 + warmup_pct: 0.0 + +# --- Faithfulness Warmup --- +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.01 +faithfulness_warmup_weight_decay: 0.1 + +# --- Logging & Saving --- +train_log_freq: 100 +eval_freq: 500 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +eval_metric_configs: + - classname: "CIHistograms" + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + identity_patterns: ["layers.*.mlp_in"] + dense_patterns: ["layers.*.mlp_out"] + - classname: "IdentityCIError" + identity_ci: + - layer_pattern: "layers.*.mlp_in" + n_features: 100 + dense_ci: + - layer_pattern: "layers.*.mlp_out" + k: 50 + - classname: "CI_L0" + groups: null + - classname: "CIMeanPerComponent" + - classname: "StochasticHiddenActsReconLoss" + +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" +pretrained_model_path: "wandb:goodfire/spd-pre-Sep-2025/runs/pziyck78" + +# --- Task Specific --- +task_config: + task_name: resid_mlp + feature_probability: 0.01 + data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/resid_mlp/resid_mlp2_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_config.yaml index bae662b6f..dcc8abeba 100644 --- a/spd/experiments/resid_mlp/resid_mlp2_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp2_config.yaml @@ -7,9 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: - - 256 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [256] sigmoid_type: leaky_hard module_info: - module_pattern: "layers.*.mlp_in" diff --git a/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml new file mode 100644 index 000000000..8c239ee60 --- /dev/null +++ b/spd/experiments/resid_mlp/resid_mlp2_global_reverse_config.yaml @@ -0,0 +1,114 @@ +# ResidualMLP 2 layers - Global Reverse Residual CI +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_reverse_residual + d_resid_ci_fn: 1024 + reader_hidden_dims: [1024, 1024] + block_groups: + # Process in reverse order: layer 1 before layer 0 + - name: "layer_1_mlp" + patterns: ["layers.1.mlp_*"] + - name: "layer_0_mlp" + patterns: ["layers.0.mlp_*"] +sampling: continuous +sigmoid_type: "leaky_hard" +module_info: + - module_pattern: "layers.*.mlp_in" + C: 400 + - module_pattern: "layers.*.mlp_out" + C: 400 +identity_module_info: null +use_delta_component: true + +# --- Loss config --- +loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: 3e-5 + pnorm: 2.0 + beta: 0 + p_anneal_start_frac: 0 + p_anneal_final_p: 0.5 + p_anneal_end_frac: 1 + eps: 1e-12 + - classname: "StochasticReconSubsetLoss" + coeff: 0.5 + routing: + type: uniform_k_subset + - classname: "PGDReconSubsetLoss" + coeff: 0.5 + init: random + step_size: 1 + n_steps: 1 + mask_scope: shared_across_batch + routing: + type: uniform_k_subset + - classname: "StochasticReconLoss" + coeff: 1.0 + - classname: "FaithfulnessLoss" + coeff: 10 +output_loss_type: mse + +# --- Training --- +batch_size: 2048 +eval_batch_size: 2048 +steps: 50_000 +lr_schedule: + start_val: 5e-4 + fn_type: cosine + final_val_frac: 0.9 + warmup_pct: 0.0 + +# --- Faithfulness Warmup --- +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.01 +faithfulness_warmup_weight_decay: 0 + +# --- Logging & Saving --- +train_log_freq: 50 +eval_freq: 500 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +eval_metric_configs: + - classname: "CIHistograms" + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + identity_patterns: ["layers.*.mlp_in"] + dense_patterns: ["layers.*.mlp_out"] + - classname: "IdentityCIError" + identity_ci: + - layer_pattern: "layers.*.mlp_in" + n_features: 100 + dense_ci: + - layer_pattern: "layers.*.mlp_out" + k: 25 + - classname: "CI_L0" + groups: null + - classname: "CIMeanPerComponent" + - classname: "StochasticHiddenActsReconLoss" + - classname: "PGDReconLoss" + init: random + step_size: 0.1 + n_steps: 20 + mask_scope: shared_across_batch + +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" +pretrained_model_path: "wandb:goodfire/spd-pre-Sep-2025/runs/any9ekl9" + +# --- Task Specific --- +task_config: + task_name: resid_mlp + feature_probability: 0.01 + data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/resid_mlp/resid_mlp3_config.yaml b/spd/experiments/resid_mlp/resid_mlp3_config.yaml index dac4f9c10..1961a44d3 100644 --- a/spd/experiments/resid_mlp/resid_mlp3_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp3_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [128] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [128] sigmoid_type: "leaky_hard" module_info: - module_pattern: "layers.*.mlp_in" diff --git a/spd/experiments/tms/tms_40-10-id_config.yaml b/spd/experiments/tms/tms_40-10-id_config.yaml index e3e40d5fc..0fbdfd07a 100644 --- a/spd/experiments/tms/tms_40-10-id_config.yaml +++ b/spd/experiments/tms/tms_40-10-id_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/experiments/tms/tms_40-10_config.yaml b/spd/experiments/tms/tms_40-10_config.yaml index a4aeb6a97..2a7cffb55 100644 --- a/spd/experiments/tms/tms_40-10_config.yaml +++ b/spd/experiments/tms/tms_40-10_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/experiments/tms/tms_5-2-id_config.yaml b/spd/experiments/tms/tms_5-2-id_config.yaml index c9b2234e8..cc654532e 100644 --- a/spd/experiments/tms/tms_5-2-id_config.yaml +++ b/spd/experiments/tms/tms_5-2-id_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 34c92fa08..07bc9056a 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/graph_interp/CLAUDE.md b/spd/graph_interp/CLAUDE.md new file mode 100644 index 000000000..1fae5e07b --- /dev/null +++ b/spd/graph_interp/CLAUDE.md @@ -0,0 +1,71 @@ +# Graph Interpretation Module + +Context-aware component labeling using network graph structure. Unlike standard autointerp (one-shot per component), this module uses dataset attributions to provide graph context: each component's prompt includes labels from already-labeled components connected via the attribution graph. + +## Usage + +```bash +# Via SLURM (standalone) +spd-graph-interp --config config.yaml + +# Direct execution +python -m spd.graph_interp.scripts.run --config_json '{...}' +``` + +Requires `OPENROUTER_API_KEY` env var. Requires both harvest data and dataset attributions to exist. + +## Three-Phase Pipeline + +1. **Output pass** (late → early): "What does this component DO?" Each component's prompt includes top-K downstream components (by attribution) with their labels. Late layers labeled first so earlier layers see labeled downstream context. + +2. **Input pass** (early → late): "What TRIGGERS this component?" Each component's prompt includes top-K upstream components (by attribution) + co-firing components (Jaccard/PMI). Early layers labeled first so later layers see labeled upstream context. Independent of the output pass. + +3. **Unification** (parallel): Synthesizes output + input labels into a single unified label per component. + +All three phases run in a single invocation. Resume is per-phase via completed key sets in the DB. + +## Data Storage + +``` +SPD_OUT_DIR/graph_interp// +└── ti-YYYYMMDD_HHMMSS/ + ├── interp.db # SQLite: output_labels, input_labels, unified_labels, prompt_edges + └── config.yaml +``` + +## Database Schema + +- `output_labels`: component_key → label, confidence, reasoning, raw_response, prompt +- `input_labels`: same schema as output_labels +- `unified_labels`: same schema as output_labels +- `prompt_edges`: directed filtered graph of (component, related_key, pass, attribution, related_label) +- `config`: key-value store + +## Architecture + +| File | Purpose | +|------|---------| +| `config.py` | `GraphInterpConfig`, `GraphInterpSlurmConfig` | +| `schemas.py` | `LabelResult`, `PromptEdge`, path helpers | +| `db.py` | `GraphInterpDB` — SQLite via `open_nfs_sqlite` (NFS-safe, no WAL) | +| `ordering.py` | Topological sort via `CanonicalWeight` from topology module | +| `graph_context.py` | `RelatedComponent`, gather attributed + co-firing components | +| `prompts.py` | Three prompt formatters (output, input, unification) | +| `interpret.py` | Main three-phase execution loop | +| `repo.py` | `GraphInterpRepo` — read-only access to results | +| `scripts/run.py` | CLI entry point (called by SLURM) | +| `scripts/run_slurm.py` | SLURM submission | +| `scripts/run_slurm_cli.py` | Thin CLI wrapper for `spd-graph-interp` | + +## Dependencies + +- Harvest data (component stats, correlations, token stats) +- Dataset attributions (component-to-component attribution strengths) +- Reuses `map_llm_calls` from `spd/autointerp/llm_api.py` +- Reuses prompt helpers from `spd/autointerp/prompt_helpers.py` + +## SLURM Integration + +- 0 GPUs, 16 CPUs, 240GB memory (CPU-only, LLM API calls) +- Depends on both harvest merge AND attribution merge jobs +- Entry point: `spd-graph-interp` diff --git a/spd/graph_interp/__init__.py b/spd/graph_interp/__init__.py new file mode 100644 index 000000000..61e182fda --- /dev/null +++ b/spd/graph_interp/__init__.py @@ -0,0 +1 @@ +"""Graph interpretation: context-aware component labeling using graph structure.""" diff --git a/spd/graph_interp/config.py b/spd/graph_interp/config.py new file mode 100644 index 000000000..e6e7441d3 --- /dev/null +++ b/spd/graph_interp/config.py @@ -0,0 +1,26 @@ +"""Graph interpretation configuration.""" + +from openrouter.components import Effort + +from spd.base_config import BaseConfig +from spd.dataset_attributions.storage import AttrMetric +from spd.settings import DEFAULT_PARTITION_NAME + + +class GraphInterpConfig(BaseConfig): + model: str = "google/gemini-3-flash-preview" + reasoning_effort: Effort = "low" + attr_metric: AttrMetric = "attr_abs" + top_k_attributed: int = 8 + max_examples: int = 20 + label_max_words: int = 8 + cost_limit_usd: float | None = None + max_requests_per_minute: int = 500 + max_concurrent: int = 50 + limit: int | None = None + + +class GraphInterpSlurmConfig(BaseConfig): + config: GraphInterpConfig + partition: str = DEFAULT_PARTITION_NAME + time: str = "24:00:00" diff --git a/spd/graph_interp/db.py b/spd/graph_interp/db.py new file mode 100644 index 000000000..1a6a83e37 --- /dev/null +++ b/spd/graph_interp/db.py @@ -0,0 +1,195 @@ +"""SQLite database for graph interpretation data. NFS-hosted, single writer then read-only.""" + +import sqlite3 +from pathlib import Path + +from spd.autointerp.db import DONE_MARKER +from spd.graph_interp.schemas import LabelResult, PromptEdge +from spd.utils.sqlite import open_nfs_sqlite + +_SCHEMA = """\ +CREATE TABLE IF NOT EXISTS output_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS input_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS unified_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS prompt_edges ( + component_key TEXT NOT NULL, + related_key TEXT NOT NULL, + pass TEXT NOT NULL, + attribution REAL NOT NULL, + related_label TEXT, + related_confidence TEXT, + PRIMARY KEY (component_key, related_key, pass) +); + +""" + + +_LABEL_TABLES = ("output_labels", "input_labels", "unified_labels") + + +class GraphInterpDB: + """NFS-hosted. Uses open_nfs_sqlite (no WAL). Single writer, then read-only.""" + + def __init__(self, db_path: Path, readonly: bool = False) -> None: + self._conn = open_nfs_sqlite(db_path, readonly) + if not readonly: + self._conn.executescript(_SCHEMA) + self._db_path = db_path + + def mark_done(self) -> None: + (self._db_path.parent / DONE_MARKER).touch() + + # -- Label CRUD (shared across output/input/unified) ----------------------- + + def _save_label(self, table: str, result: LabelResult) -> None: + assert table in _LABEL_TABLES + self._conn.execute( + f"INSERT OR REPLACE INTO {table} VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def _get_label(self, table: str, component_key: str) -> LabelResult | None: + assert table in _LABEL_TABLES + row = self._conn.execute( + f"SELECT * FROM {table} WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def _get_all_labels(self, table: str) -> dict[str, LabelResult]: + assert table in _LABEL_TABLES + rows = self._conn.execute(f"SELECT * FROM {table}").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + # -- Output labels --------------------------------------------------------- + + def save_output_label(self, result: LabelResult) -> None: + self._save_label("output_labels", result) + + def get_output_label(self, component_key: str) -> LabelResult | None: + return self._get_label("output_labels", component_key) + + def get_all_output_labels(self) -> dict[str, LabelResult]: + return self._get_all_labels("output_labels") + + # -- Input labels ---------------------------------------------------------- + + def save_input_label(self, result: LabelResult) -> None: + self._save_label("input_labels", result) + + def get_input_label(self, component_key: str) -> LabelResult | None: + return self._get_label("input_labels", component_key) + + def get_all_input_labels(self) -> dict[str, LabelResult]: + return self._get_all_labels("input_labels") + + # -- Unified labels -------------------------------------------------------- + + def save_unified_label(self, result: LabelResult) -> None: + self._save_label("unified_labels", result) + + def get_unified_label(self, component_key: str) -> LabelResult | None: + return self._get_label("unified_labels", component_key) + + def get_all_unified_labels(self) -> dict[str, LabelResult]: + return self._get_all_labels("unified_labels") + + def get_completed_unified_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM unified_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Prompt edges ---------------------------------------------------------- + + def save_prompt_edges(self, edges: list[PromptEdge]) -> None: + rows = [ + ( + e.component_key, + e.related_key, + e.pass_name, + e.attribution, + e.related_label, + e.related_confidence, + ) + for e in edges + ] + self._conn.executemany( + "INSERT OR REPLACE INTO prompt_edges VALUES (?, ?, ?, ?, ?, ?)", + rows, + ) + self._conn.commit() + + def get_prompt_edges(self, component_key: str) -> list[PromptEdge]: + rows = self._conn.execute( + "SELECT * FROM prompt_edges WHERE component_key = ?", (component_key,) + ).fetchall() + return [_row_to_prompt_edge(row) for row in rows] + + def get_all_prompt_edges(self) -> list[PromptEdge]: + rows = self._conn.execute("SELECT * FROM prompt_edges").fetchall() + return [_row_to_prompt_edge(row) for row in rows] + + # -- Stats ----------------------------------------------------------------- + + def get_label_count(self, table: str) -> int: + assert table in _LABEL_TABLES + row = self._conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() + assert row is not None + return row[0] + + def close(self) -> None: + self._conn.close() + + +def _row_to_label_result(row: sqlite3.Row) -> LabelResult: + return LabelResult( + component_key=row["component_key"], + label=row["label"], + confidence=row["confidence"], + reasoning=row["reasoning"], + raw_response=row["raw_response"], + prompt=row["prompt"], + ) + + +def _row_to_prompt_edge(row: sqlite3.Row) -> PromptEdge: + return PromptEdge( + component_key=row["component_key"], + related_key=row["related_key"], + pass_name=row["pass"], + attribution=row["attribution"], + related_label=row["related_label"], + related_confidence=row["related_confidence"], + ) diff --git a/spd/graph_interp/graph_context.py b/spd/graph_interp/graph_context.py new file mode 100644 index 000000000..df9b04953 --- /dev/null +++ b/spd/graph_interp/graph_context.py @@ -0,0 +1,60 @@ +"""Gather related components from attribution graph.""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal + +from spd.dataset_attributions.storage import DatasetAttributionEntry +from spd.graph_interp.ordering import parse_component_key +from spd.graph_interp.schemas import LabelResult +from spd.harvest.storage import CorrelationStorage + + +@dataclass +class RelatedComponent: + component_key: str + attribution: float + pmi: float | None + label: str | None + confidence: str | None + + +GetAttributed = Callable[[str, int, Literal["positive", "negative"]], list[DatasetAttributionEntry]] + + +def get_related_components( + component_key: str, + get_attributed: GetAttributed, + correlation_storage: CorrelationStorage, + labels_so_far: dict[str, LabelResult], + k: int, +) -> list[RelatedComponent]: + """Top-K components connected via attribution, enriched with PMI and labels.""" + my_layer, _ = parse_component_key(component_key) + + pos = get_attributed(component_key, k * 2, "positive") + neg = get_attributed(component_key, k * 2, "negative") + + candidates = pos + neg + candidates.sort(key=lambda e: abs(e.value), reverse=True) + candidates = candidates[:k] + + result = [] + for e in candidates: + r_layer, _ = parse_component_key(e.component_key) + assert r_layer != my_layer, ( + f"Same-layer component {e.component_key} in related list for {component_key}" + ) + + label = labels_so_far.get(e.component_key) + result.append( + RelatedComponent( + component_key=e.component_key, + attribution=e.value, + pmi=correlation_storage.pmi(component_key, e.component_key), + label=label.label if label else None, + confidence=label.confidence if label else None, + ) + ) + + return result diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py new file mode 100644 index 000000000..f0c92725a --- /dev/null +++ b/spd/graph_interp/interpret.py @@ -0,0 +1,363 @@ +"""Main three-phase graph interpretation execution. + +Structure: + output_labels = scan(layers_reversed, step) + input_labels = scan(layers_forward, step) + unified = map(output_labels + input_labels, unify) + +Each scan folds over layers. Within a layer, components are labeled in parallel +via async LLM calls. The fold accumulator (labels_so_far) lets each component's +prompt include labels from previously-processed layers. +""" + +import asyncio +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable +from pathlib import Path +from typing import Literal + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.llm_api import CostTracker, LLMError, LLMJob, LLMResult, map_llm_calls +from spd.autointerp.schemas import ModelMetadata +from spd.dataset_attributions.storage import ( + AttrMetric, + DatasetAttributionEntry, + DatasetAttributionStorage, +) +from spd.graph_interp import graph_context +from spd.graph_interp.config import GraphInterpConfig +from spd.graph_interp.db import GraphInterpDB +from spd.graph_interp.graph_context import RelatedComponent, get_related_components +from spd.graph_interp.ordering import group_and_sort_by_layer +from spd.graph_interp.prompts import ( + LABEL_SCHEMA, + format_input_prompt, + format_output_prompt, + format_unification_prompt, +) +from spd.graph_interp.schemas import LabelResult, PromptEdge +from spd.harvest.analysis import TokenPRLift, get_input_token_stats, get_output_token_stats +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData +from spd.harvest.storage import CorrelationStorage, TokenStatsStorage +from spd.log import logger + +GetRelated = Callable[[str, dict[str, LabelResult]], list[RelatedComponent]] +Step = Callable[[list[str], dict[str, LabelResult]], Awaitable[dict[str, LabelResult]]] +MakePrompt = Callable[["ComponentData", "TokenPRLift", list[RelatedComponent]], str] + + +def run_graph_interp( + openrouter_api_key: str, + config: GraphInterpConfig, + harvest: HarvestRepo, + attribution_storage: DatasetAttributionStorage, + correlation_storage: CorrelationStorage, + token_stats: TokenStatsStorage, + model_metadata: ModelMetadata, + db_path: Path, + tokenizer_name: str, +) -> None: + logger.info("Loading tokenizer...") + app_tok = AppTokenizer.from_pretrained(tokenizer_name) + + logger.info("Loading component summaries...") + summaries = harvest.get_summary() + alive = {k: s for k, s in summaries.items() if s.firing_density > 0.0} + all_keys = sorted(alive, key=lambda k: alive[k].firing_density, reverse=True) + if config.limit is not None: + all_keys = all_keys[: config.limit] + + layers = group_and_sort_by_layer(all_keys, model_metadata.layer_descriptions) + total = len(all_keys) + logger.info(f"Graph interp: {total} components across {len(layers)} layers") + + # -- Injected behaviours --------------------------------------------------- + + shared_cost = CostTracker(limit_usd=config.cost_limit_usd) + + async def llm_map( + jobs: Iterable[LLMJob], n_total: int | None = None + ) -> AsyncGenerator[LLMResult | LLMError]: + async for result in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=config.model, + reasoning_effort=config.reasoning_effort, + jobs=jobs, + max_tokens=8000, + max_concurrent=config.max_concurrent, + max_requests_per_minute=config.max_requests_per_minute, + cost_limit_usd=None, + response_schema=LABEL_SCHEMA, + n_total=n_total, + cost_tracker=shared_cost, + ): + yield result + + concrete_to_canon = model_metadata.layer_descriptions + canon_to_concrete = {v: k for k, v in concrete_to_canon.items()} + + def _translate_entries(entries: list[DatasetAttributionEntry]) -> list[DatasetAttributionEntry]: + for e in entries: + if e.layer in canon_to_concrete: + e.layer = canon_to_concrete[e.layer] + e.component_key = f"{e.layer}:{e.component_idx}" + return entries + + def _to_canon(concrete_key: str) -> str: + layer, idx = concrete_key.rsplit(":", 1) + return f"{concrete_to_canon[layer]}:{idx}" + + def _make_get_attributed( + method: Callable[..., list[DatasetAttributionEntry]], metric: AttrMetric + ) -> "graph_context.GetAttributed": + def get( + key: str, k: int, sign: Literal["positive", "negative"] + ) -> list[DatasetAttributionEntry]: + return _translate_entries(method(_to_canon(key), k=k, sign=sign, metric=metric)) + + return get + + def _get_related(get_attributed: "graph_context.GetAttributed") -> GetRelated: + def get(key: str, labels_so_far: dict[str, LabelResult]) -> list[RelatedComponent]: + return get_related_components( + key, + get_attributed, + correlation_storage, + labels_so_far, + config.top_k_attributed, + ) + + return get + + # -- Layer processor (shared for output and input passes) -------------------- + + def _make_process_layer( + get_related: GetRelated, + save_label: Callable[[LabelResult], None], + pass_name: Literal["output", "input"], + get_token_stats: Callable[[str], TokenPRLift | None], + make_prompt: MakePrompt, + ) -> Step: + async def process( + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + stats = get_token_stats(key) + assert stats is not None, f"No {pass_name} token stats for {key}" + + related = get_related(key, labels_so_far) + db.save_prompt_edges( + [ + PromptEdge( + component_key=key, + related_key=r.component_key, + pass_name=pass_name, + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] + ) + yield LLMJob( + prompt=make_prompt(component, stats, related), + schema=LABEL_SCHEMA, + key=key, + ) + + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + return process + + # -- Scan (fold over layers) ----------------------------------------------- + + async def scan( + layer_order: list[tuple[str, list[str]]], + initial: dict[str, LabelResult], + step: Step, + ) -> dict[str, LabelResult]: + labels = dict(initial) + if labels: + logger.info(f"Resuming, {len(labels)} already completed") + + completed_so_far = 0 + for layer, keys in layer_order: + pending = [k for k in keys if k not in labels] + if not pending: + completed_so_far += len(keys) + continue + + new_labels = await step(pending, labels) + labels.update(new_labels) + + completed_so_far += len(keys) + logger.info(f"Completed layer {layer} ({completed_so_far}/{total})") + + return labels + + # -- Map (parallel over all components) ------------------------------------ + + async def map_unify( + output_labels: dict[str, LabelResult], + input_labels: dict[str, LabelResult], + ) -> None: + completed = db.get_completed_unified_keys() + keys = [k for k in all_keys if k not in completed] + if not keys: + logger.info("Unification: all labels already completed") + return + if completed: + logger.info(f"Unification: resuming, {len(completed)} already completed") + + unifiable_keys = [k for k in keys if k in output_labels and k in input_labels] + n_skipped = len(keys) - len(unifiable_keys) + + def jobs() -> Iterable[LLMJob]: + for key in unifiable_keys: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + prompt = format_unification_prompt( + output_label=output_labels[key], + input_label=input_labels[key], + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + if n_skipped: + logger.warning(f"Skipping {n_skipped} components missing output or input labels") + logger.info(f"Unifying {len(unifiable_keys)} components") + new_labels = await _collect_labels( + llm_map, jobs(), len(unifiable_keys), db.save_unified_label + ) + logger.info(f"Unification: completed {len(new_labels)}/{len(keys)}") + + # -- Run ------------------------------------------------------------------- + + logger.info("Initializing DB and building scan steps...") + db = GraphInterpDB(db_path) + + metric = config.attr_metric + get_targets = _make_get_attributed(attribution_storage.get_top_targets, metric) + get_sources = _make_get_attributed(attribution_storage.get_top_sources, metric) + + def _output_prompt( + component: ComponentData, stats: TokenPRLift, related: list[RelatedComponent] + ) -> str: + return format_output_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + output_token_stats=stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + + def _input_prompt( + component: ComponentData, stats: TokenPRLift, related: list[RelatedComponent] + ) -> str: + return format_input_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + input_token_stats=stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + + label_output = _make_process_layer( + _get_related(get_targets), + db.save_output_label, + "output", + lambda key: get_output_token_stats(token_stats, key, app_tok, top_k=50), + _output_prompt, + ) + label_input = _make_process_layer( + _get_related(get_sources), + db.save_input_label, + "input", + lambda key: get_input_token_stats(token_stats, key, app_tok, top_k=20), + _input_prompt, + ) + + async def _run() -> None: + logger.section("Phase 1: Output pass (late → early)") + output_labels = await scan(list(reversed(layers)), db.get_all_output_labels(), label_output) + + logger.section("Phase 2: Input pass (early → late)") + input_labels = await scan(list(layers), db.get_all_input_labels(), label_input) + + logger.section("Phase 3: Unification") + await map_unify(output_labels, input_labels) + + logger.info( + f"Completed: {db.get_label_count('output_labels')} output, " + f"{db.get_label_count('input_labels')} input, " + f"{db.get_label_count('unified_labels')} unified labels -> {db_path}" + ) + db.mark_done() + + try: + asyncio.run(_run()) + finally: + db.close() + + +# -- Shared LLM call machinery ------------------------------------------------ + + +async def _collect_labels( + llm_map: Callable[[Iterable[LLMJob], int | None], AsyncGenerator[LLMResult | LLMError]], + jobs: Iterable[LLMJob], + n_total: int, + save_label: Callable[[LabelResult], None], +) -> dict[str, LabelResult]: + """Run LLM jobs, parse results, save to DB, return new labels.""" + new_labels: dict[str, LabelResult] = {} + n_errors = 0 + + async for outcome in llm_map(jobs, n_total): + match outcome: + case LLMResult(job=job, parsed=parsed, raw=raw): + result = _parse_label(job.key, parsed, raw, job.prompt) + save_label(result) + new_labels[job.key] = result + case LLMError(job=job, error=e): + n_errors += 1 + logger.error(f"Skipping {job.key}: {type(e).__name__}: {e}") + _check_error_rate(n_errors, len(new_labels)) + + return new_labels + + +def _parse_label(key: str, parsed: dict[str, object], raw: str, prompt: str) -> LabelResult: + assert len(parsed) == 3, f"Expected 3 fields, got {len(parsed)}" + label = parsed["label"] + confidence = parsed["confidence"] + reasoning = parsed["reasoning"] + assert isinstance(label, str) and isinstance(confidence, str) and isinstance(reasoning, str) + return LabelResult( + component_key=key, + label=label, + confidence=confidence, + reasoning=reasoning, + raw_response=raw, + prompt=prompt, + ) + + +def _check_error_rate(n_errors: int, n_done: int) -> None: + total = n_errors + n_done + if total > 10 and n_errors / total > 0.05: + raise RuntimeError( + f"Error rate {n_errors / total:.0%} ({n_errors}/{total}) exceeds 5% threshold" + ) diff --git a/spd/graph_interp/ordering.py b/spd/graph_interp/ordering.py new file mode 100644 index 000000000..03b2f0d3d --- /dev/null +++ b/spd/graph_interp/ordering.py @@ -0,0 +1,81 @@ +"""Layer ordering for graph interpretation. + +Uses the topology module's CanonicalWeight system for correct ordering +across all model architectures. Canonical addresses are provided by +ModelMetadata.layer_descriptions (concrete path → canonical string). +""" + +from spd.topology.canonical import ( + CanonicalWeight, + FusedAttnWeight, + GLUWeight, + LayerWeight, + MLPWeight, + SeparateAttnWeight, +) + +_SUBLAYER_ORDER = {"attn": 0, "attn_fused": 0, "glu": 1, "mlp": 1} + +_PROJECTION_ORDER: dict[type, dict[str, int]] = { + SeparateAttnWeight: {"q": 0, "k": 1, "v": 2, "o": 3}, + FusedAttnWeight: {"qkv": 0, "o": 1}, + GLUWeight: {"gate": 0, "up": 1, "down": 2}, + MLPWeight: {"up": 0, "down": 1}, +} + + +def canonical_sort_key(canonical: str) -> tuple[int, int, int]: + """Sort key for a canonical address string like '0.attn.q' or '1.mlp.down'.""" + weight = CanonicalWeight.parse(canonical) + assert isinstance(weight, LayerWeight), f"Expected LayerWeight, got {type(weight).__name__}" + + match weight.name: + case SeparateAttnWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["attn"] + proj_idx = _PROJECTION_ORDER[SeparateAttnWeight][p] + case FusedAttnWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["attn_fused"] + proj_idx = _PROJECTION_ORDER[FusedAttnWeight][p] + case GLUWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["glu"] + proj_idx = _PROJECTION_ORDER[GLUWeight][p] + case MLPWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["mlp"] + proj_idx = _PROJECTION_ORDER[MLPWeight][p] + + return weight.layer_idx, sublayer_idx, proj_idx + + +def parse_component_key(key: str) -> tuple[str, int]: + """Split 'h.1.mlp.c_fc:42' into ('h.1.mlp.c_fc', 42).""" + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +def group_and_sort_by_layer( + component_keys: list[str], + layer_descriptions: dict[str, str], +) -> list[tuple[str, list[str]]]: + """Group component keys by layer, return [(layer, [keys])] in topological order. + + Args: + component_keys: Component keys like 'h.0.attn.q_proj:42'. + layer_descriptions: Mapping from concrete layer path to canonical address + (from ModelMetadata.layer_descriptions). + """ + by_layer: dict[str, list[str]] = {} + for key in component_keys: + layer, _ = parse_component_key(key) + by_layer.setdefault(layer, []).append(key) + + def sort_key(layer: str) -> tuple[int, int, int]: + canonical = layer_descriptions[layer] + return canonical_sort_key(canonical) + + sorted_layers = sorted(by_layer.keys(), key=sort_key) + + result: list[tuple[str, list[str]]] = [] + for layer in sorted_layers: + keys = sorted(by_layer[layer], key=lambda k: parse_component_key(k)[1]) + result.append((layer, keys)) + return result diff --git a/spd/graph_interp/prompts.py b/spd/graph_interp/prompts.py new file mode 100644 index 000000000..0dc298e41 --- /dev/null +++ b/spd/graph_interp/prompts.py @@ -0,0 +1,234 @@ +"""Prompt formatters for graph interpretation. + +Three prompts: +1. Output pass (late→early): "What does this component DO?" — output tokens, says examples, downstream +2. Input pass (early→late): "What TRIGGERS this component?" — input tokens, fires-on examples, upstream +3. Unification: Synthesize output + input labels into unified label. +""" + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.prompt_helpers import ( + build_fires_on_examples, + build_input_section, + build_output_section, + build_says_examples, + density_note, + human_layer_desc, + layer_position_note, +) +from spd.autointerp.schemas import ModelMetadata +from spd.graph_interp.graph_context import RelatedComponent +from spd.graph_interp.schemas import LabelResult +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md + +LABEL_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "label": {"type": "string"}, + "confidence": {"type": "string", "enum": ["low", "medium", "high"]}, + "reasoning": {"type": "string"}, + }, + "required": ["label", "confidence", "reasoning"], + "additionalProperties": False, +} + +JSON_INSTRUCTION = ( + 'Respond with JSON: {"label": "...", "confidence": "low|medium|high", "reasoning": "..."}' +) + +UNCLEAR_NOTE = 'Say "unclear" if the evidence is too weak.' + + +def _component_header( + component: ComponentData, + model_metadata: ModelMetadata, +) -> Md: + canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) + layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) + position_note = layer_position_note(canonical, model_metadata.n_blocks) + dens_note = density_note(component.firing_density) + + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) + + md = Md() + md.h(2, "Context").bullets( + [ + f"Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model", + f"Firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) + context_notes = " ".join(filter(None, [position_note, dens_note])) + if context_notes: + md.p(context_notes) + return md + + +def format_output_prompt( + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + output_token_stats: TokenPRLift, + related: list[RelatedComponent], + label_max_words: int, + max_examples: int, +) -> str: + output_pmi = [ + (app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top + ] + + md = Md() + md.p( + "You are analyzing a component in a neural network to understand its " + "OUTPUT FUNCTION — what it does when it fires." + ) + md.extend(_component_header(component, model_metadata)) + + md.h(2, "Output tokens (what the model produces when this component fires)") + md.extend(build_output_section(output_token_stats, output_pmi)) + + md.h(2, "Activation examples — what the model produces") + md.extend(build_says_examples(component, app_tok, max_examples)) + + md.h(2, "Downstream components (what this component influences)") + md.p( + "These components in later layers are most influenced by this component (by gradient attribution):" + ) + md.extend(_format_related(related, model_metadata, app_tok)) + + md.h(2, "Task") + md.p( + f"Give a {label_max_words}-word-or-fewer label describing this component's " + "OUTPUT FUNCTION — what it does when it fires." + ) + md.p(UNCLEAR_NOTE).p(JSON_INSTRUCTION) + + return md.build() + + +def format_input_prompt( + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + input_token_stats: TokenPRLift, + related: list[RelatedComponent], + label_max_words: int, + max_examples: int, +) -> str: + input_pmi = [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] + + md = Md() + md.p( + "You are analyzing a component in a neural network to understand its " + "INPUT FUNCTION — what triggers it to fire." + ) + md.extend(_component_header(component, model_metadata)) + + md.h(2, "Input tokens (what causes this component to fire)") + md.extend(build_input_section(input_token_stats, input_pmi)) + + md.h(2, "Activation examples — where the component fires") + md.extend(build_fires_on_examples(component, app_tok, max_examples)) + + md.h(2, "Upstream components (what feeds into this component)") + md.p("These components in earlier layers most strongly attribute to this component:") + md.extend(_format_related(related, model_metadata, app_tok)) + + md.h(2, "Task") + md.p( + f"Give a {label_max_words}-word-or-fewer label describing this component's " + "INPUT FUNCTION — what conditions trigger it to fire." + ) + md.p(UNCLEAR_NOTE).p(JSON_INSTRUCTION) + + return md.build() + + +def format_unification_prompt( + output_label: LabelResult, + input_label: LabelResult, + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + label_max_words: int, + max_examples: int, +) -> str: + md = Md() + md.p("A neural network component has been analyzed from two perspectives.") + md.extend(_component_header(component, model_metadata)) + + md.h(2, "Activation examples — where the component fires") + md.extend(build_fires_on_examples(component, app_tok, max_examples)) + + md.h(2, "Activation examples — what the model produces") + md.extend(build_says_examples(component, app_tok, max_examples)) + + md.h(2, "Two-perspective analysis") + md.p( + f'OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence})\n' + f" Reasoning: {output_label.reasoning}\n\n" + f'INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence})\n' + f" Reasoning: {input_label.reasoning}" + ) + + md.h(2, "Task") + md.p( + f"Synthesize these into a single unified label (max {label_max_words} words) " + "that captures the component's complete role. If input and output suggest the " + "same concept, unify them. If they describe genuinely different aspects " + "(e.g. fires on X, produces Y), combine both." + ) + md.p(JSON_INSTRUCTION) + + return md.build() + + +def _format_related( + components: list[RelatedComponent], + model_metadata: ModelMetadata, + app_tok: AppTokenizer, +) -> Md: + visible = [n for n in components if n.label is not None or _is_token_entry(n.component_key)] + md = Md() + if not visible: + md.p("(no related components with labels found)") + return md + + max_attr = max(abs(n.attribution) for n in visible) + norm = max_attr if max_attr > 0 else 1.0 + + lines: list[str] = [] + for n in visible: + display = _component_display(n.component_key, model_metadata, app_tok) + rel_attr = n.attribution / norm + pmi_str = f", co-firing PMI: {n.pmi:.2f}" if n.pmi is not None else "" + line = f" {display} (relative attribution: {rel_attr:+.2f}{pmi_str})" + if n.label is not None: + line += f'\n label: "{n.label}" (confidence: {n.confidence})' + lines.append(line) + + md.p("\n".join(lines)) + return md + + +def _is_token_entry(key: str) -> bool: + layer = key.rsplit(":", 1)[0] + return layer in ("embed", "output") + + +def _component_display(key: str, model_metadata: ModelMetadata, app_tok: AppTokenizer) -> str: + layer, idx_str = key.rsplit(":", 1) + match layer: + case "embed": + return f'input token "{app_tok.get_tok_display(int(idx_str))}"' + case "output": + return f'output token "{app_tok.get_tok_display(int(idx_str))}"' + case _: + canonical = model_metadata.layer_descriptions.get(layer, layer) + desc = human_layer_desc(canonical, model_metadata.n_blocks) + return f"component from {desc}" diff --git a/spd/graph_interp/repo.py b/spd/graph_interp/repo.py new file mode 100644 index 000000000..a76590e3d --- /dev/null +++ b/spd/graph_interp/repo.py @@ -0,0 +1,96 @@ +"""Graph interpretation data repository. + +Owns SPD_OUT_DIR/graph_interp// and provides read access +to output, input, and unified labels. + +Use GraphInterpRepo.open() to construct — returns None if no data exists. +""" + +from pathlib import Path +from typing import Any + +import yaml + +from spd.autointerp.db import DONE_MARKER +from spd.graph_interp.db import GraphInterpDB +from spd.graph_interp.schemas import LabelResult, PromptEdge, get_graph_interp_dir + + +class GraphInterpRepo: + """Read access to graph interpretation data for a single run.""" + + def __init__(self, db: GraphInterpDB, subrun_dir: Path, run_id: str) -> None: + self._db = db + self._subrun_dir = subrun_dir + self.subrun_id = subrun_dir.name + self.run_id = run_id + + @classmethod + def open(cls, run_id: str) -> "GraphInterpRepo | None": + """Open graph interp data for a run. Returns None if no data exists.""" + base_dir = get_graph_interp_dir(run_id) + if not base_dir.exists(): + return None + candidates = sorted( + [ + d + for d in base_dir.iterdir() + if d.is_dir() and d.name.startswith("ti-") and (d / DONE_MARKER).exists() + ], + key=lambda d: d.name, + ) + if not candidates: + return None + subrun_dir = candidates[-1] + db_path = subrun_dir / "interp.db" + if not db_path.exists(): + return None + return cls( + db=GraphInterpDB(db_path, readonly=True), + subrun_dir=subrun_dir, + run_id=run_id, + ) + + def get_config(self) -> dict[str, Any] | None: + config_path = self._subrun_dir / "config.yaml" + if not config_path.exists(): + return None + with open(config_path) as f: + return yaml.safe_load(f) + + # -- Labels ---------------------------------------------------------------- + + def get_all_output_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_output_labels() + + def get_all_input_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_input_labels() + + def get_all_unified_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_unified_labels() + + def get_output_label(self, component_key: str) -> LabelResult | None: + return self._db.get_output_label(component_key) + + def get_input_label(self, component_key: str) -> LabelResult | None: + return self._db.get_input_label(component_key) + + def get_unified_label(self, component_key: str) -> LabelResult | None: + return self._db.get_unified_label(component_key) + + # -- Edges ----------------------------------------------------------------- + + def get_prompt_edges(self, component_key: str) -> list[PromptEdge]: + return self._db.get_prompt_edges(component_key) + + def get_all_prompt_edges(self) -> list[PromptEdge]: + return self._db.get_all_prompt_edges() + + # -- Stats ----------------------------------------------------------------- + + def get_label_counts(self) -> dict[str, int]: + return { + "output": self._db.get_label_count("output_labels"), + "input": self._db.get_label_count("input_labels"), + "unified": self._db.get_label_count("unified_labels"), + } diff --git a/spd/graph_interp/schemas.py b/spd/graph_interp/schemas.py new file mode 100644 index 000000000..ad391e270 --- /dev/null +++ b/spd/graph_interp/schemas.py @@ -0,0 +1,37 @@ +"""Data types and path helpers for graph interpretation.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +from spd.settings import SPD_OUT_DIR + +GRAPH_INTERP_DIR = SPD_OUT_DIR / "graph_interp" + + +def get_graph_interp_dir(decomposition_id: str) -> Path: + return GRAPH_INTERP_DIR / decomposition_id + + +def get_graph_interp_subrun_dir(decomposition_id: str, subrun_id: str) -> Path: + return get_graph_interp_dir(decomposition_id) / subrun_id + + +@dataclass +class LabelResult: + component_key: str + label: str + confidence: str + reasoning: str + raw_response: str + prompt: str + + +@dataclass +class PromptEdge: + component_key: str + related_key: str + pass_name: Literal["output", "input"] + attribution: float + related_label: str | None + related_confidence: str | None diff --git a/spd/graph_interp/scripts/__init__.py b/spd/graph_interp/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/graph_interp/scripts/run.py b/spd/graph_interp/scripts/run.py new file mode 100644 index 000000000..2060e0e62 --- /dev/null +++ b/spd/graph_interp/scripts/run.py @@ -0,0 +1,96 @@ +"""CLI entry point for graph interpretation. + +Called by SLURM or directly: + python -m spd.graph_interp.scripts.run --config_json '{...}' +""" + +import os +from typing import Any + +from dotenv import load_dotenv + +from spd.adapters import adapter_from_id +from spd.adapters.spd import SPDAdapter +from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.config import GraphInterpConfig +from spd.graph_interp.interpret import run_graph_interp +from spd.graph_interp.schemas import get_graph_interp_subrun_dir +from spd.harvest.repo import HarvestRepo +from spd.log import logger + + +def main( + decomposition_id: str, + config_json: dict[str, Any], + subrun_id: str, + harvest_subrun_id: str, +) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + config = GraphInterpConfig.model_validate(config_json) + + load_dotenv() + openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") + assert openrouter_api_key, "OPENROUTER_API_KEY not set" + subrun_dir = get_graph_interp_subrun_dir(decomposition_id, subrun_id) + subrun_dir.mkdir(parents=True, exist_ok=True) + config.to_file(subrun_dir / "config.yaml") + db_path = subrun_dir / "interp.db" + logger.info(f"Graph interp run: {subrun_dir}") + + logger.info("Loading adapter and model metadata...") + adapter = adapter_from_id(decomposition_id) + assert isinstance(adapter, SPDAdapter) + logger.info("Loading harvest data...") + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) + + logger.info("Loading dataset attributions...") + attributions = AttributionRepo.open(decomposition_id) + assert attributions is not None, f"Dataset attributions required for {decomposition_id}" + attribution_storage = attributions.get_attributions() + logger.info( + f" {attribution_storage.n_components} components, {attribution_storage.n_tokens_processed:,} tokens" + ) + + logger.info("Loading component correlations...") + correlations = harvest.get_correlations() + assert correlations is not None, f"Component correlations required for {decomposition_id}" + + logger.info("Loading token stats...") + token_stats = harvest.get_token_stats() + assert token_stats is not None, f"Token stats required for {decomposition_id}" + + logger.info("Data loading complete") + + run_graph_interp( + openrouter_api_key=openrouter_api_key, + config=config, + harvest=harvest, + attribution_storage=attribution_storage, + correlation_storage=correlations, + token_stats=token_stats, + model_metadata=adapter.model_metadata, + db_path=db_path, + tokenizer_name=adapter.tokenizer_name, + ) + + +def get_command( + decomposition_id: str, + config: GraphInterpConfig, + subrun_id: str, + harvest_subrun_id: str, +) -> str: + config_json = config.model_dump_json(exclude_none=True) + return ( + "python -m spd.graph_interp.scripts.run " + f"--decomposition_id {decomposition_id} " + f"--config_json '{config_json}' " + f"--subrun_id {subrun_id} " + f"--harvest_subrun_id {harvest_subrun_id} " + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/graph_interp/scripts/run_slurm.py b/spd/graph_interp/scripts/run_slurm.py new file mode 100644 index 000000000..915d2344e --- /dev/null +++ b/spd/graph_interp/scripts/run_slurm.py @@ -0,0 +1,64 @@ +"""SLURM launcher for graph interpretation. + +Submits a single CPU job that runs the three-phase interpretation pipeline. +Depends on both harvest merge and attribution merge jobs. +""" + +from dataclasses import dataclass +from datetime import datetime + +from spd.graph_interp.config import GraphInterpSlurmConfig +from spd.graph_interp.scripts import run +from spd.log import logger +from spd.utils.slurm import SlurmConfig, SubmitResult, generate_script, submit_slurm_job + + +@dataclass +class GraphInterpSubmitResult: + result: SubmitResult + + +def submit_graph_interp( + decomposition_id: str, + config: GraphInterpSlurmConfig, + dependency_job_ids: list[str], + harvest_subrun_id: str, + snapshot_branch: str | None = None, +) -> GraphInterpSubmitResult: + """Submit graph interpretation to SLURM.""" + subrun_id = "ti-" + datetime.now().strftime("%Y%m%d_%H%M%S") + cmd = run.get_command( + decomposition_id=decomposition_id, + config=config.config, + subrun_id=subrun_id, + harvest_subrun_id=harvest_subrun_id, + ) + + dependency_str = ":".join(dependency_job_ids) if dependency_job_ids else None + + slurm_config = SlurmConfig( + job_name="spd-graph-interp", + partition=config.partition, + n_gpus=0, + cpus_per_task=16, + mem="240G", + time=config.time, + snapshot_branch=snapshot_branch, + dependency_job_id=dependency_str, + comment=decomposition_id, + ) + script_content = generate_script(slurm_config, cmd) + result = submit_slurm_job(script_content, "spd-graph-interp") + + logger.section("Graph interp job submitted") + logger.values( + { + "Job ID": result.job_id, + "Decomposition ID": decomposition_id, + "Model": config.config.model, + "Depends on": ", ".join(dependency_job_ids), + "Log": result.log_pattern, + } + ) + + return GraphInterpSubmitResult(result=result) diff --git a/spd/graph_interp/scripts/run_slurm_cli.py b/spd/graph_interp/scripts/run_slurm_cli.py new file mode 100644 index 000000000..0ddf11982 --- /dev/null +++ b/spd/graph_interp/scripts/run_slurm_cli.py @@ -0,0 +1,30 @@ +"""CLI entry point for graph interp SLURM launcher. + +Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. + +Usage: + spd-graph-interp --config graph_interp_config.yaml +""" + +import fire + + +def main(decomposition_id: str, config: str, harvest_subrun_id: str) -> None: + """Submit graph interpretation pipeline to SLURM. + + Args: + decomposition_id: ID of the target decomposition run. + config: Path to GraphInterpSlurmConfig YAML/JSON. + harvest_subrun_id: Harvest subrun to use (e.g. "h-20260306_120000"). + """ + from spd.graph_interp.config import GraphInterpSlurmConfig + from spd.graph_interp.scripts.run_slurm import submit_graph_interp + + slurm_config = GraphInterpSlurmConfig.from_file(config) + submit_graph_interp( + decomposition_id, slurm_config, dependency_job_ids=[], harvest_subrun_id=harvest_subrun_id + ) + + +def cli() -> None: + fire.Fire(main) diff --git a/spd/harvest/CLAUDE.md b/spd/harvest/CLAUDE.md index 235f4dffd..05a9bc906 100644 --- a/spd/harvest/CLAUDE.md +++ b/spd/harvest/CLAUDE.md @@ -29,61 +29,72 @@ The command: For environments without SLURM, run the worker script directly: ```bash -# Single GPU with specific number of batches -python -m spd.harvest.scripts.run --n_batches 1000 - -# Single GPU processing entire dataset (omit --n_batches) +# Single GPU (defaults from HarvestConfig, auto-generates subrun ID) python -m spd.harvest.scripts.run +# Single GPU with config file +python -m spd.harvest.scripts.run --config_path path/to/config.yaml + # Multi-GPU (run in parallel via shell, tmux, etc.) -python -m spd.harvest.scripts.run --n_batches 1000 --rank 0 --world_size 4 & -python -m spd.harvest.scripts.run --n_batches 1000 --rank 1 --world_size 4 & -python -m spd.harvest.scripts.run --n_batches 1000 --rank 2 --world_size 4 & -python -m spd.harvest.scripts.run --n_batches 1000 --rank 3 --world_size 4 & +# All workers and the merge step must share the same --subrun_id +SUBRUN="h-$(date +%Y%m%d_%H%M%S)" +python -m spd.harvest.scripts.run --config_json '{"n_batches": 1000}' --rank 0 --world_size 4 --subrun_id $SUBRUN & +python -m spd.harvest.scripts.run --config_json '{"n_batches": 1000}' --rank 1 --world_size 4 --subrun_id $SUBRUN & +python -m spd.harvest.scripts.run --config_json '{"n_batches": 1000}' --rank 2 --world_size 4 --subrun_id $SUBRUN & +python -m spd.harvest.scripts.run --config_json '{"n_batches": 1000}' --rank 3 --world_size 4 --subrun_id $SUBRUN & wait # Merge results after all workers complete -python -m spd.harvest.scripts.run --merge +python -m spd.harvest.scripts.run --merge --subrun_id $SUBRUN ``` Each worker processes batches where `batch_idx % world_size == rank`, then the merge step combines all partial results. ## Data Storage -Data is stored in `SPD_OUT_DIR/harvest/` (see `spd/settings.py`): +Each harvest invocation creates a timestamped sub-run directory. `HarvestRepo` automatically loads from the latest sub-run. ``` SPD_OUT_DIR/harvest// -├── activation_contexts/ -│ ├── config.json -│ ├── summary.json # Lightweight {component_key: {layer, idx, mean_ci}} -│ └── components.jsonl # One ComponentData per line (quite big - usually ≈10gb) -├── correlations/ +├── h-20260211_120000/ # sub-run 1 +│ ├── harvest.db # SQLite DB: components table + config table (WAL mode) │ ├── component_correlations.pt -│ └── token_stats.pt -└── worker_states/ - └── worker_*.pt # Per-worker states (cleaned up after merge) +│ ├── token_stats.pt +│ └── worker_states/ # cleaned up after merge +│ └── worker_*.pt +├── h-20260211_140000/ # sub-run 2 +│ └── ... ``` +Legacy layout (pre sub-run, `activation_contexts/` + `correlations/`) is no longer supported. + ## Architecture ### SLURM Launcher (`scripts/run_slurm.py`, `scripts/run_slurm_cli.py`) Entry point via `spd-harvest`. Submits array job + dependent merge job. +**Intruder evaluation** (`spd/harvest/intruder.py`) evaluates the quality of the *decomposition itself* — whether component activation patterns are coherent — without relying on LLM-generated labels. Intruder scores are stored in `harvest.db`, not `interp.db`. Intruder eval is submitted as a top-level postprocess stage (via `spd-postprocess`), not as part of the harvest pipeline. + ### Worker Script (`scripts/run.py`) -Internal script called by SLURM jobs. Supports: +Internal script called by SLURM jobs. Accepts config via `--config_path` (file) or `--config_json` (inline JSON). Supports: +- `--config_path`/`--config_json`: Provide `HarvestConfig` (defaults used if neither given) - `--rank R --world_size N`: Process subset of batches - `--merge`: Combine per-rank results into final files +- `--subrun_id`: Sub-run identifier (auto-generated if not provided) + +### Config (`config.py`) + +`HarvestConfig` (tuning params) and `HarvestSlurmConfig` (HarvestConfig + SLURM params). `wandb_path` is a runtime arg, not part of config. ### Harvest Logic (`harvest.py`) Main harvesting functions: -- `harvest_activation_contexts()`: Process batches for a single rank -- `merge_activation_contexts()`: Combine results from all ranks +- `harvest_activation_contexts(wandb_path, config, output_dir, ...)`: Process batches for a single rank +- `merge_activation_contexts(output_dir)`: Combine worker results from `output_dir/worker_states/` into `output_dir` -### Harvester (`lib/harvester.py`) +### Harvester (`harvester.py`) Core class that accumulates statistics in a single pass: - **Correlations**: Co-occurrence counts between components (for precision/recall/PMI) @@ -95,21 +106,21 @@ Key optimizations: - Subsampling: Caps firings per batch at 10k (plenty for k=20 examples per component) - All accumulation on GPU, only moves to CPU for final `build_results()` -### Reservoir Sampler (`lib/reservoir_sampler.py`) - -Implements reservoir sampling for uniform random sampling from a stream. Maintains a fixed-size buffer of examples that represents a uniform sample over all items seen. - ### Storage (`storage.py`) `CorrelationStorage` and `TokenStatsStorage` classes for loading/saving harvested data. -### Loaders (`loaders.py`) +### Database (`db.py`) + +`HarvestDB` class wrapping SQLite for component-level data. Two tables: +- `components`: keyed by `component_key`, stores layer/idx/mean_ci + JSON blobs for activation examples and PMI data +- `config`: key-value store for harvest config (ci_threshold, etc.) + +Uses WAL mode for concurrent reads. Serialization via `orjson`. + +### Repository (`repo.py`) -Functions for loading harvested data by run ID: -- `load_activation_contexts_summary(run_id)` -> dict[component_key, ComponentSummary] -- `load_component_activation_contexts(run_id, component_key)` -> ComponentData -- `load_correlations(run_id)` -> CorrelationStorage -- `load_token_stats(run_id)` -> TokenStatsStorage +`HarvestRepo` provides read-only access to all harvest data for a run. Automatically resolves the latest sub-run directory (by lexicographic sort of `h-YYYYMMDD_HHMMSS` names). Falls back to legacy layout if no sub-runs exist. Used by the app backend. ## Key Types (`schemas.py`) diff --git a/spd/harvest/analysis.py b/spd/harvest/analysis.py index 31a4bcf2c..739e78fe9 100644 --- a/spd/harvest/analysis.py +++ b/spd/harvest/analysis.py @@ -5,13 +5,13 @@ import math from dataclasses import dataclass -from typing import Literal +from typing import Literal, cast import torch from jaxtyping import Float from torch import Tensor -from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from spd.app.backend.app_tokenizer import AppTokenizer from spd.harvest.storage import CorrelationStorage, TokenStatsStorage Metric = Literal["precision", "recall", "jaccard", "pmi"] @@ -44,10 +44,6 @@ class TokenPRLift: bottom_pmi: list[tuple[str, float]] | None -def _build_key_to_idx(component_keys: list[str]) -> dict[str, int]: - return {k: i for i, k in enumerate(component_keys)} - - def get_correlated_components( storage: CorrelationStorage, component_key: str, @@ -56,8 +52,7 @@ def get_correlated_components( largest: bool = True, ) -> list[CorrelatedComponent]: """Get top-k or bottom-k correlated components.""" - key_to_idx = _build_key_to_idx(storage.component_keys) - i = key_to_idx[component_key] + i = storage.key_to_idx[component_key] count_this = int(storage.count_i[i].item()) if count_this == 0: @@ -113,58 +108,44 @@ def get_correlated_components( def has_component(storage: CorrelationStorage, component_key: str) -> bool: """Check if a component exists in the storage.""" - key_to_idx = _build_key_to_idx(storage.component_keys) - return component_key in key_to_idx + return component_key in storage.key_to_idx def get_input_token_stats( storage: TokenStatsStorage, component_key: str, - tokenizer: PreTrainedTokenizerBase, + tok: AppTokenizer, top_k: int, ) -> TokenPRLift | None: """Compute P/R/lift/PMI for input tokens.""" - key_to_idx = _build_key_to_idx(storage.component_keys) - idx = key_to_idx[component_key] + idx = storage.key_to_idx[component_key] - result = _compute_token_stats( + return _compute_token_stats( counts=storage.input_counts[idx], totals=storage.input_totals, n_tokens=storage.n_tokens, firing_count=storage.firing_counts[idx].item(), - tokenizer=tokenizer, + tok=tok, top_k=top_k, ) - if result is None: - return None - - # Input stats don't have bottom PMI - return TokenPRLift( - top_recall=result.top_recall, - top_precision=result.top_precision, - top_lift=result.top_lift, - top_pmi=result.top_pmi, - bottom_pmi=None, - ) def get_output_token_stats( storage: TokenStatsStorage, component_key: str, - tokenizer: PreTrainedTokenizerBase, + tok: AppTokenizer, top_k: int, ) -> TokenPRLift | None: """Compute P/R/lift/PMI for output tokens.""" - key_to_idx = _build_key_to_idx(storage.component_keys) - idx = key_to_idx[component_key] + idx = storage.key_to_idx[component_key] return _compute_token_stats( counts=storage.output_counts[idx], totals=storage.output_totals, n_tokens=storage.n_tokens, firing_count=storage.firing_counts[idx].item(), - tokenizer=tokenizer, top_k=top_k, + tok=tok, ) @@ -173,7 +154,7 @@ def _compute_token_stats( totals: Float[Tensor, " vocab"], n_tokens: int, firing_count: float, - tokenizer: PreTrainedTokenizerBase, + tok: AppTokenizer, top_k: int, ) -> TokenPRLift | None: """Compute P/R/lift/PMI from count tensors.""" @@ -201,13 +182,16 @@ def get_top_k(values: Tensor, k: int, largest: bool = True) -> list[tuple[str, f top_vals, top_idx = torch.topk( masked, min(k, int(valid_mask.sum().item())), largest=largest ) - result = [] - for idx, val in zip(top_idx.tolist(), top_vals.tolist(), strict=True): + + result: list[tuple[str, float]] = [] + + for idx, val in zip( + cast(list[int], top_idx.tolist()), cast(list[float], top_vals.tolist()), strict=True + ): if val == float("-inf"): continue assert math.isfinite(val), f"Unexpected non-finite score {val} for token {idx}" - token_str = tokenizer.decode([idx]) - result.append((token_str, round(val, 3 if abs(val) < 10 else 2))) + result.append((tok.get_tok_display(idx), round(val, 3 if abs(val) < 10 else 2))) return result return TokenPRLift( diff --git a/spd/harvest/config.py b/spd/harvest/config.py new file mode 100644 index 000000000..cc01b3cd0 --- /dev/null +++ b/spd/harvest/config.py @@ -0,0 +1,100 @@ +"""Harvest configuration. + +HarvestConfig: tuning params for the harvest pipeline. +HarvestSlurmConfig: HarvestConfig + SLURM submission params. +""" + +from typing import Annotated, Any, Literal, override + +from openrouter.components import Effort +from pydantic import Field, PositiveInt + +from spd.base_config import BaseConfig +from spd.settings import DEFAULT_PARTITION_NAME +from spd.utils.wandb_utils import parse_wandb_run_path + +# -- Method-specific harvest configs ------------------------------------------ + + +class SPDHarvestConfig(BaseConfig): + type: Literal["SPDHarvestConfig"] = "SPDHarvestConfig" + wandb_path: str + activation_threshold: float = 0.0 + + @property + def id(self) -> str: + _, _, run_id = parse_wandb_run_path(self.wandb_path) + return run_id + + @override + def model_post_init(self, __context: Any) -> None: + parse_wandb_run_path(self.wandb_path) + + +class CLTHarvestConfig(BaseConfig): + type: Literal["CLTHarvestConfig"] = "CLTHarvestConfig" + + wandb_path: str + + @property + def id(self) -> str: + return "clt" + + +class MOLTHarvestConfig(BaseConfig): + type: Literal["MOLTHarvestConfig"] = "MOLTHarvestConfig" + + wandb_path: str + + @property + def id(self) -> str: + return "molt" + + +DecompositionMethodHarvestConfig = SPDHarvestConfig | CLTHarvestConfig | MOLTHarvestConfig + + +# -- Pipeline configs ---------------------------------------------------------- + + +class IntruderEvalConfig(BaseConfig): + """Config for intruder detection eval (decomposition quality, not label quality).""" + + model: str = "google/gemini-3-flash-preview" + reasoning_effort: Effort = "none" + n_real: int = 4 + n_trials: int = 10 + density_tolerance: float = 0.05 + max_concurrent: int = 50 + limit: int | None = None + cost_limit_usd: float | None = None + max_requests_per_minute: int = 500 + + +class IntruderSlurmConfig(BaseConfig): + """Config for intruder eval SLURM submission.""" + + config: IntruderEvalConfig = IntruderEvalConfig() + partition: str = DEFAULT_PARTITION_NAME + time: str = "4:00:00" + + +class HarvestConfig(BaseConfig): + method_config: Annotated[DecompositionMethodHarvestConfig, Field(discriminator="type")] + n_batches: int | Literal["whole_dataset"] = 20_000 + batch_size: int = 32 + activation_examples_per_component: int = 400 + activation_context_tokens_per_side: int = 20 + pmi_token_top_k: int = 40 + max_examples_per_batch_per_component: int = 5 + + +class HarvestSlurmConfig(BaseConfig): + """Config for harvest SLURM submission.""" + + config: HarvestConfig + n_gpus: PositiveInt = 8 + partition: str = DEFAULT_PARTITION_NAME + time: str = "12:00:00" + merge_time: str = "04:00:00" + merge_mem: str = "200G" diff --git a/spd/harvest/db.py b/spd/harvest/db.py new file mode 100644 index 000000000..a7af0caa5 --- /dev/null +++ b/spd/harvest/db.py @@ -0,0 +1,186 @@ +"""SQLite database for component-level harvest data. NFS-hosted, write-once then read-only.""" + +import sqlite3 +from collections.abc import Iterable +from pathlib import Path + +import orjson + +from spd.harvest.config import HarvestConfig +from spd.harvest.schemas import ( + ActivationExample, + ComponentData, + ComponentSummary, + ComponentTokenPMI, +) +from spd.utils.sqlite import open_nfs_sqlite + +_SCHEMA = """\ +CREATE TABLE IF NOT EXISTS components ( + component_key TEXT PRIMARY KEY, + layer TEXT NOT NULL, + component_idx INTEGER NOT NULL, + firing_density REAL NOT NULL, + mean_activations TEXT NOT NULL, + activation_examples TEXT NOT NULL, + input_token_pmi TEXT NOT NULL, + output_token_pmi TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS scores ( + component_key TEXT NOT NULL, + score_type TEXT NOT NULL, + score REAL NOT NULL, + details TEXT NOT NULL, + PRIMARY KEY (component_key, score_type) +); +""" + + +def _serialize_component( + comp: ComponentData, +) -> tuple[str, str, int, float, bytes, bytes, bytes, bytes]: + return ( + comp.component_key, + comp.layer, + comp.component_idx, + comp.firing_density, + orjson.dumps(comp.mean_activations), + orjson.dumps([ex.model_dump() for ex in comp.activation_examples]), + orjson.dumps(comp.input_token_pmi.model_dump()), + orjson.dumps(comp.output_token_pmi.model_dump()), + ) + + +def _deserialize_component(row: sqlite3.Row) -> ComponentData: + return ComponentData( + component_key=row["component_key"], + layer=row["layer"], + component_idx=row["component_idx"], + firing_density=row["firing_density"], + mean_activations=orjson.loads(row["mean_activations"]), + activation_examples=[ + ActivationExample(**ex) for ex in orjson.loads(row["activation_examples"]) + ], + input_token_pmi=ComponentTokenPMI(**orjson.loads(row["input_token_pmi"])), + output_token_pmi=ComponentTokenPMI(**orjson.loads(row["output_token_pmi"])), + ) + + +class HarvestDB: + def __init__(self, db_path: Path, readonly: bool = False) -> None: + self._conn = open_nfs_sqlite(db_path, readonly) + if not readonly: + self._conn.executescript(_SCHEMA) + + def save_component(self, comp: ComponentData) -> None: + row = _serialize_component(comp) + self._conn.execute( + "INSERT OR REPLACE INTO components VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + row, + ) + self._conn.commit() + + def save_components_iter(self, components: Iterable[ComponentData]) -> int: + """Save components from an iterable, one at a time (constant memory).""" + n = 0 + for comp in components: + self._conn.execute( + "INSERT OR REPLACE INTO components VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + _serialize_component(comp), + ) + n += 1 + self._conn.commit() + return n + + def save_config(self, config: HarvestConfig) -> None: + data = config.model_dump() + rows = [(k, orjson.dumps(v).decode()) for k, v in data.items()] + self._conn.executemany( + "INSERT OR REPLACE INTO config VALUES (?, ?)", + rows, + ) + self._conn.commit() + + def get_component(self, component_key: str) -> ComponentData | None: + row = self._conn.execute( + "SELECT * FROM components WHERE component_key = ?", + (component_key,), + ).fetchone() + if row is None: + return None + return _deserialize_component(row) + + def get_components_bulk(self, keys: list[str]) -> dict[str, ComponentData]: + if not keys: + return {} + placeholders = ",".join("?" for _ in keys) + rows = self._conn.execute( + f"SELECT * FROM components WHERE component_key IN ({placeholders})", + keys, + ).fetchall() + return {row["component_key"]: _deserialize_component(row) for row in rows} + + def get_summary(self) -> dict[str, ComponentSummary]: + rows = self._conn.execute( + "SELECT component_key, layer, component_idx, firing_density, mean_activations FROM components" + ).fetchall() + return { + row["component_key"]: ComponentSummary( + layer=row["layer"], + component_idx=row["component_idx"], + firing_density=row["firing_density"], + mean_activations=orjson.loads(row["mean_activations"]), + ) + for row in rows + } + + def get_config_dict(self) -> dict[str, object]: + rows = self._conn.execute("SELECT key, value FROM config").fetchall() + return {row["key"]: orjson.loads(row["value"]) for row in rows} + + def get_activation_threshold(self) -> float: + row = self._conn.execute( + "SELECT value FROM config WHERE key = 'activation_threshold'" + ).fetchone() + assert row is not None, "activation_threshold not found in config table" + return orjson.loads(row["value"]) + + def has_data(self) -> bool: + row = self._conn.execute("SELECT EXISTS(SELECT 1 FROM components LIMIT 1)").fetchone() + assert row is not None + return bool(row[0]) + + def get_component_count(self) -> int: + row = self._conn.execute("SELECT COUNT(*) FROM components").fetchone() + assert row is not None + return row[0] + + def get_all_components(self) -> list[ComponentData]: + """Load all components.""" + rows = self._conn.execute("SELECT * FROM components").fetchall() + return [_deserialize_component(row) for row in rows] + + # -- Scores (e.g. intruder eval) ------------------------------------------ + + def save_score(self, component_key: str, score_type: str, score: float, details: str) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO scores VALUES (?, ?, ?, ?)", + (component_key, score_type, score, details), + ) + self._conn.commit() + + def get_scores(self, score_type: str) -> dict[str, float]: + rows = self._conn.execute( + "SELECT component_key, score FROM scores WHERE score_type = ?", + (score_type,), + ).fetchall() + return {row["component_key"]: row["score"] for row in rows} + + def close(self) -> None: + self._conn.close() diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index 0c1d5daf1..f1d5b9bd7 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -1,9 +1,9 @@ -"""Harvest token stats and activation contexts for autointerp. +"""Generic harvest pipeline: single-pass collection of component statistics. Collects per-component statistics in a single pass over the data: - Input/output token PMI (pointwise mutual information) - Activation examples with context windows -- Firing counts and CI sums +- Firing counts and activation sums - Component co-occurrence counts Performance (SimpleStories, 600M tokens, batch_size=256): @@ -12,340 +12,143 @@ """ import itertools -import json import time -from dataclasses import asdict, dataclass +from collections.abc import Callable from pathlib import Path +from typing import Any import torch import tqdm -from jaxtyping import Float -from torch import Tensor +from torch.utils.data import DataLoader -from spd.data import train_loader_and_tokenizer -from spd.harvest.lib.harvester import Harvester, HarvesterState -from spd.harvest.schemas import ( - ActivationExample, - ComponentData, - ComponentSummary, - ComponentTokenPMI, -) -from spd.harvest.storage import CorrelationStorage, TokenStatsStorage +from spd.harvest.config import HarvestConfig +from spd.harvest.harvester import Harvester +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import HarvestBatch from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo -from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import bf16_autocast, extract_batch_data +from spd.utils.general_utils import bf16_autocast -def _compute_u_norms(model: ComponentModel) -> dict[str, Float[Tensor, " C"]]: - """Compute ||U[c,:]|| for each component c in each layer. - - Component activations (v_i^T @ a) have a scale invariance: scaling V by α and U by 1/α - leaves the weight matrix unchanged but scales component activations by α. To make component - activations reflect actual output contribution, we multiply by the U row norms. - This gives a value proportional to the magnitude of the component's output vector. - """ - u_norms: dict[str, Float[Tensor, " C"]] = {} - for layer_name, component in model.components.items(): - # U has shape (C, d_out) for LinearComponents - u_norms[layer_name] = component.U.norm(dim=1) # [C] - return u_norms - - -def _normalize_component_acts( - component_acts: dict[str, Float[Tensor, "B S C"]], - u_norms: dict[str, Float[Tensor, " C"]], -) -> dict[str, Float[Tensor, "B S C"]]: - """Normalize component activations by U column norms (output magnitude).""" - normalized = {} - for layer_name, acts in component_acts.items(): - norms = u_norms[layer_name].to(acts.device) - normalized[layer_name] = acts * norms - return normalized - - -@dataclass -class HarvestConfig: - wandb_path: str - n_batches: int | None - batch_size: int - ci_threshold: float - activation_examples_per_component: int - activation_context_tokens_per_side: int - pmi_token_top_k: int - - -@dataclass -class HarvestResult: - """Result of harvest containing components, correlations, and token stats.""" - - components: list[ComponentData] - correlations: CorrelationStorage - token_stats: TokenStatsStorage - config: HarvestConfig - - def save(self, activation_contexts_dir: Path, correlations_dir: Path) -> None: - """Save harvest result to disk.""" - # Save activation contexts (JSONL) - activation_contexts_dir.mkdir(parents=True, exist_ok=True) - - config_path = activation_contexts_dir / "config.json" - config_path.write_text(json.dumps(asdict(self.config), indent=2)) - - components_path = activation_contexts_dir / "components.jsonl" - with open(components_path, "w") as f: - for comp in self.components: - f.write(json.dumps(asdict(comp)) + "\n") - logger.info(f"Saved {len(self.components)} components to {components_path}") - - # Save lightweight summary for fast /summary endpoint - summaries = { - comp.component_key: ComponentSummary( - layer=comp.layer, - component_idx=comp.component_idx, - mean_ci=comp.mean_ci, - ) - for comp in self.components - } - summary_path = activation_contexts_dir / "summary.json" - ComponentSummary.save_all(summaries, summary_path) - logger.info(f"Saved summary to {summary_path}") - - # Save correlations (.pt) - self.correlations.save(correlations_dir / "component_correlations.pt") - - # Save token stats (.pt) - self.token_stats.save(correlations_dir / "token_stats.pt") - - @staticmethod - def load_components(activation_contexts_dir: Path) -> list[ComponentData]: - """Load components from disk.""" - assert activation_contexts_dir.exists(), f"No harvest found at {activation_contexts_dir}" - - components_path = activation_contexts_dir / "components.jsonl" - components = [] - with open(components_path) as f: - for line in f: - data = json.loads(line) - data["activation_examples"] = [ - ActivationExample(**ex) for ex in data["activation_examples"] - ] - data["input_token_pmi"] = ComponentTokenPMI(**data["input_token_pmi"]) - data["output_token_pmi"] = ComponentTokenPMI(**data["output_token_pmi"]) - components.append(ComponentData(**data)) - - return components - - -def _build_harvest_result( - harvester: Harvester, +def harvest( + layers: list[tuple[str, int]], + vocab_size: int, + dataloader: DataLoader[Any], + harvest_fn: Callable[[torch.Tensor], HarvestBatch], config: HarvestConfig, -) -> HarvestResult: - """Build HarvestResult from a harvester.""" - logger.info("Building component results...") - components = harvester.build_results(pmi_top_k_tokens=config.pmi_token_top_k) - logger.info(f"Built {len(components)} components (skipped components with no firings)") - - # Build component keys list (same ordering as tensors) - component_keys = [ - f"{layer}:{c}" - for layer in harvester.layer_names - for c in range(harvester.c_per_layer[layer]) - ] - - correlations = CorrelationStorage( - component_keys=component_keys, - count_i=harvester.firing_counts.long().cpu(), - count_ij=harvester.count_ij.long().cpu(), - count_total=harvester.total_tokens_processed, - ) - - token_stats = TokenStatsStorage( - component_keys=component_keys, - vocab_size=harvester.vocab_size, - n_tokens=harvester.total_tokens_processed, - input_counts=harvester.input_token_counts.cpu(), - input_totals=harvester.input_token_totals.float().cpu(), - output_counts=harvester.output_token_prob_mass.cpu(), - output_totals=harvester.output_token_prob_totals.cpu(), - firing_counts=harvester.firing_counts.cpu(), - ) - - return HarvestResult( - components=components, - correlations=correlations, - token_stats=token_stats, - config=config, - ) - - -def harvest_activation_contexts( - config: HarvestConfig, - activation_contexts_dir: Path, - correlations_dir: Path, - rank: int | None = None, - world_size: int | None = None, + output_dir: Path, + *, + rank_world_size: tuple[int, int] | None, + device: torch.device | None = None, ) -> None: - """Single-pass harvest of token stats, activation contexts, and correlations. + """Single-pass harvest for any decomposition method. Args: + harvest_fn: Converts a raw dataloader batch into a HarvestBatch. + Responsible for moving data to the correct device. + layers: List of (layer_name, n_components) pairs. + vocab_size: Vocabulary size for token stats. + dataloader: Iterable yielding raw batches. config: Harvest configuration. - activation_contexts_dir: Directory to save activation contexts. - correlations_dir: Directory to save correlations. + output_dir: Directory to save harvest outputs. rank: Worker rank for parallel execution (0 to world_size-1). - world_size: Total number of workers. If specified with rank, only processes - batches where batch_idx % world_size == rank. + world_size: Total number of workers. + device: Device for accumulator tensors. """ - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - - device = torch.device(get_device()) - logger.info(f"Loading model on {device}") - - run_info = SPDRunInfo.from_path(config.wandb_path) - model = ComponentModel.from_run_info(run_info).to(device) - model.eval() - - spd_config = run_info.config - train_loader, tokenizer = train_loader_and_tokenizer(spd_config, config.batch_size) - - layer_names = list(model.target_module_paths) - vocab_size = tokenizer.vocab_size - assert isinstance(vocab_size, int) - - # Precompute U norms for normalizing component activations - u_norms = _compute_u_norms(model) + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") harvester = Harvester( - layer_names=layer_names, - c_per_layer=model.module_to_c, + layers=layers, vocab_size=vocab_size, - ci_threshold=config.ci_threshold, max_examples_per_component=config.activation_examples_per_component, context_tokens_per_side=config.activation_context_tokens_per_side, + max_examples_per_batch_per_component=config.max_examples_per_batch_per_component, device=device, ) - train_iter = iter(train_loader) + train_iter = iter(dataloader) batches_processed = 0 last_log_time = time.time() - batch_range = range(config.n_batches) if config.n_batches is not None else itertools.count() - for batch_idx in tqdm.tqdm(batch_range, desc="Harvesting", disable=rank is not None): + match config.n_batches: + case int(n_batches): + batch_range = range(n_batches) + case "whole_dataset": + batch_range = itertools.count() + + for batch_idx in tqdm.tqdm(batch_range, desc="Harvesting", disable=rank_world_size is not None): try: - batch_data = extract_batch_data(next(train_iter)) + batch_item = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break - # Skip batches not assigned to this rank - if world_size is not None and batch_idx % world_size != rank: - continue + if rank_world_size is not None: + r, w = rank_world_size + if batch_idx % w != r: + continue - batch = batch_data.to(device) with torch.no_grad(), bf16_autocast(): - out = model(batch, cache_type="input") - probs = torch.softmax(out.output, dim=-1) - - ci_dict = model.calc_causal_importances( - pre_weight_acts=out.cache, - detach_inputs=True, - sampling=spd_config.sampling, - ).lower_leaky - - ci: Float[Tensor, "B S n_comp"] = torch.cat( - [ci_dict[layer] for layer in layer_names], dim=2 - ) - expected_n_comp = sum(model.module_to_c[layer] for layer in layer_names) - assert ci.shape[2] == expected_n_comp - - component_acts = model.get_all_component_acts(out.cache) - normalized_acts = _normalize_component_acts(component_acts, u_norms) - subcomp_acts: Float[Tensor, "B S n_comp"] = torch.cat( - [normalized_acts[layer] for layer in layer_names], - dim=2, - ) - - harvester.process_batch(batch, ci, probs, subcomp_acts) + hb = harvest_fn(batch_item) + harvester.process_batch(hb.tokens, hb.firings, hb.activations, hb.output_probs) batches_processed += 1 now = time.time() - if rank is not None and now - last_log_time >= 10: - logger.info(f"[Worker {rank}] {batches_processed} batches") - last_log_time = now + + if rank_world_size is not None: + r, w = rank_world_size + if now - last_log_time >= 10: + logger.info(f"[Worker {r}] {batches_processed} batches") + last_log_time = now logger.info( - f"{'[Worker ' + str(rank) + '] ' if rank is not None else ''}" + f"{'[Worker ' + str(rank_world_size[0]) + '] ' if rank_world_size is not None else ''}" f"Processing complete. {batches_processed} batches, " f"{harvester.total_tokens_processed:,} tokens" ) - # Save results (with rank suffix if distributed) - if rank is not None: - # Distributed: save worker state - state = harvester.get_state() - state_dir = activation_contexts_dir.parent / "worker_states" + if rank_world_size is not None: + r, w = rank_world_size + state_dir = output_dir / "worker_states" state_dir.mkdir(parents=True, exist_ok=True) - state_path = state_dir / f"worker_{rank}.pt" - torch.save(state, state_path) - logger.info(f"[Worker {rank}] Saved state to {state_path}") + state_path = state_dir / f"worker_{r}.pt" + harvester.save(state_path) + logger.info(f"[Worker {r}] Saved state to {state_path}") else: - # Single GPU: save full result - result = _build_harvest_result(harvester, config) - result.save(activation_contexts_dir, correlations_dir) - logger.info(f"Saved results to {activation_contexts_dir} and {correlations_dir}") + HarvestRepo.save_results(harvester, config, output_dir) + logger.info(f"Saved results to {output_dir}") -def merge_activation_contexts(wandb_path: str) -> None: +def merge_harvest(output_dir: Path, config: HarvestConfig) -> None: """Merge partial harvest results from parallel workers. - Looks for worker_*.pt state files and merges them into final harvest results. - - Uses streaming merge to avoid OOM - loads one file at a time instead of all at once. + Looks for worker_*.pt state files in output_dir/worker_states/ and merges them + into final harvest results written to output_dir. """ - from spd.harvest.schemas import get_activation_contexts_dir, get_correlations_dir - from spd.utils.wandb_utils import parse_wandb_run_path - - _, _, run_id = parse_wandb_run_path(wandb_path) - activation_contexts_dir = get_activation_contexts_dir(run_id) - correlations_dir = get_correlations_dir(run_id) - state_dir = activation_contexts_dir.parent / "worker_states" + state_dir = output_dir / "worker_states" - # Find all worker state files worker_files = sorted(state_dir.glob("worker_*.pt")) assert worker_files, f"No worker state files found in {state_dir}" logger.info(f"Found {len(worker_files)} worker state files to merge") - # Load first file to initialize merged state - logger.info(f"Loading worker 0: {worker_files[0].name}") - merged_state: HarvesterState = torch.load(worker_files[0], weights_only=False) - logger.info(f"Loaded worker 0: {merged_state.total_tokens_processed:,} tokens") + first_worker_file, *rest_worker_files = worker_files - # Stream remaining files one at a time - for worker_file in tqdm.tqdm(worker_files[1:], desc="Merging worker states"): - state = torch.load(worker_file, weights_only=False) - merged_state.merge_into(state) - # state will be garbage collected here before loading the next file + logger.info(f"Loading worker 0: {first_worker_file.name}") + harvester = Harvester.load(first_worker_file, device=torch.device("cpu")) + logger.info(f"Loaded worker 0: {harvester.total_tokens_processed:,} tokens") - logger.info(f"Merge complete. Total tokens: {merged_state.total_tokens_processed:,}") + for worker_file in tqdm.tqdm(rest_worker_files, desc="Merging worker states"): + other = Harvester.load(worker_file, device=torch.device("cpu")) + harvester.merge(other) + del other - # Build harvester from merged state and generate results - harvester = Harvester.from_state(merged_state, torch.device("cpu")) - - # Load config from merged state (all workers use same config) - config = HarvestConfig( - wandb_path=wandb_path, - n_batches=None, # Not applicable for merge - batch_size=0, # Not applicable for merge - ci_threshold=merged_state.ci_threshold, - activation_examples_per_component=merged_state.max_examples_per_component, - activation_context_tokens_per_side=merged_state.context_tokens_per_side, - pmi_token_top_k=40, # Standard value - ) + logger.info(f"Merge complete. Total tokens: {harvester.total_tokens_processed:,}") - result = _build_harvest_result(harvester, config) - result.save(activation_contexts_dir, correlations_dir) - logger.info(f"Saved merged results to {activation_contexts_dir} and {correlations_dir}") + HarvestRepo.save_results(harvester, config, output_dir) + db_path = output_dir / "harvest.db" + assert db_path.exists() and db_path.stat().st_size > 0, f"Merge output is empty: {db_path}" + logger.info(f"Saved merged results to {output_dir}") - # Clean up worker state files for worker_file in worker_files: worker_file.unlink() + state_dir.rmdir() logger.info(f"Deleted {len(worker_files)} worker state files") diff --git a/spd/harvest/harvest_fn/__init__.py b/spd/harvest/harvest_fn/__init__.py new file mode 100644 index 000000000..de57541ba --- /dev/null +++ b/spd/harvest/harvest_fn/__init__.py @@ -0,0 +1,28 @@ +import torch + +from spd.adapters.base import DecompositionAdapter +from spd.adapters.spd import SPDAdapter +from spd.harvest.config import ( + CLTHarvestConfig, + DecompositionMethodHarvestConfig, + MOLTHarvestConfig, + SPDHarvestConfig, +) +from spd.harvest.harvest_fn.base import HarvestFn +from spd.harvest.harvest_fn.spd import SPDHarvestFn + + +def make_harvest_fn( + device: torch.device, + method_config: DecompositionMethodHarvestConfig, + adapter: DecompositionAdapter, +) -> HarvestFn: + match method_config, adapter: + case SPDHarvestConfig(), SPDAdapter(): + return SPDHarvestFn(method_config, adapter, device=device) + case CLTHarvestConfig(), _: + raise NotImplementedError("CLT harvest not implemented yet") + case MOLTHarvestConfig(), _: + raise NotImplementedError("MOLT harvest not implemented yet") + case _, _: + raise ValueError(f"Unsupported method config: {method_config} and adapter: {adapter}") diff --git a/spd/harvest/harvest_fn/base.py b/spd/harvest/harvest_fn/base.py new file mode 100644 index 000000000..a618d02e2 --- /dev/null +++ b/spd/harvest/harvest_fn/base.py @@ -0,0 +1,9 @@ +from typing import Protocol + +import torch + +from spd.harvest.schemas import HarvestBatch + + +class HarvestFn(Protocol): + def __call__(self, batch_item: torch.Tensor) -> HarvestBatch: ... diff --git a/spd/harvest/harvest_fn/spd.py b/spd/harvest/harvest_fn/spd.py new file mode 100644 index 000000000..861388066 --- /dev/null +++ b/spd/harvest/harvest_fn/spd.py @@ -0,0 +1,56 @@ +from typing import override + +import torch + +from spd.adapters.spd import SPDAdapter +from spd.harvest.config import SPDHarvestConfig +from spd.harvest.harvest_fn.base import HarvestFn +from spd.harvest.schemas import HarvestBatch +from spd.utils.general_utils import extract_batch_data + + +class SPDHarvestFn(HarvestFn): + def __init__(self, config: SPDHarvestConfig, adapter: SPDAdapter, device: torch.device): + self._adapter = adapter + self._activation_threshold = config.activation_threshold + self._device = device + + self._adapter.component_model.to(device).eval() + self._u_norms = { + layer_name: component.U.norm(dim=1).to(device) + for layer_name, component in self._adapter.component_model.components.items() + } + + @override + def __call__(self, batch_item: torch.Tensor) -> HarvestBatch: + model = self._adapter.component_model + + batch = extract_batch_data(batch_item).to(self._device) + + out = model(batch, cache_type="input") + probs = torch.softmax(out.output, dim=-1) + + ci_dict = model.calc_causal_importances( + pre_weight_acts=out.cache, + detach_inputs=True, + sampling=self._adapter.spd_run_info.config.sampling, + ).lower_leaky + + per_layer_acts = model.get_all_component_acts(out.cache) + + firings = {layer: ci > self._activation_threshold for layer, ci in ci_dict.items()} + + activations = { + layer: { + "causal_importance": ci_dict[layer], + "component_activation": per_layer_acts[layer] * self._u_norms[layer], + } + for layer in model.target_module_paths + } + + return HarvestBatch( + tokens=batch, + firings=firings, + activations=activations, + output_probs=probs, + ) diff --git a/spd/harvest/harvester.py b/spd/harvest/harvester.py new file mode 100644 index 000000000..de54198c3 --- /dev/null +++ b/spd/harvest/harvester.py @@ -0,0 +1,401 @@ +"""Harvester for collecting component statistics in a single pass. + +All accumulator state lives as tensors on `device` (GPU during harvesting, CPU during merge). +""" + +from collections import defaultdict +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +import torch +import tqdm +from einops import einsum, rearrange, reduce, repeat +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.harvest.reservoir import ( + WINDOW_PAD_SENTINEL, + ActivationExamplesReservoir, + ActivationWindows, +) +from spd.harvest.sampling import sample_at_most_n_per_group, top_k_pmi +from spd.harvest.schemas import ComponentData, ComponentTokenPMI +from spd.log import logger + + +def extract_padding_firing_windows( + batch: Int[Tensor, "B S"], + firings: Bool[Tensor, "B S C"], + activations: dict[str, Float[Tensor, "B S C"]], + max_examples_per_batch_per_component: int, + context_tokens_per_side: int, +) -> ActivationWindows | None: + batch_idx, seq_idx, comp_idx = torch.where(firings) + if len(batch_idx) == 0: + return None + + keep = sample_at_most_n_per_group(comp_idx, max_examples_per_batch_per_component) + batch_idx, seq_idx, comp_idx = batch_idx[keep], seq_idx[keep], comp_idx[keep] + + seq_len = batch.shape[1] + offsets = torch.arange( + -context_tokens_per_side, context_tokens_per_side + 1, device=batch.device + ) + window_size = offsets.shape[0] + assert window_size == 2 * context_tokens_per_side + 1 + + window_positions: Int[Tensor, "n_firings window_size"] + window_positions = seq_idx.unsqueeze(1) + offsets.unsqueeze(0) + + in_bounds = (window_positions >= 0) & (window_positions < seq_len) + clamped = window_positions.clamp(0, seq_len - 1) + + batch_idx_rep = repeat(batch_idx, "n_firings -> n_firings window_size", window_size=window_size) + c_idx_rep = repeat(comp_idx, "n_firings -> n_firings window_size", window_size=window_size) + + token_windows = batch[batch_idx_rep, clamped] + token_windows[~in_bounds] = WINDOW_PAD_SENTINEL + + firing_windows = firings[batch_idx_rep, clamped, c_idx_rep] + firing_windows[~in_bounds] = False + + activation_windows = {} + for act_type, act in activations.items(): + activation_windows[act_type] = act[batch_idx_rep, clamped, c_idx_rep] + activation_windows[act_type][~in_bounds] = 0.0 + + return ActivationWindows( + component_idx=comp_idx, + token_windows=token_windows, + firing_windows=firing_windows, + activation_windows=activation_windows, + ) + + +class Harvester: + """Accumulates component statistics in a single pass over data. + + All mutable state is stored as tensors on `device`. Workers on GPU accumulate + into GPU tensors; the merge job reconstructs on CPU. + """ + + def __init__( + self, + layers: list[tuple[str, int]], + vocab_size: int, + max_examples_per_component: int, + context_tokens_per_side: int, + max_examples_per_batch_per_component: int, + device: torch.device, + ): + self.layers = layers + self.vocab_size = vocab_size + self.max_examples_per_component = max_examples_per_component + self.context_tokens_per_side = context_tokens_per_side + self.max_examples_per_batch_per_component = max_examples_per_batch_per_component + self.device = device + + self.layer_offsets: dict[str, int] = {} + offset = 0 + for layer, c in layers: + self.layer_offsets[layer] = offset + offset += c + + n_components = offset + + window_size = 2 * context_tokens_per_side + 1 + + # Per-component firing stats + self.firing_counts = torch.zeros(n_components, device=device) + self.activation_sums = defaultdict[str, Tensor]( + lambda: torch.zeros(n_components, device=device) + ) + self.cooccurrence_counts: Float[Tensor, "C C"] = torch.zeros( + n_components, n_components, device=device, dtype=torch.float32 + ) + + # Per-(component, token) stats for PMI computation + # input: hard token counts at positions where component fires + # output: predicted probability mass at positions where component fires + self.input_cooccurrence: Int[Tensor, "C vocab"] = torch.zeros( + n_components, vocab_size, device=device, dtype=torch.long + ) + self.input_marginals: Int[Tensor, " vocab"] = torch.zeros( + vocab_size, device=device, dtype=torch.long + ) + self.output_cooccurrence: Float[Tensor, "C vocab"] = torch.zeros( + n_components, vocab_size, device=device + ) + self.output_marginals: Float[Tensor, " vocab"] = torch.zeros(vocab_size, device=device) + + self.reservoir = ActivationExamplesReservoir.create( + n_components, max_examples_per_component, window_size, device + ) + self.total_tokens_processed = 0 + + @property + def layer_names(self) -> list[str]: + return [layer for layer, _ in self.layers] + + @property + def c_per_layer(self) -> dict[str, int]: + return {layer: c for layer, c in self.layers} + + @property + def component_keys(self) -> list[str]: + return [f"{layer}:{i}" for layer, c in self.layers for i in range(c)] + + # -- Batch processing -------------------------------------------------- + + def process_batch( + self, + batch: Int[Tensor, "B S"], + firings: dict[str, Bool[Tensor, "B S C"]], + activations: dict[str, dict[str, Float[Tensor, "B S C"]]], + output_probs: Float[Tensor, "B S V"], + ) -> None: + self.total_tokens_processed += batch.numel() + + tokens_flat = rearrange(batch, "b s -> (b s)") + probs_flat = rearrange(output_probs, "b s v -> (b s) v") + + firings_cat = torch.cat([firings[layer] for layer in self.layer_names], dim=-1) + firings_flat = rearrange(firings_cat, "b s lc -> (b s) lc") + + act_types = list(activations[self.layer_names[0]].keys()) + activations_cat: dict[str, Float[Tensor, "B S LC"]] = {} + for act_type in act_types: + activations_cat[act_type] = torch.cat( + [activations[layer][act_type] for layer in self.layer_names], dim=-1 + ) + + self.firing_counts += reduce(firings_cat, "b s lc -> lc", "sum") + + for act_type, act in activations_cat.items(): + self.activation_sums[act_type] += reduce(act, "b s lc -> lc", "sum") + + firings_float = firings_flat.float() + self.cooccurrence_counts += einsum(firings_float, firings_float, "S c1, S c2 -> c1 c2") + self._accumulate_token_stats(tokens_flat, probs_flat, firings_float) + self._collect_activation_examples(batch, firings_cat, activations_cat) + + def _accumulate_token_stats( + self, + tokens_flat: Int[Tensor, " S"], + probs_flat: Float[Tensor, "S vocab"], + firing_flat: Float[Tensor, "S LC"], + ) -> None: + n_components = firing_flat.shape[1] + token_indices = repeat(tokens_flat, "S -> lc S", lc=n_components) + + # use scatter_add for inputs because inputs are one-hot / token indices + self.input_cooccurrence.scatter_add_( + dim=1, index=token_indices, src=rearrange(firing_flat, "S lc -> lc S").long() + ) + self.input_marginals.scatter_add_( + dim=0, + index=tokens_flat, + src=torch.ones(tokens_flat.shape[0], device=self.device, dtype=torch.long), + ) + + # however, for outputs we need to accumulate probability mass over vocab + self.output_cooccurrence += einsum(firing_flat, probs_flat, "S lc, S v -> lc v") + self.output_marginals += reduce(probs_flat, "S v -> v", "sum") + + def _collect_activation_examples( + self, + batch: Int[Tensor, "B S"], + firings: Bool[Tensor, "B S LC"], + activations: dict[str, Float[Tensor, "B S LC"]], + ) -> None: + res = extract_padding_firing_windows( + batch, + firings, + activations, + self.max_examples_per_batch_per_component, + self.context_tokens_per_side, + ) + if res is not None: + self.reservoir.add(res) + + def save(self, path: Path) -> None: + data: dict[str, object] = { + "layers": self.layers, + "vocab_size": self.vocab_size, + "max_examples_per_component": self.max_examples_per_component, + "context_tokens_per_side": self.context_tokens_per_side, + "max_examples_per_batch_per_component": self.max_examples_per_batch_per_component, + "total_tokens_processed": self.total_tokens_processed, + "reservoir": self.reservoir.state_dict(), + "firing_counts": self.firing_counts.cpu(), + "activation_sums": { + act_type: self.activation_sums[act_type].cpu() for act_type in self.activation_sums + }, + "cooccurrence_counts": self.cooccurrence_counts.cpu(), + "input_cooccurrence": self.input_cooccurrence.cpu(), + "input_marginals": self.input_marginals.cpu(), + "output_cooccurrence": self.output_cooccurrence.cpu(), + "output_marginals": self.output_marginals.cpu(), + } + torch.save(data, path) + + @staticmethod + def load(path: Path, device: torch.device) -> "Harvester": + d: dict[str, Any] = torch.load(path, weights_only=False) + h = Harvester( + layers=d["layers"], + vocab_size=d["vocab_size"], + max_examples_per_component=d["max_examples_per_component"], + context_tokens_per_side=d["context_tokens_per_side"], + max_examples_per_batch_per_component=d.get("max_examples_per_batch_per_component", 5), + device=device, + ) + h.total_tokens_processed = d["total_tokens_processed"] + h.firing_counts = d["firing_counts"].to(device) + h.activation_sums = {k: v.to(device) for k, v in d["activation_sums"].items()} + h.cooccurrence_counts = d["cooccurrence_counts"].to(device) + h.input_cooccurrence = d["input_cooccurrence"].to(device) + h.input_marginals = d["input_marginals"].to(device) + h.output_cooccurrence = d["output_cooccurrence"].to(device) + h.output_marginals = d["output_marginals"].to(device) + h.reservoir = ActivationExamplesReservoir.from_state_dict(d["reservoir"], device) + return h + + def merge(self, other: "Harvester") -> None: + assert other.layer_names == self.layer_names + assert other.c_per_layer == self.c_per_layer + assert other.vocab_size == self.vocab_size + + self.firing_counts += other.firing_counts + for act_type in self.activation_sums: + self.activation_sums[act_type] += other.activation_sums[act_type] + self.cooccurrence_counts += other.cooccurrence_counts + self.input_cooccurrence += other.input_cooccurrence + self.input_marginals += other.input_marginals + self.output_cooccurrence += other.output_cooccurrence + self.output_marginals += other.output_marginals + self.total_tokens_processed += other.total_tokens_processed + + self.reservoir.merge(other.reservoir) + + # -- Result building --------------------------------------------------- + + def build_results(self, pmi_top_k_tokens: int) -> Iterator[ComponentData]: + """Yield ComponentData objects one at a time (constant memory).""" + logger.info(" Moving tensors to CPU...") + mean_activations = { + act_type: (self.activation_sums[act_type] / self.total_tokens_processed).cpu() + for act_type in self.activation_sums + } + firing_counts = self.firing_counts.cpu() + input_cooccurrence = self.input_cooccurrence.cpu() + input_marginals = self.input_marginals.cpu() + output_cooccurrence = self.output_cooccurrence.cpu() + output_marginals = self.output_marginals.cpu() + + reservoir_cpu = self.reservoir.to(torch.device("cpu")) + + _log_base_rate_summary(firing_counts, input_marginals) + + for layer, layer_c in self.layers: + offset = self.layer_offsets[layer] + + for component_idx in tqdm.tqdm(range(layer_c), desc="Building components"): + flat_idx = offset + component_idx + + n_firings = float(firing_counts[flat_idx]) + if n_firings == 0: + continue + + yield ComponentData( + component_key=f"{layer}:{component_idx}", + layer=layer, + component_idx=component_idx, # as in, the index of the component within the layer + firing_density=n_firings / self.total_tokens_processed, + mean_activations={ + act_type: float(mean_activations[act_type][flat_idx].item()) + for act_type in mean_activations + }, + activation_examples=list(reservoir_cpu.examples(flat_idx)), + input_token_pmi=_compute_token_pmi( + input_cooccurrence[flat_idx], + input_marginals, + n_firings, + self.total_tokens_processed, + pmi_top_k_tokens, + ), + output_token_pmi=_compute_token_pmi( + output_cooccurrence[flat_idx], + output_marginals, + n_firings, + self.total_tokens_processed, + pmi_top_k_tokens, + ), + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_base_rate_summary(firing_counts: Tensor, input_marginals: Tensor) -> None: + active_counts = firing_counts[firing_counts > 0] + if len(active_counts) == 0: + logger.info(" WARNING: No components fired above threshold!") + return + + sorted_counts = active_counts.sort().values + n_active = len(active_counts) + logger.info("\n === Base Rate Summary ===") + logger.info(f" Components with firings: {n_active} / {len(firing_counts)}") + logger.info( + f" Firing counts - min: {int(sorted_counts[0])}, " + f"median: {int(sorted_counts[n_active // 2])}, " + f"max: {int(sorted_counts[-1])}" + ) + + LOW_FIRING_THRESHOLD = 100 + n_sparse = int((active_counts < LOW_FIRING_THRESHOLD).sum()) + if n_sparse > 0: + logger.info( + f" WARNING: {n_sparse} components have <{LOW_FIRING_THRESHOLD} firings " + f"(stats may be noisy)" + ) + + active_tokens = input_marginals[input_marginals > 0] + sorted_token_counts = active_tokens.sort().values + n_tokens = len(active_tokens) + logger.info( + f" Tokens seen: {n_tokens} unique, " + f"occurrences - min: {int(sorted_token_counts[0])}, " + f"median: {int(sorted_token_counts[n_tokens // 2])}, " + f"max: {int(sorted_token_counts[-1])}" + ) + + RARE_TOKEN_THRESHOLD = 10 + n_rare = int((active_tokens < RARE_TOKEN_THRESHOLD).sum()) + if n_rare > 0: + logger.info( + f" Note: {n_rare} tokens have <{RARE_TOKEN_THRESHOLD} occurrences " + f"(high precision/recall with these may be spurious)" + ) + logger.info("") + + +def _compute_token_pmi( + token_mass_for_component: Tensor, + token_mass_totals: Tensor, + component_firing_count: float, + total_tokens: int, + top_k: int, +) -> ComponentTokenPMI: + top, bottom = top_k_pmi( + cooccurrence_counts=token_mass_for_component, + marginal_counts=token_mass_totals, + target_count=component_firing_count, + total_count=total_tokens, + top_k=top_k, + ) + return ComponentTokenPMI(top=top, bottom=bottom) diff --git a/spd/harvest/intruder.py b/spd/harvest/intruder.py new file mode 100644 index 000000000..4fb9b2052 --- /dev/null +++ b/spd/harvest/intruder.py @@ -0,0 +1,243 @@ +"""Intruder detection scoring. + +Tests whether a component's activating examples are coherent by asking an LLM +to identify an "intruder" example drawn from a different component. No labels needed. + +Based on: "Evaluating SAE interpretability without explanations" (2025). + +Usage: + python -m spd.autointerp.scoring.scripts.run_intruder --limit 100 +""" + +import bisect +import json +import random +from collections import defaultdict +from dataclasses import asdict, dataclass + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.utils import delimit_tokens +from spd.autointerp.llm_api import LLMError, LLMJob, LLMResult, map_llm_calls +from spd.harvest.config import IntruderEvalConfig +from spd.harvest.db import HarvestDB +from spd.harvest.schemas import ActivationExample, ComponentData +from spd.log import logger + +INTRUDER_SCHEMA = { + "type": "object", + "properties": { + "intruder": { + "type": "integer", + "description": "1-indexed example number of the intruder", + }, + "reasoning": {"type": "string", "description": "Brief explanation"}, + }, + "required": ["intruder", "reasoning"], +} + + +@dataclass +class IntruderTrial: + correct_answer: int + predicted: int + is_correct: bool + + +@dataclass +class IntruderResult: + component_key: str + score: float + trials: list[IntruderTrial] + n_errors: int + + +class DensityIndex: + """Index of components sorted by bold density for efficient similar-density lookup.""" + + def __init__(self, components: list[ComponentData], min_examples: int) -> None: + eligible = [c for c in components if len(c.activation_examples) >= min_examples] + eligible.sort(key=lambda c: c.firing_density) + self._components = eligible + self._densities = [c.firing_density for c in eligible] + self._key_to_idx = {c.component_key: i for i, c in enumerate(self._components)} + + def sample_similar( + self, + target: ComponentData, + rng: random.Random, + tolerance: float, + ) -> ComponentData: + """Sample a different component with similar bold density.""" + assert target.component_key in self._key_to_idx + target_density = self._densities[self._key_to_idx[target.component_key]] + + lo = bisect.bisect_left(self._densities, target_density - tolerance) + hi = bisect.bisect_right(self._densities, target_density + tolerance) + + candidates = [ + self._components[i] + for i in range(lo, hi) + if self._components[i].component_key != target.component_key + ] + + # Widen search if no candidates in tolerance band + if not candidates: + candidates = [c for c in self._components if c.component_key != target.component_key] + + return rng.choice(candidates) + + +def _format_example( + example: ActivationExample, + app_tok: AppTokenizer, +) -> str: + spans = app_tok.get_spans(example.token_ids) + tokens = [(span, firing) for span, firing in zip(spans, example.firings, strict=True)] + return delimit_tokens(tokens) + + +def _sample_intruder( + target: ComponentData, + density_index: DensityIndex, + rng: random.Random, + density_tolerance: float, +) -> ActivationExample: + """Sample an intruder example from a component with similar bold density.""" + donor = density_index.sample_similar(target, rng, tolerance=density_tolerance) + return rng.choice(donor.activation_examples) + + +def _build_prompt( + real_examples: list[ActivationExample], + intruder: ActivationExample, + intruder_position: int, + app_tok: AppTokenizer, +) -> str: + all_examples = list(real_examples) + all_examples.insert(intruder_position, intruder) + n_total = len(all_examples) + n_real = len(real_examples) + + examples_text = "" + for i, ex in enumerate(all_examples): + examples_text += f"Example {i + 1}: {_format_example(ex, app_tok)}\n\n" + + return f"""\ +Below are {n_total} text snippets from a neural network's training data. {n_real} come from contexts \ +where the SAME component fires strongly. One is an INTRUDER from a DIFFERENT component. + +Tokens between <> are where the component fires most strongly. + +{examples_text}\ +Which example is the intruder? Identify what pattern the majority share, then pick \ +the example that does not fit. + +Respond with the intruder example number (1-{n_total}) and brief reasoning.""" + + +@dataclass +class _TrialGroundTruth: + component_key: str + correct_answer: int + + +async def run_intruder_scoring( + components: list[ComponentData], + model: str, + openrouter_api_key: str, + tokenizer_name: str, + score_db: HarvestDB, + eval_config: IntruderEvalConfig, + limit: int | None, + cost_limit_usd: float | None, +) -> list[IntruderResult]: + n_real = eval_config.n_real + n_trials = eval_config.n_trials + density_tolerance = eval_config.density_tolerance + + app_tok = AppTokenizer.from_pretrained(tokenizer_name) + + eligible = [c for c in components if len(c.activation_examples) >= n_real + 1] + if limit is not None: + eligible = eligible[:limit] + + density_index = DensityIndex(components, min_examples=n_real + 1) + + existing_scores = score_db.get_scores("intruder") + completed = set(existing_scores.keys()) + if completed: + logger.info(f"Resuming: {len(completed)} already scored") + + remaining = [c for c in eligible if c.component_key not in completed] + logger.info(f"Scoring {len(remaining)} components ({len(remaining) * n_trials} trials)") + + rng = random.Random() + jobs: list[LLMJob] = [] + ground_truth: dict[str, _TrialGroundTruth] = {} + + for i, component in enumerate(remaining): + if i > 0 and i % 1000 == 0: + logger.info(f"Building trials: {i}/{len(remaining)} components") + for trial_idx in range(n_trials): + real_examples = rng.sample(component.activation_examples, n_real) + intruder = _sample_intruder(component, density_index, rng, density_tolerance) + intruder_pos = rng.randint(0, n_real) + correct_answer = intruder_pos + 1 + + key = f"{component.component_key}/trial{trial_idx}" + jobs.append( + LLMJob( + prompt=_build_prompt(real_examples, intruder, intruder_pos, app_tok), + schema=INTRUDER_SCHEMA, + key=key, + ) + ) + ground_truth[key] = _TrialGroundTruth( + component_key=component.component_key, + correct_answer=correct_answer, + ) + logger.info(f"Built {len(jobs)} trials") + + component_trials: defaultdict[str, list[IntruderTrial]] = defaultdict(list) + component_errors: defaultdict[str, int] = defaultdict(int) + + async for outcome in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=model, + reasoning_effort=eval_config.reasoning_effort, + jobs=jobs, + max_tokens=300, + max_concurrent=eval_config.max_concurrent, + max_requests_per_minute=eval_config.max_requests_per_minute, + cost_limit_usd=cost_limit_usd, + response_schema=INTRUDER_SCHEMA, + ): + match outcome: + case LLMResult(job=job, parsed=parsed): + gt = ground_truth[job.key] + predicted = int(parsed["intruder"]) + component_trials[gt.component_key].append( + IntruderTrial( + correct_answer=gt.correct_answer, + predicted=predicted, + is_correct=predicted == gt.correct_answer, + ) + ) + case LLMError(job=job, error=e): + gt = ground_truth[job.key] + component_errors[gt.component_key] += 1 + logger.error(f"{job.key}: {type(e).__name__}: {e}") + + results: list[IntruderResult] = [] + for component in remaining: + ck = component.component_key + trials = component_trials.get(ck, []) + n_err = component_errors.get(ck, 0) + correct = sum(1 for t in trials if t.is_correct) + score = correct / len(trials) if trials else 0.0 + result = IntruderResult(component_key=ck, score=score, trials=trials, n_errors=n_err) + results.append(result) + score_db.save_score(ck, "intruder", score, json.dumps(asdict(result))) + + logger.info(f"Scored {len(results)} components") + return results diff --git a/spd/harvest/lib/harvester.py b/spd/harvest/lib/harvester.py deleted file mode 100644 index 22fd3cc5a..000000000 --- a/spd/harvest/lib/harvester.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Harvester for collecting component statistics in a single pass.""" - -from dataclasses import dataclass -from typing import cast - -import torch -import tqdm -from einops import einsum, rearrange, reduce -from jaxtyping import Float, Int -from torch import Tensor - -from spd.harvest.lib.reservoir_sampler import ReservoirSampler, ReservoirState -from spd.harvest.lib.sampling import sample_at_most_n_per_group, top_k_pmi -from spd.harvest.schemas import ActivationExample, ComponentData, ComponentTokenPMI - -# Sentinel for padding token windows at sequence boundaries. -WINDOW_PAD_SENTINEL = -1 - -# Entry: (token_ids, ci_values_in_window, component_acts_in_window) -ActivationExampleTuple = tuple[list[int], list[float], list[float]] - - -@dataclass -class HarvesterState: - """Serializable state of a Harvester for parallel merging.""" - - layer_names: list[str] - c_per_layer: dict[str, int] # Maps layer name -> number of components - vocab_size: int - ci_threshold: float - max_examples_per_component: int - context_tokens_per_side: int - - # Tensor accumulators (on CPU) - firing_counts: Tensor - ci_sums: Tensor - count_ij: Tensor # Component co-occurrence matrix - input_token_counts: Tensor - input_token_totals: Tensor - output_token_prob_mass: Tensor - output_token_prob_totals: Tensor - total_tokens_processed: int - - # Reservoir states - reservoir_states: list[ReservoirState[ActivationExampleTuple]] - - def merge_into(self, other: "HarvesterState") -> None: - """Merge another HarvesterState into this one (in-place accumulation). - - This is the streaming merge primitive used to avoid OOM when merging many workers. - """ - assert other.layer_names == self.layer_names - assert other.c_per_layer == self.c_per_layer - assert other.vocab_size == self.vocab_size - assert other.ci_threshold == self.ci_threshold - assert other.max_examples_per_component == self.max_examples_per_component - assert other.context_tokens_per_side == self.context_tokens_per_side - assert len(other.reservoir_states) == len(self.reservoir_states) - - # Accumulate tensor stats - self.firing_counts += other.firing_counts - self.ci_sums += other.ci_sums - self.count_ij += other.count_ij - self.input_token_counts += other.input_token_counts - self.input_token_totals += other.input_token_totals - self.output_token_prob_mass += other.output_token_prob_mass - self.output_token_prob_totals += other.output_token_prob_totals - self.total_tokens_processed += other.total_tokens_processed - - # Merge reservoir states pairwise - for i in range(len(self.reservoir_states)): - merged = ReservoirState.merge([self.reservoir_states[i], other.reservoir_states[i]]) - self.reservoir_states[i] = merged - - -class Harvester: - """Accumulates component statistics in a single pass over data.""" - - def __init__( - self, - layer_names: list[str], - c_per_layer: dict[str, int], - vocab_size: int, - ci_threshold: float, - max_examples_per_component: int, - context_tokens_per_side: int, - device: torch.device, - ): - self.layer_names = layer_names - self.c_per_layer = c_per_layer - self.vocab_size = vocab_size - self.ci_threshold = ci_threshold - self.max_examples_per_component = max_examples_per_component - self.context_tokens_per_side = context_tokens_per_side - self.device = device - - # Precompute layer offsets for flat indexing - # layer_offsets[layer_name] gives the starting flat index for that layer's components - self.layer_offsets: dict[str, int] = {} - offset = 0 - for layer in layer_names: - self.layer_offsets[layer] = offset - offset += c_per_layer[layer] - - n_components = sum(c_per_layer[layer] for layer in layer_names) - - # Correlation accumulators - self.firing_counts = torch.zeros(n_components, device=device) - self.ci_sums = torch.zeros(n_components, device=device) - self.count_ij = torch.zeros(n_components, n_components, device=device, dtype=torch.float32) - - # Token stat accumulators - self.input_token_counts: Int[Tensor, "n_components vocab"] = torch.zeros( - n_components, vocab_size, device=device, dtype=torch.long - ) - self.input_token_totals: Int[Tensor, " vocab"] = torch.zeros( - vocab_size, device=device, dtype=torch.long - ) - self.output_token_prob_mass: Float[Tensor, "n_components vocab"] = torch.zeros( - n_components, vocab_size, device=device - ) - self.output_token_prob_totals: Float[Tensor, " vocab"] = torch.zeros( - vocab_size, device=device - ) - - # Reservoir samplers for activation examples - self.activation_example_samplers = [ - ReservoirSampler[ActivationExampleTuple](k=max_examples_per_component) - for _ in range(n_components) - ] - - self.total_tokens_processed = 0 - - def process_batch( - self, - batch: Int[Tensor, "B S"], - ci: Float[Tensor, "B S n_comp"], - output_probs: Float[Tensor, "B S V"], - subcomp_acts: Float[Tensor, "B S n_comp"], - ) -> None: - """Accumulate stats from a single batch. - - Args: - batch: Token IDs - ci: Causal importance values per component - output_probs: Output probabilities - subcomp_acts: Normalized subcomponent activations: (v_i^T @ a) * ||u_i||. - """ - self.total_tokens_processed += batch.numel() - - firing = (ci > self.ci_threshold).float() - - firing_flat = rearrange(firing, "b s c -> (b s) c") - batch_flat = rearrange(batch, "b s -> (b s)") - output_probs_flat = rearrange(output_probs, "b s v -> (b s) v") - - self._accumulate_firing_stats(ci, firing) - self._accumulate_cooccurrence_stats(firing_flat) - self._accumulate_input_token_stats(batch_flat, firing_flat) - self._accumulate_output_token_stats(output_probs_flat, firing_flat) - self._collect_activation_examples(batch, ci, subcomp_acts) - - def _accumulate_firing_stats( - self, - ci: Float[Tensor, "B S n_comp"], - firing: Float[Tensor, "B S n_comp"], - ) -> None: - self.firing_counts += reduce(firing, "b s c -> c", "sum") - self.ci_sums += reduce(ci, "b s c -> c", "sum") - - def _accumulate_cooccurrence_stats(self, firing_flat: Float[Tensor, "pos n_comp"]) -> None: - """Accumulate component-component co-occurrence counts.""" - self.count_ij += einsum(firing_flat, firing_flat, "pos c1, pos c2 -> c1 c2") - - def _accumulate_input_token_stats( - self, - batch_flat: Int[Tensor, " pos"], - firing_flat: Float[Tensor, "pos n_comp"], - ) -> None: - """Accumulate which input tokens caused each component to fire. - - Uses scatter_add_ to efficiently accumulate counts into a sparse [n_comp, vocab] matrix. - For each position, we add the firing indicator (0 or 1) to the count for that token. - - Equivalent to: for each pos, for each component c: - input_token_counts[c, batch_flat[pos]] += firing_flat[pos, c] - """ - n_components = firing_flat.shape[1] - # Broadcast token_ids to [n_comp, pos] so scatter_add_ can index into vocab dim - token_indices = batch_flat.unsqueeze(0).expand(n_components, -1) - # input_token_counts[c, token_indices[c, pos]] += firing_flat.T[c, pos] - self.input_token_counts.scatter_add_( - dim=1, index=token_indices, src=rearrange(firing_flat, "pos c -> c pos").long() - ) - # Count total occurrences of each token (denominator for precision) - self.input_token_totals.scatter_add_( - dim=0, - index=batch_flat, - src=torch.ones(batch_flat.shape[0], device=self.device, dtype=torch.long), - ) - - def _accumulate_output_token_stats( - self, - output_probs_flat: Float[Tensor, "pos vocab"], - firing_flat: Float[Tensor, "pos n_comp"], - ) -> None: - """Accumulate which output tokens each component predicts. - - Unlike input tokens (hard counts), we accumulate probability mass. - When component c fires, we add the full output probability distribution, - weighted by the firing indicator. - """ - # Sum of P(token | pos) for positions where component c fired - self.output_token_prob_mass += einsum(firing_flat, output_probs_flat, "pos c, pos v -> c v") - # Sum of P(token | pos) across all positions (for normalization) - self.output_token_prob_totals += reduce(output_probs_flat, "pos v -> v", "sum") - - def _collect_activation_examples( - self, - batch: Int[Tensor, "B S"], - ci: Float[Tensor, "B S n_comp"], - subcomp_acts: Float[Tensor, "B S n_comp"], - ) -> None: - """Reservoir sample activation examples from high-CI firings.""" - firing = ci > self.ci_threshold - batch_idx, seq_idx, component_idx = torch.where(firing) - if len(batch_idx) == 0: - return - - # Cap firings per component to ensure rare components get examples. - # With ~3000 batches and topk=1000 examples, we only need ~1 per component per batch. - MAX_FIRINGS_PER_COMPONENT = 5 - keep_mask = sample_at_most_n_per_group(component_idx, MAX_FIRINGS_PER_COMPONENT) - batch_idx = batch_idx[keep_mask] - seq_idx = seq_idx[keep_mask] - component_idx = component_idx[keep_mask] - - # Pad sequences so we can extract windows at boundaries without going out of bounds. - # E.g. if context_tokens_per_side=3, a firing at seq_idx=0 needs tokens at [-3, -2, -1, 0, 1, 2, 3] - # Padding with sentinel allows uniform window extraction; sentinels are filtered in display. - batch_padded = torch.nn.functional.pad( - batch, - (self.context_tokens_per_side, self.context_tokens_per_side), - value=WINDOW_PAD_SENTINEL, - ) - ci_padded = torch.nn.functional.pad( - ci, (0, 0, self.context_tokens_per_side, self.context_tokens_per_side), value=0.0 - ) - subcomp_acts_padded = torch.nn.functional.pad( - subcomp_acts, - (0, 0, self.context_tokens_per_side, self.context_tokens_per_side), - value=0.0, - ) - - # Build indices to extract [n_firings, window_size] windows via advanced indexing. - # For each firing, we want tokens at [seq_idx - k, ..., seq_idx, ..., seq_idx + k] - window_size = 2 * self.context_tokens_per_side + 1 - offsets = torch.arange( - -self.context_tokens_per_side, self.context_tokens_per_side + 1, device=self.device - ) - seq_idx_padded = seq_idx + self.context_tokens_per_side # Adjust for padding - window_seq_indices = seq_idx_padded.unsqueeze(1) + offsets # [n_firings, window_size] - batch_idx_expanded = batch_idx.unsqueeze(1).expand(-1, window_size) - component_idx_expanded = component_idx.unsqueeze(1).expand(-1, window_size) - - # Advanced indexing: token_windows[i, j] = batch_padded[batch_idx[i], window_seq_indices[i, j]] - token_windows = batch_padded[batch_idx_expanded, window_seq_indices] - ci_windows = ci_padded[batch_idx_expanded, window_seq_indices, component_idx_expanded] - component_act_windows = subcomp_acts_padded[ - batch_idx_expanded, window_seq_indices, component_idx_expanded - ] - - # Add to reservoir samplers - for comp_idx, tokens, ci_vals, component_acts in zip( - cast(list[int], component_idx.cpu().tolist()), - cast(list[list[int]], token_windows.cpu().tolist()), - cast(list[list[float]], ci_windows.cpu().tolist()), - cast(list[list[float]], component_act_windows.cpu().tolist()), - strict=True, - ): - self.activation_example_samplers[comp_idx].add((tokens, ci_vals, component_acts)) - - def get_state(self) -> HarvesterState: - """Extract serializable state for parallel merging.""" - return HarvesterState( - layer_names=self.layer_names, - c_per_layer=self.c_per_layer, - vocab_size=self.vocab_size, - ci_threshold=self.ci_threshold, - max_examples_per_component=self.max_examples_per_component, - context_tokens_per_side=self.context_tokens_per_side, - firing_counts=self.firing_counts.cpu(), - ci_sums=self.ci_sums.cpu(), - count_ij=self.count_ij.cpu(), - input_token_counts=self.input_token_counts.cpu(), - input_token_totals=self.input_token_totals.cpu(), - output_token_prob_mass=self.output_token_prob_mass.cpu(), - output_token_prob_totals=self.output_token_prob_totals.cpu(), - total_tokens_processed=self.total_tokens_processed, - reservoir_states=[s.get_state() for s in self.activation_example_samplers], - ) - - @staticmethod - def from_state(state: HarvesterState, device: torch.device) -> "Harvester": - """Reconstruct Harvester from state.""" - harvester = Harvester( - layer_names=state.layer_names, - c_per_layer=state.c_per_layer, - vocab_size=state.vocab_size, - ci_threshold=state.ci_threshold, - max_examples_per_component=state.max_examples_per_component, - context_tokens_per_side=state.context_tokens_per_side, - device=device, - ) - harvester.firing_counts = state.firing_counts.to(device) - harvester.ci_sums = state.ci_sums.to(device) - harvester.count_ij = state.count_ij.to(device) - harvester.input_token_counts = state.input_token_counts.to(device) - harvester.input_token_totals = state.input_token_totals.to(device) - harvester.output_token_prob_mass = state.output_token_prob_mass.to(device) - harvester.output_token_prob_totals = state.output_token_prob_totals.to(device) - harvester.total_tokens_processed = state.total_tokens_processed - harvester.activation_example_samplers = [ - ReservoirSampler.from_state(s) for s in state.reservoir_states - ] - return harvester - - def build_results(self, pmi_top_k_tokens: int) -> list[ComponentData]: - """Convert accumulated state into ComponentData objects.""" - print(" Moving tensors to CPU...") - mean_ci_per_component = (self.ci_sums / self.total_tokens_processed).cpu() - firing_counts = self.firing_counts.cpu() - input_token_counts = self.input_token_counts.cpu() - input_token_totals = self.input_token_totals.cpu() - output_token_prob_mass = self.output_token_prob_mass.cpu() - output_token_prob_totals = self.output_token_prob_totals.cpu() - - self._log_base_rate_summary(firing_counts, input_token_totals) - - n_total = sum(self.c_per_layer[layer] for layer in self.layer_names) - print( - f" Computing stats for {n_total} components across {len(self.layer_names)} layers..." - ) - components = [] - for layer_name in tqdm.tqdm(self.layer_names, desc="Building components"): - layer_offset = self.layer_offsets[layer_name] - layer_c = self.c_per_layer[layer_name] - - for component_idx in range(layer_c): - flat_idx = layer_offset + component_idx - mean_ci = float(mean_ci_per_component[flat_idx]) - - component_firing_count = float(firing_counts[flat_idx]) - if component_firing_count == 0: - continue - - # Build activation examples from reservoir (uniform random sample) - sampler = self.activation_example_samplers[flat_idx] - activation_examples = [ - ActivationExample( - token_ids=token_ids, ci_values=ci_values, component_acts=component_acts - ) - for token_ids, ci_values, component_acts in sampler.samples - ] - - input_token_pmi = _compute_token_pmi( - token_mass_for_component=input_token_counts[flat_idx], - token_mass_totals=input_token_totals, - component_firing_count=component_firing_count, - total_tokens=self.total_tokens_processed, - top_k=pmi_top_k_tokens, - ) - - output_token_pmi = _compute_token_pmi( - token_mass_for_component=output_token_prob_mass[flat_idx], - token_mass_totals=output_token_prob_totals, - component_firing_count=component_firing_count, - total_tokens=self.total_tokens_processed, - top_k=pmi_top_k_tokens, - ) - - components.append( - ComponentData( - component_key=f"{layer_name}:{component_idx}", - layer=layer_name, - component_idx=component_idx, - mean_ci=mean_ci, - activation_examples=activation_examples, - input_token_pmi=input_token_pmi, - output_token_pmi=output_token_pmi, - ) - ) - - return components - - def _log_base_rate_summary(self, firing_counts: Tensor, input_token_totals: Tensor) -> None: - """Log summary statistics about base rates.""" - active_counts = firing_counts[firing_counts > 0] - if len(active_counts) == 0: - print(" WARNING: No components fired above threshold!") - return - - sorted_counts = active_counts.sort().values - n_active = len(active_counts) - print("\n === Base Rate Summary ===") - print(f" Components with firings: {n_active} / {len(firing_counts)}") - print( - f" Firing counts - min: {int(sorted_counts[0])}, " - f"median: {int(sorted_counts[n_active // 2])}, " - f"max: {int(sorted_counts[-1])}" - ) - - LOW_FIRING_THRESHOLD = 100 - n_sparse = int((active_counts < LOW_FIRING_THRESHOLD).sum()) - if n_sparse > 0: - print( - f" WARNING: {n_sparse} components have <{LOW_FIRING_THRESHOLD} firings " - f"(stats may be noisy)" - ) - - active_tokens = input_token_totals[input_token_totals > 0] - sorted_token_counts = active_tokens.sort().values - n_tokens = len(active_tokens) - print( - f" Tokens seen: {n_tokens} unique, " - f"occurrences - min: {int(sorted_token_counts[0])}, " - f"median: {int(sorted_token_counts[n_tokens // 2])}, " - f"max: {int(sorted_token_counts[-1])}" - ) - - RARE_TOKEN_THRESHOLD = 10 - n_rare = int((active_tokens < RARE_TOKEN_THRESHOLD).sum()) - if n_rare > 0: - print( - f" Note: {n_rare} tokens have <{RARE_TOKEN_THRESHOLD} occurrences " - f"(high precision/recall with these may be spurious)" - ) - print() - - -def _compute_token_pmi( - token_mass_for_component: Tensor, - token_mass_totals: Tensor, - component_firing_count: float, - total_tokens: int, - top_k: int, -) -> ComponentTokenPMI: - """Compute PMI for tokens associated with a component.""" - top, bottom = top_k_pmi( - cooccurrence_counts=token_mass_for_component, - marginal_counts=token_mass_totals, - target_count=component_firing_count, - total_count=total_tokens, - top_k=top_k, - ) - return ComponentTokenPMI(top=top, bottom=bottom) diff --git a/spd/harvest/lib/reservoir_sampler.py b/spd/harvest/lib/reservoir_sampler.py deleted file mode 100644 index 2e48aea09..000000000 --- a/spd/harvest/lib/reservoir_sampler.py +++ /dev/null @@ -1,72 +0,0 @@ -import heapq -import random -from dataclasses import dataclass -from typing import Generic, TypeVar - -T = TypeVar("T") - - -@dataclass -class ReservoirState(Generic[T]): # noqa: UP046 - PEP 695 syntax breaks pickling - """Serializable state of a ReservoirSampler.""" - - k: int - samples: list[T] - n_seen: int - - @staticmethod - def merge(states: list["ReservoirState[T]"]) -> "ReservoirState[T]": - """Merge multiple reservoir states via weighted random sampling. - - Uses Efraimidis-Spirakis algorithm: each sample gets key = random()^(1/weight), - take k largest. O(n + k log n) vs O(k*n) for naive weighted sampling. - """ - assert len(states) > 0 - k = states[0].k - assert all(s.k == k for s in states) - - total_seen = sum(s.n_seen for s in states) - if total_seen == 0: - return ReservoirState(k=k, samples=[], n_seen=0) - - # Build weighted pool: each sample weighted by its reservoir's n_seen - weighted_samples: list[tuple[T, int]] = [] - for state in states: - for sample in state.samples: - weighted_samples.append((sample, state.n_seen)) - - if len(weighted_samples) <= k: - merged_samples = [s for s, _ in weighted_samples] - else: - # Efraimidis-Spirakis: key = random()^(1/weight), take k largest - keys_and_samples = [(random.random() ** (1.0 / w), s) for s, w in weighted_samples] - top_k = heapq.nlargest(k, keys_and_samples, key=lambda x: x[0]) - merged_samples = [s for _, s in top_k] - - return ReservoirState(k=k, samples=merged_samples, n_seen=total_seen) - - -class ReservoirSampler(Generic[T]): # noqa: UP046 - PEP 695 syntax breaks pickling - """Uniform random sampling from a stream via reservoir sampling.""" - - def __init__(self, k: int): - self.k = k - self.samples: list[T] = [] - self.n_seen = 0 - - def add(self, item: T) -> None: - self.n_seen += 1 - if len(self.samples) < self.k: - self.samples.append(item) - elif random.randint(1, self.n_seen) <= self.k: - self.samples[random.randrange(self.k)] = item - - def get_state(self) -> ReservoirState[T]: - return ReservoirState(k=self.k, samples=list(self.samples), n_seen=self.n_seen) - - @staticmethod - def from_state(state: ReservoirState[T]) -> "ReservoirSampler[T]": - sampler: ReservoirSampler[T] = ReservoirSampler(k=state.k) - sampler.samples = list(state.samples) - sampler.n_seen = state.n_seen - return sampler diff --git a/spd/harvest/loaders.py b/spd/harvest/loaders.py deleted file mode 100644 index 929316286..000000000 --- a/spd/harvest/loaders.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Loaders for reading harvest output files.""" - -import json -import threading - -from spd.harvest.schemas import ( - ActivationExample, - ComponentData, - ComponentSummary, - ComponentTokenPMI, - get_activation_contexts_dir, - get_correlations_dir, -) -from spd.harvest.storage import CorrelationStorage, TokenStatsStorage - - -def load_activation_contexts_summary(wandb_run_id: str) -> dict[str, ComponentSummary] | None: - """Load lightweight summary of activation contexts (just metadata, not full examples).""" - ctx_dir = get_activation_contexts_dir(wandb_run_id) - path = ctx_dir / "summary.json" - if not path.exists(): - return None - return ComponentSummary.load_all(path) - - -# Cache for component indices (run_id -> {component_key -> byte_offset}) -_component_index_cache: dict[str, dict[str, int]] = {} -_component_index_lock = threading.Lock() - -_COMPONENT_KEY_PREFIX = '"component_key": "' - - -def _get_component_index(wandb_run_id: str) -> dict[str, int]: - """Get or build component index for a run. - - On first access, scans the components.jsonl file to build a byte offset - index, then caches it in memory for O(1) lookups. - """ - # Fast path: already cached - if wandb_run_id in _component_index_cache: - return _component_index_cache[wandb_run_id] - - # Slow path: build index under lock to prevent duplicate work - with _component_index_lock: - # Double-check after acquiring lock - if wandb_run_id in _component_index_cache: - return _component_index_cache[wandb_run_id] - - ctx_dir = get_activation_contexts_dir(wandb_run_id) - components_path = ctx_dir / "components.jsonl" - assert components_path.exists(), f"No activation contexts found at {components_path}" - - index: dict[str, int] = {} - with open(components_path) as f: - while True: - offset = f.tell() - line = f.readline() - if not line: - break - # Extract component_key from start of JSON line - # Format: {"component_key": "layer:idx", ...} - key_start = line.find(_COMPONENT_KEY_PREFIX) - assert key_start != -1, f"Malformed line in components.jsonl: {line[:100]}" - key_start += len(_COMPONENT_KEY_PREFIX) - key_end = line.find('"', key_start) - assert key_end != -1, f"Malformed line in components.jsonl: {line[:100]}" - component_key = line[key_start:key_end] - index[component_key] = offset - - _component_index_cache[wandb_run_id] = index - return index - - -def load_component_activation_contexts(wandb_run_id: str, component_key: str) -> ComponentData: - """Load a single component's activation contexts using index for O(1) lookup.""" - ctx_dir = get_activation_contexts_dir(wandb_run_id) - path = ctx_dir / "components.jsonl" - assert path.exists(), f"No activation contexts found at {path}" - - index = _get_component_index(wandb_run_id) - if component_key not in index: - raise ValueError(f"Component {component_key} not found in activation contexts") - - byte_offset = index[component_key] - with open(path) as f: - f.seek(byte_offset) - line = f.readline() - - data = json.loads(line) - data["activation_examples"] = [ActivationExample(**ex) for ex in data["activation_examples"]] - data["input_token_pmi"] = ComponentTokenPMI(**data["input_token_pmi"]) - data["output_token_pmi"] = ComponentTokenPMI(**data["output_token_pmi"]) - return ComponentData(**data) - - -def load_correlations(wandb_run_id: str) -> CorrelationStorage: - """Load component correlations from harvest output.""" - corr_dir = get_correlations_dir(wandb_run_id) - path = corr_dir / "component_correlations.pt" - assert path.exists() - return CorrelationStorage.load(path) - - -def load_token_stats(wandb_run_id: str) -> TokenStatsStorage: - """Load token statistics from harvest output.""" - corr_dir = get_correlations_dir(wandb_run_id) - path = corr_dir / "token_stats.pt" - assert path.exists() - return TokenStatsStorage.load(path) diff --git a/spd/harvest/repo.py b/spd/harvest/repo.py new file mode 100644 index 000000000..d1eb7019b --- /dev/null +++ b/spd/harvest/repo.py @@ -0,0 +1,148 @@ +"""Harvest data repository. + +Owns SPD_OUT_DIR/harvest// and provides read/write access to all +harvest artifacts. No in-memory caching -- reads go through on every call. +Component data backed by SQLite; correlations and token stats remain as .pt files. + +Layout: harvest//h-YYYYMMDD_HHMMSS/{harvest.db, *.pt} +""" + +from pathlib import Path + +from spd.harvest.config import HarvestConfig +from spd.harvest.db import HarvestDB +from spd.harvest.harvester import Harvester +from spd.harvest.schemas import ( + ComponentData, + ComponentSummary, + get_harvest_dir, +) +from spd.harvest.storage import CorrelationStorage, TokenStatsStorage +from spd.log import logger + + +class HarvestRepo: + """Access to harvest data for a single harvest subrun of a decomposition.""" + + def __init__(self, decomposition_id: str, subrun_id: str, readonly: bool) -> None: + self.subrun_id = subrun_id + self._dir = get_harvest_dir(decomposition_id) / subrun_id + self._db = HarvestDB(self._dir / "harvest.db", readonly=readonly) + + @classmethod + def open_most_recent( + cls, + decomposition_id: str, + readonly: bool = True, + ) -> "HarvestRepo | None": + """Open harvest data. Returns None if no harvest data exists.""" + decomposition_subruns_dir = get_harvest_dir(decomposition_id) + if not decomposition_subruns_dir.exists(): + return None + + subrun_candidates = sorted( + [ + d + for d in decomposition_subruns_dir.iterdir() + if d.is_dir() and d.name.startswith("h-") + ], + key=lambda d: d.name, + ) + if not subrun_candidates: + return None + + subrun_dir = subrun_candidates[-1] + + db_path = subrun_dir / "harvest.db" + if not db_path.exists(): + logger.info(f"No harvest data found for {decomposition_id}") + return None + + logger.info(f"Opening harvest data for {decomposition_id} from {subrun_dir}") + subrun_id = subrun_dir.name + + return cls(decomposition_id=decomposition_id, subrun_id=subrun_id, readonly=readonly) + + @staticmethod + def save_results(harvester: Harvester, config: HarvestConfig, output_dir: Path) -> None: + """Build and save all harvest results to disk. + + Components are streamed to the DB one at a time to avoid holding all + ComponentData objects in memory simultaneously. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("Building and saving component results...") + db_path = output_dir / "harvest.db" + db = HarvestDB(db_path) + db.save_config(config) + components_iter = harvester.build_results(pmi_top_k_tokens=config.pmi_token_top_k) + n_saved = db.save_components_iter(components_iter) + db.close() + logger.info(f"Saved {n_saved} components to {db_path}") + + component_keys = harvester.component_keys + + correlations = CorrelationStorage( + component_keys=component_keys, + count_i=harvester.firing_counts.long().cpu(), + count_ij=harvester.cooccurrence_counts.long().cpu(), + count_total=harvester.total_tokens_processed, + ) + correlations.save(output_dir / "component_correlations.pt") + + token_stats = TokenStatsStorage( + component_keys=component_keys, + vocab_size=harvester.vocab_size, + n_tokens=harvester.total_tokens_processed, + input_counts=harvester.input_cooccurrence.cpu(), + input_totals=harvester.input_marginals.float().cpu(), + output_counts=harvester.output_cooccurrence.cpu(), + output_totals=harvester.output_marginals.cpu(), + firing_counts=harvester.firing_counts.cpu(), + ) + token_stats.save(output_dir / "token_stats.pt") + + # -- Provenance ------------------------------------------------------------ + + def get_config(self) -> dict[str, object]: + return self._db.get_config_dict() + + def get_component_count(self) -> int: + return self._db.get_component_count() + + # -- Activation contexts --------------------------------------------------- + + def get_summary(self) -> dict[str, ComponentSummary]: + return self._db.get_summary() + + def get_component(self, component_key: str) -> ComponentData | None: + return self._db.get_component(component_key) + + def get_components_bulk(self, component_keys: list[str]) -> dict[str, ComponentData]: + return self._db.get_components_bulk(component_keys) + + def get_all_components(self) -> list[ComponentData]: + return self._db.get_all_components() + + # -- Correlations & token stats (tensor data) ------------------------------ + + def get_correlations(self) -> CorrelationStorage | None: + path = self._dir / "component_correlations.pt" + if not path.exists(): + return None + return CorrelationStorage.load(path) + + def get_token_stats(self) -> TokenStatsStorage | None: + path = self._dir / "token_stats.pt" + if not path.exists(): + return None + return TokenStatsStorage.load(path) + + # -- Eval scores (e.g. intruder) ------------------------------------------- + + def save_score(self, component_key: str, score_type: str, score: float, details: str) -> None: + self._db.save_score(component_key, score_type, score, details) + + def get_scores(self, score_type: str) -> dict[str, float]: + return self._db.get_scores(score_type) diff --git a/spd/harvest/reservoir.py b/spd/harvest/reservoir.py new file mode 100644 index 000000000..94bcdad88 --- /dev/null +++ b/spd/harvest/reservoir.py @@ -0,0 +1,234 @@ +"""Activation examples reservoir backed by dense tensors. + +Stores [n_components, k, window] activation example windows using Algorithm R +for sampling and Efraimidis-Spirakis for merging parallel reservoirs. +""" + +import random +from collections import defaultdict +from collections.abc import Iterator +from dataclasses import dataclass + +import torch +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.harvest.schemas import ActivationExample +from spd.utils.general_utils import runtime_cast + +WINDOW_PAD_SENTINEL = -1 + + +@dataclass +class ActivationWindows: + component_idx: Int[Tensor, " n_firings"] + token_windows: Int[Tensor, "n_firings window_size"] + firing_windows: Bool[Tensor, "n_firings window_size"] + activation_windows: dict[str, Float[Tensor, "n_firings window_size"]] + + +class ActivationExamplesReservoir: + """Fixed-capacity reservoir of activation example windows per component. + + Each component slot holds up to `k` windows of size `w`, where each window + contains (token_ids, activation_values, component_acts) aligned by position. + + Use create() for fresh allocation, from_state_dict() for deserialization. + """ + + def __init__( + self, + n_components: int, + k: int, + window: int, + device: torch.device, + tokens: Int[Tensor, "C k w"], + firings: Bool[Tensor, "C k w"], + acts: dict[str, Float[Tensor, "C k w"]], + n_items: Int[Tensor, " C"], + n_seen: Int[Tensor, " C"], + ): + self.n_components = n_components + self.k = k + self.window = window + self.device = device + self.tokens = tokens + self.firings = firings + self.acts = acts + self.n_items = n_items + self.n_seen = n_seen + + @classmethod + def create( + cls, + n_components: int, + k: int, + window: int, + device: torch.device, + ) -> "ActivationExamplesReservoir": + return cls( + n_components=n_components, + k=k, + window=window, + device=device, + tokens=torch.full( + (n_components, k, window), WINDOW_PAD_SENTINEL, dtype=torch.long, device=device + ), + firings=torch.full((n_components, k, window), False, dtype=torch.bool, device=device), + acts=defaultdict(lambda: torch.zeros(n_components, k, window, device=device)), + n_items=torch.zeros(n_components, dtype=torch.long, device=device), + n_seen=torch.zeros(n_components, dtype=torch.long, device=device), + ) + + @classmethod + def from_state_dict( + cls, d: dict[str, object], device: torch.device + ) -> "ActivationExamplesReservoir": + tokens = runtime_cast(Tensor, d["tokens"]) + + acts = runtime_cast(dict, d["acts"]) + acts = {act_type: runtime_cast(Tensor, acts[act_type]).to(device) for act_type in acts} + + return cls( + n_components=tokens.shape[0], + k=runtime_cast(int, d["k"]), + window=runtime_cast(int, d["window"]), + device=device, + tokens=tokens.to(device), + firings=runtime_cast(Tensor, d["firings"]).to(device), + acts=acts, + n_items=runtime_cast(Tensor, d["n_items"]).to(device), + n_seen=runtime_cast(Tensor, d["n_seen"]).to(device), + ) + + def add(self, activation_windows: ActivationWindows) -> None: + """Add firing windows via Algorithm R. + + Bookkeeping on CPU (cheap integer ops), then batch-write to device. + """ + device = activation_windows.component_idx.device + comps = activation_windows.component_idx.cpu().tolist() + items_cpu = self.n_items.cpu() + seen_cpu = self.n_seen.cpu() + + write_comps: list[int] = [] + write_slots: list[int] = [] + write_srcs: list[int] = [] + + for i, c in enumerate(comps): + n = int(seen_cpu[c]) + if items_cpu[c] < self.k: + write_comps.append(c) + write_slots.append(int(items_cpu[c])) + write_srcs.append(i) + items_cpu[c] += 1 + else: + j = random.randint(0, n) + if j < self.k: + write_comps.append(c) + write_slots.append(j) + write_srcs.append(i) + seen_cpu[c] += 1 + + self.n_items.copy_(items_cpu) + self.n_seen.copy_(seen_cpu) + + if write_comps: + c_t = torch.tensor(write_comps, dtype=torch.long, device=device) + s_t = torch.tensor(write_slots, dtype=torch.long, device=device) + f_t = torch.tensor(write_srcs, dtype=torch.long, device=device) + + self.tokens[c_t, s_t] = activation_windows.token_windows[f_t] + self.firings[c_t, s_t] = activation_windows.firing_windows[f_t] + for act_type in activation_windows.activation_windows: + self.acts[act_type][c_t, s_t] = activation_windows.activation_windows[act_type][f_t] + + def merge(self, other: "ActivationExamplesReservoir") -> None: + """Merge other's reservoir into self via Efraimidis-Spirakis. + + Computes selection indices on small [C, 2k] tensors, then gathers + from self/other based on whether each selected index came from self or other. + """ + assert other.n_components == self.n_components + assert other.k == self.k + device = self.device + n_comp = self.n_components + + idx = rearrange(torch.arange(self.k, device=device), "k -> 1 k") + valid_self = idx < rearrange(self.n_items, "c -> c 1") + valid_other = idx < rearrange(other.n_items, "c -> c 1") + valid = torch.cat([valid_self, valid_other], dim=1) + + weights = torch.zeros(n_comp, 2 * self.k, device=device) + weights[:, : self.k] = rearrange(self.n_seen.float(), "c -> c 1") + weights[:, self.k :] = rearrange(other.n_seen.float(), "c -> c 1") + weights[~valid] = 0.0 + + rand = torch.rand(n_comp, 2 * self.k, device=device).clamp(min=1e-30) + keys = rand.pow(1.0 / weights.clamp(min=1.0)) + keys[~valid] = -1.0 + + _, top_indices = keys.topk(self.k, dim=1) + + from_self = top_indices < self.k + self_indices = top_indices.clamp(max=self.k - 1) + other_indices = (top_indices - self.k).clamp(min=0) + + si = repeat(self_indices, "c k -> c k w", w=self.window) + oi = repeat(other_indices, "c k -> c k w", w=self.window) + mask = repeat(from_self, "c k -> c k w", w=self.window) + + self.tokens = torch.where(mask, self.tokens.gather(1, si), other.tokens.gather(1, oi)) + + self.firings = torch.where(mask, self.firings.gather(1, si), other.firings.gather(1, oi)) + + for act_type in self.acts: + self.acts[act_type] = torch.where( + mask, + self.acts[act_type].gather(1, si), + other.acts[act_type].gather(1, oi), + ) + + self.n_items = valid.sum(dim=1).clamp(max=self.k) + self.n_seen = self.n_seen + other.n_seen + + def examples(self, component: int) -> Iterator[ActivationExample]: + """Yield (token_ids, component_acts), sentinel-filtered.""" + n = int(self.n_items[component]) + for j in range(n): + toks = self.tokens[component, j] + firings = self.firings[component, j] + acts = {act_type: self.acts[act_type][component, j] for act_type in self.acts} + + mask = toks != WINDOW_PAD_SENTINEL # TODO(oli) not sure this is actually needed + + toks = toks[mask].tolist() + firings = firings[mask].tolist() + acts = {act_type: acts[act_type][mask].tolist() for act_type in acts} + + yield ActivationExample(token_ids=toks, firings=firings, activations=acts) + + def to(self, device: torch.device) -> "ActivationExamplesReservoir": + return ActivationExamplesReservoir( + n_components=self.n_components, + k=self.k, + window=self.window, + device=device, + tokens=self.tokens.to(device), + firings=self.firings.to(device), + acts={act_type: self.acts[act_type].to(device) for act_type in self.acts}, + n_items=self.n_items.to(device), + n_seen=self.n_seen.to(device), + ) + + def state_dict(self) -> dict[str, object]: + return { + "k": self.k, + "window": self.window, + "tokens": self.tokens.cpu(), + "firings": self.firings.cpu(), + "acts": {act_type: self.acts[act_type].cpu() for act_type in self.acts}, + "n_items": self.n_items.cpu(), + "n_seen": self.n_seen.cpu(), + } diff --git a/spd/harvest/lib/sampling.py b/spd/harvest/sampling.py similarity index 100% rename from spd/harvest/lib/sampling.py rename to spd/harvest/sampling.py diff --git a/spd/harvest/schemas.py b/spd/harvest/schemas.py index ee6af6b87..2faf1eefe 100644 --- a/spd/harvest/schemas.py +++ b/spd/harvest/schemas.py @@ -1,8 +1,12 @@ """Data types for harvest pipeline.""" -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from pathlib import Path +from typing import Any + +from jaxtyping import Bool, Float, Int +from pydantic import BaseModel, model_validator +from torch import Tensor from spd.settings import SPD_OUT_DIR @@ -15,56 +19,74 @@ def get_harvest_dir(wandb_run_id: str) -> Path: return HARVEST_DATA_DIR / wandb_run_id -def get_activation_contexts_dir(wandb_run_id: str) -> Path: - """Get the activation contexts directory for a run.""" - return get_harvest_dir(wandb_run_id) / "activation_contexts" +def get_harvest_subrun_dir(decomposition_id: str, subrun_id: str) -> Path: + """Get the sub-run directory for a specific harvest invocation.""" + return get_harvest_dir(decomposition_id) / subrun_id -def get_correlations_dir(wandb_run_id: str) -> Path: - """Get the correlations directory for a run.""" - return get_harvest_dir(wandb_run_id) / "correlations" +@dataclass +class HarvestBatch: + """Output of a method-specific harvest function for a single batch. + The harvest loop calls the user-provided harvest_fn on each raw dataloader batch, + which returns one of these. The harvest loop then feeds it to the Harvester. -@dataclass -class ActivationExample: - token_ids: list[int] - ci_values: list[float] - component_acts: list[float] # Normalized component activations: (v_i^T @ a) * ||u_i|| + firings/activations are keyed by layer name. activations values are keyed by + activation type (e.g. "causal_importance", "component_activation" for SPD; + just "activation" for SAEs). + """ + tokens: Int[Tensor, "batch seq"] + firings: dict[str, Bool[Tensor, "batch seq c"]] + activations: dict[str, dict[str, Float[Tensor, "batch seq c"]]] + output_probs: Float[Tensor, "batch seq vocab"] -@dataclass -class ComponentTokenPMI: + +class ActivationExample(BaseModel): + """Activation example for a single component. no padding""" + + token_ids: list[int] + firings: list[bool] + activations: dict[str, list[float]] + + @model_validator(mode="before") + @classmethod + def _strip_legacy_padding(cls, data: dict[str, Any]) -> dict[str, Any]: + """Strip -1 padding sentinels from old harvest data.""" + PAD = -1 + token_ids = data["token_ids"] + if any(t == PAD for t in token_ids): + mask = [t != PAD for t in token_ids] + data["token_ids"] = [v for v, k in zip(token_ids, mask, strict=True) if k] + data["firings"] = [v for v, k in zip(data["firings"], mask, strict=True) if k] + data["activations"] = { + act_type: [v for v, k in zip(vals, mask, strict=True) if k] + for act_type, vals in data["activations"].items() + } + return data + + +class ComponentTokenPMI(BaseModel): top: list[tuple[int, float]] bottom: list[tuple[int, float]] -@dataclass -class ComponentSummary: +class ComponentSummary(BaseModel): """Lightweight summary of a component (for /summary endpoint).""" layer: str component_idx: int - mean_ci: float + firing_density: float + mean_activations: dict[str, float] + """Key is activation type, (e.g. "causal_importance", "component_activation", etc.)""" - @staticmethod - def save_all(summaries: dict[str, "ComponentSummary"], path: Path) -> None: - """Save component summaries to JSON file.""" - data = {key: asdict(s) for key, s in summaries.items()} - path.write_text(json.dumps(data)) - @staticmethod - def load_all(path: Path) -> dict[str, "ComponentSummary"]: - """Load component summaries from JSON file.""" - data = json.loads(path.read_text()) - return {key: ComponentSummary(**val) for key, val in data.items()} - - -@dataclass -class ComponentData: +class ComponentData(BaseModel): component_key: str layer: str component_idx: int - mean_ci: float + mean_activations: dict[str, float] + firing_density: float activation_examples: list[ActivationExample] input_token_pmi: ComponentTokenPMI output_token_pmi: ComponentTokenPMI diff --git a/spd/harvest/scripts/run.py b/spd/harvest/scripts/run.py deleted file mode 100644 index a3f19b20e..000000000 --- a/spd/harvest/scripts/run.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Worker script for harvest pipeline. - -Usage (non-SLURM): - # Single GPU - python -m spd.harvest.scripts.run --n_batches 1000 - - # Multi-GPU (run in parallel via shell, tmux, etc.) - python -m spd.harvest.scripts.run --n_batches 1000 --rank 0 --world_size 4 & - python -m spd.harvest.scripts.run --n_batches 1000 --rank 1 --world_size 4 & - python -m spd.harvest.scripts.run --n_batches 1000 --rank 2 --world_size 4 & - python -m spd.harvest.scripts.run --n_batches 1000 --rank 3 --world_size 4 & - wait - - # Merge results after all workers complete - python -m spd.harvest.scripts.run --merge - -Usage (SLURM submission): - spd-harvest --n_batches 1000 --n_gpus 8 -""" - -from spd.harvest.harvest import ( - HarvestConfig, - harvest_activation_contexts, - merge_activation_contexts, -) -from spd.harvest.schemas import get_activation_contexts_dir, get_correlations_dir -from spd.utils.wandb_utils import parse_wandb_run_path - - -def main( - wandb_path: str, - n_batches: int | None = None, - batch_size: int = 256, - ci_threshold: float = 1e-6, - activation_examples_per_component: int = 1000, - activation_context_tokens_per_side: int = 10, - pmi_token_top_k: int = 40, - rank: int | None = None, - world_size: int | None = None, - merge: bool = False, -) -> None: - """Harvest correlations and activation contexts, or merge results. - - Args: - wandb_path: WandB run path for the target decomposition run. - n_batches: Number of batches to process. If None, processes entire training dataset. - batch_size: Batch size for processing. - ci_threshold: CI threshold for component activation. - activation_examples_per_component: Number of activation examples per component. - activation_context_tokens_per_side: Number of tokens per side of the activation context. - pmi_token_top_k: Number of top- and bottom-k tokens by PMI to include. - rank: Worker rank for parallel execution (0 to world_size-1). - world_size: Total number of workers. If specified with rank, only processes - batches where batch_idx % world_size == rank. - merge: If True, merge partial results from workers. - """ - - _, _, run_id = parse_wandb_run_path(wandb_path) - - if merge: - assert rank is None and world_size is None, "Cannot specify rank/world_size with --merge" - print(f"Merging harvest results for {wandb_path}") - merge_activation_contexts(wandb_path) - return - - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - - config = HarvestConfig( - wandb_path=wandb_path, - n_batches=n_batches, - batch_size=batch_size, - ci_threshold=ci_threshold, - activation_examples_per_component=activation_examples_per_component, - activation_context_tokens_per_side=activation_context_tokens_per_side, - pmi_token_top_k=pmi_token_top_k, - ) - - activation_contexts_dir = get_activation_contexts_dir(run_id) - correlations_dir = get_correlations_dir(run_id) - - if world_size is not None: - print(f"Distributed harvest: {wandb_path} (rank {rank}/{world_size})") - else: - print(f"Single-GPU harvest: {wandb_path}") - - harvest_activation_contexts(config, activation_contexts_dir, correlations_dir, rank, world_size) - - -if __name__ == "__main__": - import fire - - fire.Fire(main) diff --git a/spd/harvest/scripts/run_intruder.py b/spd/harvest/scripts/run_intruder.py new file mode 100644 index 000000000..1870d1a06 --- /dev/null +++ b/spd/harvest/scripts/run_intruder.py @@ -0,0 +1,63 @@ +import asyncio +import os +from typing import Any + +from dotenv import load_dotenv + +from spd.adapters import adapter_from_id +from spd.harvest.config import IntruderEvalConfig +from spd.harvest.db import HarvestDB +from spd.harvest.intruder import run_intruder_scoring +from spd.harvest.repo import HarvestRepo +from spd.log import logger + + +def main( + decomposition_id: str, + config_json: dict[str, Any], + harvest_subrun_id: str, +) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + load_dotenv() + openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") + assert openrouter_api_key, "OPENROUTER_API_KEY not set" + + eval_config = IntruderEvalConfig.model_validate(config_json) + + tokenizer_name = adapter_from_id(decomposition_id).tokenizer_name + + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) + score_db = HarvestDB(harvest._dir / "harvest.db") + + logger.info("Loading components from harvest DB...") + components = harvest.get_all_components() + logger.info(f"Loaded {len(components)} components") + + asyncio.run( + run_intruder_scoring( + components=components, + model=eval_config.model, + openrouter_api_key=openrouter_api_key, + tokenizer_name=tokenizer_name, + score_db=score_db, + eval_config=eval_config, + limit=eval_config.limit, + cost_limit_usd=eval_config.cost_limit_usd, + ) + ) + score_db.close() + + +def get_command(decomposition_id: str, config: IntruderEvalConfig, harvest_subrun_id: str) -> str: + config_json = config.model_dump_json(exclude_none=True) + return ( + f"python -m spd.harvest.scripts.run_intruder {decomposition_id} " + f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id}" + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/harvest/scripts/run_merge.py b/spd/harvest/scripts/run_merge.py new file mode 100644 index 000000000..f7777a66c --- /dev/null +++ b/spd/harvest/scripts/run_merge.py @@ -0,0 +1,36 @@ +"""Harvest merge: combines worker states into final harvest results. + +Usage: + python -m spd.harvest.scripts.run_merge --subrun_id h-20260211_120000 --config_json '...' +""" + +from typing import Any + +import fire + +from spd.harvest.config import HarvestConfig +from spd.harvest.harvest import merge_harvest +from spd.harvest.schemas import get_harvest_subrun_dir +from spd.log import logger + + +def main(subrun_id: str, config_json: dict[str, Any]) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + config = HarvestConfig.model_validate(config_json) + output_dir = get_harvest_subrun_dir(config.method_config.id, subrun_id) + logger.info(f"Merging harvest results for (subrun {subrun_id})") + merge_harvest(output_dir, config) + + +def get_command(subrun_id: str, config: HarvestConfig) -> str: + config_json = config.model_dump_json(exclude_none=True) + cmd = ( + f"python -m spd.harvest.scripts.run_merge " + f"--subrun_id {subrun_id} " + f"--config_json '{config_json}'" + ) + return cmd + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/harvest/scripts/run_slurm.py b/spd/harvest/scripts/run_slurm.py index d30874d04..7775bedac 100644 --- a/spd/harvest/scripts/run_slurm.py +++ b/spd/harvest/scripts/run_slurm.py @@ -1,8 +1,7 @@ """SLURM launcher for harvest pipeline. -Submits multi-GPU harvest jobs as a SLURM array, with a dependent merge job -that runs after all workers complete. Creates a git snapshot to ensure consistent -code across all workers even if jobs are queued. +Harvest is a functional unit: GPU workers -> merge. This module submits all +jobs in the unit with proper dependency chaining. Usage: spd-harvest --n_gpus 24 @@ -10,87 +9,76 @@ """ import secrets +from dataclasses import dataclass +from datetime import datetime +from spd.harvest.config import HarvestSlurmConfig +from spd.harvest.scripts import run_merge as harvest_merge +from spd.harvest.scripts import run_worker as harvest_worker from spd.log import logger -from spd.settings import DEFAULT_PARTITION_NAME from spd.utils.git_utils import create_git_snapshot from spd.utils.slurm import ( SlurmArrayConfig, SlurmConfig, + SubmitResult, generate_array_script, generate_script, submit_slurm_job, ) -from spd.utils.wandb_utils import wandb_path_to_url - - -def harvest( - wandb_path: str, - n_gpus: int, - n_batches: int | None = None, - batch_size: int = 256, - ci_threshold: float = 1e-6, - activation_examples_per_component: int = 1000, - activation_context_tokens_per_side: int = 10, - pmi_token_top_k: int = 40, - partition: str = DEFAULT_PARTITION_NAME, - time: str = "24:00:00", + + +@dataclass +class HarvestSubmitResult: + array_result: SubmitResult + merge_result: SubmitResult + subrun_id: str + + +def submit_harvest( + config: HarvestSlurmConfig, job_suffix: str | None = None, -) -> None: + snapshot_branch: str | None = None, + dependency_job_id: str | None = None, +) -> HarvestSubmitResult: """Submit multi-GPU harvest job to SLURM. Submits a job array where each task processes a subset of batches, then - submits a merge job that depends on all workers completing. Creates a git - snapshot to ensure consistent code across all workers. - - Args: - wandb_path: WandB run path for the target decomposition run. - n_batches: Total number of batches to process (divided among workers). - If None, processes entire training dataset. - n_gpus: Number of GPUs (each gets its own array task). - batch_size: Batch size for processing. - ci_threshold: CI threshold for component activation. - activation_examples_per_component: Number of activation examples per component. - activation_context_tokens_per_side: Number of tokens per side of the activation context. - pmi_token_top_k: Number of top- and bottom-k tokens by PMI to include. - partition: SLURM partition name. - time: Job time limit for worker jobs. - job_suffix: Optional suffix for SLURM job names (e.g., "v2" -> "spd-harvest-v2"). + submits a merge job that depends on all workers completing. """ - launch_id = f"harvest-{secrets.token_hex(4)}" - snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=launch_id) - logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + n_gpus = config.n_gpus + partition = config.partition + time = config.time + + if snapshot_branch is None: + run_id = f"harvest-{secrets.token_hex(4)}" + snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=run_id) + logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + else: + commit_hash = "shared" + + subrun_id = "h-" + datetime.now().strftime("%Y%m%d_%H%M%S") suffix = f"-{job_suffix}" if job_suffix else "" array_job_name = f"spd-harvest{suffix}" - # Build worker commands (SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc.) worker_commands = [] for rank in range(n_gpus): - n_batches_arg = f"--n_batches {n_batches} " if n_batches is not None else "" - cmd = ( - f"python -m spd.harvest.scripts.run " - f'"{wandb_path}" ' - f"{n_batches_arg}" - f"--batch_size {batch_size} " - f"--ci_threshold {ci_threshold} " - f"--activation_examples_per_component {activation_examples_per_component} " - f"--activation_context_tokens_per_side {activation_context_tokens_per_side} " - f"--pmi_token_top_k {pmi_token_top_k} " - f"--rank {rank} " - f"--world_size {n_gpus}" + cmd = harvest_worker.get_command( + config.config, + rank=rank, + world_size=n_gpus, + subrun_id=subrun_id, ) worker_commands.append(cmd) - wandb_url = wandb_path_to_url(wandb_path) - array_config = SlurmArrayConfig( job_name=array_job_name, partition=partition, - n_gpus=1, # 1 GPU per worker + n_gpus=1, time=time, snapshot_branch=snapshot_branch, - comment=wandb_url, + dependency_job_id=dependency_job_id, + comment=config.config.method_config.id, ) array_script = generate_array_script(array_config, worker_commands) array_result = submit_slurm_job( @@ -100,33 +88,40 @@ def harvest( n_array_tasks=n_gpus, ) - # Submit merge job with dependency on array completion - merge_cmd = f'python -m spd.harvest.scripts.run "{wandb_path}" --merge' + merge_command = harvest_merge.get_command( + subrun_id, + config.config, + ) merge_config = SlurmConfig( job_name="spd-harvest-merge", partition=partition, - n_gpus=0, # No GPU needed for merge - time="01:00:00", # Merge is quick + n_gpus=0, + time=config.merge_time, + mem=config.merge_mem, snapshot_branch=snapshot_branch, dependency_job_id=array_result.job_id, - comment=wandb_url, + comment=config.config.method_config.id, ) - merge_script = generate_script(merge_config, merge_cmd) + merge_script = generate_script(merge_config, merge_command) merge_result = submit_slurm_job(merge_script, "harvest_merge") logger.section("Harvest jobs submitted!") logger.values( { - "WandB path": wandb_path, - "N batches": n_batches, + "Sub-run ID": subrun_id, + "N batches": config.config.n_batches, "N GPUs": n_gpus, - "Batch size": batch_size, + "Batch size": config.config.batch_size, "Snapshot": f"{snapshot_branch} ({commit_hash[:8]})", "Array Job ID": array_result.job_id, "Merge Job ID": merge_result.job_id, "Worker logs": array_result.log_pattern, "Merge log": merge_result.log_pattern, - "Array script": str(array_result.script_path), - "Merge script": str(merge_result.script_path), } ) + + return HarvestSubmitResult( + array_result=array_result, + merge_result=merge_result, + subrun_id=subrun_id, + ) diff --git a/spd/harvest/scripts/run_slurm_cli.py b/spd/harvest/scripts/run_slurm_cli.py index 40045bb98..00f6d7c0c 100644 --- a/spd/harvest/scripts/run_slurm_cli.py +++ b/spd/harvest/scripts/run_slurm_cli.py @@ -3,66 +3,29 @@ Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. Usage: - spd-harvest --n_gpus 24 - spd-harvest --n_batches 1000 --n_gpus 8 # Only process 1000 batches + spd-harvest --n_gpus 8 + spd-harvest --config harvest_config.yaml """ import fire -from spd.settings import DEFAULT_PARTITION_NAME - def harvest( - wandb_path: str, - n_gpus: int, - n_batches: int | None = None, - batch_size: int = 256, - ci_threshold: float = 1e-6, - activation_examples_per_component: int = 1000, - activation_context_tokens_per_side: int = 10, - pmi_token_top_k: int = 40, - partition: str = DEFAULT_PARTITION_NAME, - time: str = "24:00:00", + config: str, job_suffix: str | None = None, ) -> None: """Submit multi-GPU harvest job to SLURM. - Submits a job array where each GPU processes a subset of batches, - then a merge job that combines results after all workers complete. - - Examples: - spd-harvest wandb:spd/runs/abc123 --n_gpus 24 - spd-harvest wandb:spd/runs/abc123 --n_batches 1000 --n_gpus 8 # Only process 1000 batches - Args: wandb_path: WandB run path for the target decomposition run. - n_batches: Total number of batches to process (divided among workers). - If None, processes entire training dataset. - n_gpus: Number of GPUs (each gets its own array task). - batch_size: Batch size for processing. - ci_threshold: CI threshold for component activation. - activation_examples_per_component: Number of activation examples per component. - activation_context_tokens_per_side: Number of tokens per side of the activation context. - pmi_token_top_k: Number of top- and bottom-k tokens by PMI to include. - partition: SLURM partition name. - time: Job time limit for worker jobs. + config: Path to HarvestSlurmConfig YAML/JSON. Uses built-in defaults if omitted. job_suffix: Optional suffix for SLURM job names (e.g., "v2" -> "spd-harvest-v2"). """ - from spd.harvest.scripts.run_slurm import harvest as harvest_impl + from spd.harvest.config import HarvestSlurmConfig + from spd.harvest.scripts.run_slurm import submit_harvest - harvest_impl( - wandb_path=wandb_path, - n_batches=n_batches, - n_gpus=n_gpus, - batch_size=batch_size, - ci_threshold=ci_threshold, - activation_examples_per_component=activation_examples_per_component, - activation_context_tokens_per_side=activation_context_tokens_per_side, - pmi_token_top_k=pmi_token_top_k, - partition=partition, - time=time, - job_suffix=job_suffix, - ) + slurm_config = HarvestSlurmConfig.from_file(config) + submit_harvest(config=slurm_config, job_suffix=job_suffix) def cli() -> None: diff --git a/spd/harvest/scripts/run_worker.py b/spd/harvest/scripts/run_worker.py new file mode 100644 index 000000000..9c3d1c582 --- /dev/null +++ b/spd/harvest/scripts/run_worker.py @@ -0,0 +1,72 @@ +"""Harvest worker: collects component statistics on a single GPU. + +Usage: + python -m spd.harvest.scripts.run_worker --config_json '{"n_batches": 100}' + python -m spd.harvest.scripts.run_worker --config_json '...' --rank 0 --world_size 4 --subrun_id h-20260211_120000 +""" + +from datetime import datetime +from typing import Any + +import fire +import torch + +from spd.adapters import adapter_from_id +from spd.harvest.config import HarvestConfig +from spd.harvest.harvest import harvest +from spd.harvest.harvest_fn import make_harvest_fn +from spd.harvest.schemas import get_harvest_subrun_dir +from spd.log import logger +from spd.utils.distributed_utils import get_device + + +def main( + config_json: dict[str, Any], + rank: int | None = None, + world_size: int | None = None, + subrun_id: str | None = None, +) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + assert (rank is not None) == (world_size is not None) + + if subrun_id is None: + subrun_id = "h-" + datetime.now().strftime("%Y%m%d_%H%M%S") + device = torch.device(get_device()) + + config = HarvestConfig.model_validate(config_json) + + adapter = adapter_from_id(config.method_config.id) + + output_dir = get_harvest_subrun_dir(adapter.decomposition_id, subrun_id) + + if rank is not None: + logger.info(f"Distributed harvest: rank {rank}/{world_size}, subrun {subrun_id}") + else: + logger.info(f"Single-GPU harvest: subrun {subrun_id}") + + harvest( + layers=adapter.layer_activation_sizes, + vocab_size=adapter.vocab_size, + dataloader=adapter.dataloader(config.batch_size), + harvest_fn=make_harvest_fn(device, config.method_config, adapter), + config=config, + output_dir=output_dir, + rank_world_size=(rank, world_size) if rank is not None and world_size is not None else None, + device=device, + ) + + +def get_command(config: HarvestConfig, rank: int, world_size: int, subrun_id: str) -> str: + config_json = config.model_dump_json(exclude_none=True) + cmd = ( + f"python -m spd.harvest.scripts.run_worker " + f"--config_json '{config_json}' " + f"--rank {rank} " + f"--world_size {world_size} " + f"--subrun_id {subrun_id}" + ) + return cmd + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/harvest/storage.py b/spd/harvest/storage.py index a176e3701..2111a0167 100644 --- a/spd/harvest/storage.py +++ b/spd/harvest/storage.py @@ -4,6 +4,7 @@ For query functionality, see harvest/analysis.py. """ +import math from dataclasses import dataclass from pathlib import Path @@ -26,6 +27,32 @@ class CorrelationStorage: count_total: int """Total tokens seen""" + _key_to_idx: dict[str, int] | None = None + + @property + def key_to_idx(self) -> dict[str, int]: + """Cached mapping from component key to index.""" + if self._key_to_idx is None: + self._key_to_idx = {k: i for i, k in enumerate(self.component_keys)} + return self._key_to_idx + + def pmi(self, key_a: str, key_b: str) -> float | None: + """Point-wise mutual information between two components. + + Returns None if either component is missing or they never co-fire. + """ + if key_a not in self.key_to_idx or key_b not in self.key_to_idx: + return None + i, j = self.key_to_idx[key_a], self.key_to_idx[key_b] + count_ij = self.count_ij[i][j].item() + if count_ij == 0: + return None + count_i = self.count_i[i].item() + count_j = self.count_i[j].item() + if count_i == 0 or count_j == 0: + return None + return math.log(count_ij * self.count_total / (count_i * count_j)) + def save(self, path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) torch.save( @@ -70,6 +97,15 @@ class TokenStatsStorage: output_totals: Float[Tensor, " vocab"] firing_counts: Float[Tensor, " n_components"] + _key_to_idx: dict[str, int] | None = None + + @property + def key_to_idx(self) -> dict[str, int]: + """Cached mapping from component key to index.""" + if self._key_to_idx is None: + self._key_to_idx = {k: i for i, k in enumerate(self.component_keys)} + return self._key_to_idx + def save(self, path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) torch.save( diff --git a/spd/investigate/CLAUDE.md b/spd/investigate/CLAUDE.md new file mode 100644 index 000000000..922734220 --- /dev/null +++ b/spd/investigate/CLAUDE.md @@ -0,0 +1,118 @@ +# Investigation Module + +Launch a Claude Code agent to investigate a specific research question about an SPD model decomposition. + +## Usage + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What circuit handles verb agreement?" --max_turns 30 --time 4:00:00 +``` + +For parallel investigations, run the command multiple times with different prompts. + +## Architecture + +``` +spd/investigate/ +├── __init__.py # Public exports +├── CLAUDE.md # This file +├── schemas.py # Pydantic models for outputs (BehaviorExplanation, InvestigationEvent) +├── agent_prompt.py # System prompt template with model info injection +└── scripts/ + ├── __init__.py + ├── run_slurm_cli.py # CLI entry point (spd-investigate) + ├── run_slurm.py # SLURM submission logic + └── run_agent.py # Worker script (runs in SLURM job) +``` + +## How It Works + +1. `spd-investigate` creates output dir, metadata, git snapshot, and submits a single SLURM job +2. The SLURM job runs `run_agent.py` which: + - Starts an isolated FastAPI backend with MCP support + - Loads the SPD run onto GPU + - Fetches model architecture info + - Generates the agent prompt (research question + model context + methodology) + - Launches Claude Code with MCP tools +3. The agent investigates using MCP tools and writes findings to the output directory + +## MCP Tools + +The agent accesses all SPD functionality via MCP at `/mcp`: + +**Circuit Discovery:** +- `optimize_graph` — Find minimal circuit for a behavior (streams progress) +- `create_prompt` — Tokenize text and get next-token probabilities + +**Component Analysis:** +- `get_component_info` — Interpretation, token stats, correlations +- `probe_component` — Fast CI probing on custom text +- `get_component_activation_examples` — Training examples where a component fires +- `get_component_attributions` — Dataset-level component dependencies +- `get_attribution_strength` — Attribution between specific component pairs + +**Testing:** +- `run_ablation` — Test circuit with only selected components +- `search_dataset` — Search training data + +**Metadata:** +- `get_model_info` — Architecture details + +**Output:** +- `update_research_log` — Append to research log (PRIMARY OUTPUT) +- `save_graph_artifact` — Save graph for inline visualization +- `save_explanation` — Save complete behavior explanation +- `set_investigation_summary` — Set title/summary for UI + +## Output Structure + +``` +SPD_OUT_DIR/investigations// +├── metadata.json # Investigation config (wandb_path, prompt, etc.) +├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) +├── events.jsonl # Structured progress events +├── explanations.jsonl # Complete behavior explanations +├── summary.json # Agent-provided title/summary for UI +├── artifacts/ # Graph artifacts for visualization +│ └── graph_001.json +├── app.db # Isolated SQLite database +├── backend.log # Backend subprocess output +├── claude_output.jsonl # Raw Claude Code output +├── agent_prompt.md # The prompt given to the agent +└── mcp_config.json # MCP server configuration +``` + +## Environment + +The backend runs with `SPD_INVESTIGATION_DIR` set to the investigation directory. This controls: +- Database location: `/app.db` +- Events log: `/events.jsonl` +- Research log: `/research_log.md` + +## Configuration + +CLI arguments: +- `wandb_path` — Required. WandB run path for the SPD decomposition. +- `prompt` — Required. Research question or investigation directive. +- `--context_length` — Token context length (default: 128) +- `--max_turns` — Max Claude turns (default: 50, prevents runaway) +- `--partition` — SLURM partition (default: h200-reserved) +- `--time` — Job time limit (default: 8:00:00) +- `--job_suffix` — Optional suffix for job names + +## Monitoring + +```bash +# Watch research log +tail -f SPD_OUT_DIR/investigations//research_log.md + +# Watch events +tail -f SPD_OUT_DIR/investigations//events.jsonl + +# View explanations +cat SPD_OUT_DIR/investigations//explanations.jsonl | jq . + +# Check SLURM job status +squeue --me +``` diff --git a/spd/investigate/__init__.py b/spd/investigate/__init__.py new file mode 100644 index 000000000..9e666dd7d --- /dev/null +++ b/spd/investigate/__init__.py @@ -0,0 +1,22 @@ +"""Investigation: SLURM-based agent investigation of model behaviors. + +This module provides infrastructure for launching a Claude Code agent to investigate +behaviors in an SPD model decomposition. Each investigation: +1. Starts an isolated app backend instance (separate database, unique port) +2. Receives a specific research question and detailed instructions +3. Investigates behaviors and writes findings to append-only JSONL files +""" + +from spd.investigate.schemas import ( + BehaviorExplanation, + ComponentInfo, + Evidence, + InvestigationEvent, +) + +__all__ = [ + "BehaviorExplanation", + "ComponentInfo", + "Evidence", + "InvestigationEvent", +] diff --git a/spd/investigate/agent_prompt.py b/spd/investigate/agent_prompt.py new file mode 100644 index 000000000..aae26f6da --- /dev/null +++ b/spd/investigate/agent_prompt.py @@ -0,0 +1,202 @@ +"""System prompt for SPD investigation agents. + +This module contains the detailed instructions given to the investigation agent. +The agent has access to SPD tools via MCP - tools are self-documenting. +""" + +from typing import Any + +AGENT_SYSTEM_PROMPT = """ +# SPD Behavior Investigation Agent + +You are a research agent investigating behaviors in a neural network model decomposition. +A researcher has given you a specific question to investigate. Your job is to answer it +thoroughly using the SPD analysis tools available to you. + +## Your Mission + +{prompt} + +## Available Tools (via MCP) + +You have access to SPD analysis tools. Use them directly - they have full documentation. + +**Circuit Discovery:** +- **optimize_graph**: Find the minimal circuit for a behavior (e.g., "boy" → "he") +- **create_prompt**: Tokenize text and get next-token probabilities + +**Component Analysis:** +- **get_component_info**: Get interpretation and token stats for a component +- **probe_component**: Fast CI probing - test if a component activates on specific text +- **get_component_activation_examples**: See training examples where a component fires +- **get_component_attributions**: Dataset-level component dependencies (sources and targets) +- **get_attribution_strength**: Query attribution strength between two specific components + +**Testing:** +- **run_ablation**: Test a circuit by running with only selected components +- **search_dataset**: Find examples in the training data + +**Metadata:** +- **get_model_info**: Get model architecture details +- **get_stored_graphs**: Retrieve previously computed graphs + +**Output:** +- **update_research_log**: Append to your research log (PRIMARY OUTPUT - use frequently!) +- **save_graph_artifact**: Save a graph for inline visualization in your research log +- **save_explanation**: Save a complete, validated behavior explanation +- **set_investigation_summary**: Set a title and summary for your investigation + +## Investigation Methodology + +### Step 1: Understand the Question + +Read the research question carefully. Think about what behaviors, components, or mechanisms +might be relevant. Use `get_model_info` if you need to understand the model architecture. + +### Step 2: Explore and Hypothesize + +- Use `create_prompt` to test prompts and see what the model predicts +- Use `search_dataset` to find relevant examples in the training data +- Use `probe_component` to quickly test whether specific components respond to your prompts +- Use `get_component_info` to understand what components do + +### Step 3: Find Circuits + +- Use `optimize_graph` to find the minimal circuit for specific behaviors +- Examine which components have high CI values +- Note the circuit size (fewer active components = cleaner mechanism) + +### Step 4: Understand Component Roles + +For each important component in a circuit: +1. Use `get_component_info` for interpretation and token associations +2. Use `probe_component` to test activation on different inputs +3. Use `get_component_activation_examples` to see training examples +4. Use `get_component_attributions` to understand information flow +5. Check correlated components for related functions + +### Step 5: Test with Ablations + +Form hypotheses and test them: +1. Use `run_ablation` with the circuit's components +2. Verify predictions match expectations +3. Try removing individual components to find critical ones + +### Step 6: Document Your Findings + +Use `update_research_log` frequently - this is how humans monitor your work! +When you have a complete explanation, use `save_explanation` to create a structured record. + +## Scientific Principles + +- **Be skeptical**: Your first hypothesis is probably incomplete +- **Triangulate**: Don't rely on a single type of evidence +- **Document uncertainty**: Note what you're confident in vs. uncertain about +- **Consider alternatives**: What else could explain the behavior? + +## Output Format + +### Research Log (PRIMARY OUTPUT - Update frequently!) + +Use `update_research_log` with markdown content. Call it every few minutes to show progress: + +Example calls: +``` +update_research_log("## Hypothesis: Gendered Pronoun Circuit\\n\\nTesting prompt: 'The boy said that' → expecting ' he'\\n\\n") + +update_research_log("## Ablation Test\\n\\nResult: P(he) = 0.89 (vs 0.22 baseline)\\n\\nThis confirms the circuit is sufficient!\\n\\n") +``` + +### Including Graph Visualizations + +After running `optimize_graph`, embed the circuit visualization in your research log: + +1. Call `save_graph_artifact` with the graph_id returned by optimize_graph +2. Reference it in your research log using the `spd:graph` code block + +Example: +``` +save_graph_artifact(graph_id=42, caption="Circuit predicting 'he' after 'The boy'") + +update_research_log('''## Circuit Visualization + +```spd:graph +artifact: graph_001 +``` + +This circuit shows the key components involved in predicting "he"... +''') +``` + +### Saving Explanations + +When you have a complete explanation, use `save_explanation`: + +``` +save_explanation( + subject_prompt="The boy said that", + behavior_description="Predicts masculine pronoun 'he' after male subject", + components_involved=[ + {{"component_key": "h.0.mlp.c_fc:407", "role": "Male subject detector"}}, + {{"component_key": "h.3.attn.o_proj:262", "role": "Masculine pronoun promoter"}} + ], + explanation="Component h.0.mlp.c_fc:407 activates on male subjects...", + confidence="medium", + limitations=["Only tested on simple sentences"] +) +``` + +## Getting Started + +1. **Create your research log** with `update_research_log` +2. Understand the research question and plan your approach +3. Use analysis tools to explore the model +4. **Call `update_research_log` frequently** - humans are watching! +5. Use `save_explanation` for complete findings +6. **Call `set_investigation_summary`** with a title and summary when done + +Document what you learn, even if it's "this was more complicated than expected." +""" + + +def _format_model_info(model_info: dict[str, Any]) -> str: + target_config = model_info["target_model_config"] + topology = model_info["topology"] + block = topology["block_structure"][0] + + return "\n".join( + [ + f"- **Architecture**: {model_info['summary']}", + f"- **Layers**: {target_config['n_layer']}", + f"- **Hidden dim**: {target_config['n_embd']}", + f"- **Vocab size**: {target_config['vocab_size']}", + f"- **Attention projections**: {', '.join(block['attn_projections'])}", + f"- **FFN projections**: {', '.join(block['ffn_projections'])}", + ] + ) + + +def get_agent_prompt( + wandb_path: str, + prompt: str, + model_info: dict[str, Any], +) -> str: + """Generate the full agent prompt with runtime parameters filled in.""" + formatted_prompt = AGENT_SYSTEM_PROMPT.format(prompt=prompt) + + model_section = f""" +## Model Architecture + +{_format_model_info(model_info)} + +## Runtime Context + +- **Model Run**: {wandb_path} + +Use the MCP tools for ALL output: +- `update_research_log` → **PRIMARY OUTPUT** - Update frequently with your progress! +- `save_explanation` → Save complete, validated behavior explanations + +**Start by calling update_research_log to create your log, then investigate!** +""" + return formatted_prompt + model_section diff --git a/spd/investigate/schemas.py b/spd/investigate/schemas.py new file mode 100644 index 000000000..d4da1a896 --- /dev/null +++ b/spd/investigate/schemas.py @@ -0,0 +1,104 @@ +"""Schemas for investigation outputs. + +All agent outputs are append-only JSONL files. Each line is a JSON object +conforming to one of the schemas defined here. +""" + +from datetime import UTC, datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ComponentInfo(BaseModel): + """Information about a component involved in a behavior.""" + + component_key: str = Field( + ..., + description="Component key in format 'layer:component_idx' (e.g., 'h.0.mlp.c_fc:5')", + ) + role: str = Field( + ..., + description="The role this component plays in the behavior (e.g., 'stores subject gender')", + ) + interpretation: str | None = Field( + default=None, + description="Auto-interp label for this component if available", + ) + + +class Evidence(BaseModel): + """A piece of supporting evidence for an explanation.""" + + evidence_type: Literal["ablation", "attribution", "activation_pattern", "correlation", "other"] + description: str = Field( + ..., + description="Description of the evidence", + ) + details: dict[str, Any] = Field( + default_factory=dict, + description="Additional structured details (e.g., ablation results, attribution values)", + ) + + +class BehaviorExplanation(BaseModel): + """A candidate explanation for a behavior discovered by an agent. + + This is the primary output schema for agent investigations. Each explanation + describes a behavior (demonstrated by a subject prompt), the components involved, + and supporting evidence. + """ + + subject_prompt: str = Field( + ..., + description="A prompt that demonstrates the behavior being explained", + ) + behavior_description: str = Field( + ..., + description="Clear description of the behavior (e.g., 'correctly predicts gendered pronoun')", + ) + components_involved: list[ComponentInfo] = Field( + ..., + description="List of components involved in this behavior and their roles", + ) + explanation: str = Field( + ..., + description="Explanation of how the components work together to produce the behavior", + ) + supporting_evidence: list[Evidence] = Field( + default_factory=list, + description="Evidence supporting this explanation (ablations, attributions, etc.)", + ) + confidence: Literal["high", "medium", "low"] = Field( + ..., + description="Agent's confidence in this explanation", + ) + alternative_hypotheses: list[str] = Field( + default_factory=list, + description="Alternative hypotheses that were considered but not fully supported", + ) + limitations: list[str] = Field( + default_factory=list, + description="Known limitations of this explanation", + ) + + +class InvestigationEvent(BaseModel): + """A generic event logged by an agent during investigation. + + Used for logging progress, observations, and other non-explanation events. + """ + + event_type: Literal[ + "start", + "progress", + "observation", + "hypothesis", + "test_result", + "explanation", + "error", + "complete", + ] + timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) + message: str + details: dict[str, Any] = Field(default_factory=dict) diff --git a/spd/investigate/scripts/__init__.py b/spd/investigate/scripts/__init__.py new file mode 100644 index 000000000..ff51f7654 --- /dev/null +++ b/spd/investigate/scripts/__init__.py @@ -0,0 +1 @@ +"""Investigation SLURM scripts.""" diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py new file mode 100644 index 000000000..0ec623ccc --- /dev/null +++ b/spd/investigate/scripts/run_agent.py @@ -0,0 +1,282 @@ +"""Worker script that runs inside each SLURM job. + +This script: +1. Reads the research question from the investigation metadata +2. Starts the app backend with an isolated database +3. Loads the SPD run and fetches model architecture info +4. Configures MCP server for Claude Code +5. Launches Claude Code with the investigation question +6. Handles cleanup on exit +""" + +import json +import os +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from types import FrameType +from typing import Any + +import fire +import requests + +from spd.investigate.agent_prompt import get_agent_prompt +from spd.investigate.schemas import InvestigationEvent +from spd.investigate.scripts.run_slurm import get_investigation_output_dir +from spd.log import logger + + +def write_mcp_config(inv_dir: Path, port: int) -> Path: + mcp_config = { + "mcpServers": { + "spd": { + "type": "http", + "url": f"http://localhost:{port}/mcp", + } + } + } + config_path = inv_dir / "mcp_config.json" + config_path.write_text(json.dumps(mcp_config, indent=2)) + return config_path + + +def write_claude_settings(inv_dir: Path) -> None: + claude_dir = inv_dir / ".claude" + claude_dir.mkdir(exist_ok=True) + settings = {"permissions": {"allow": ["mcp__spd__*"]}} + (claude_dir / "settings.json").write_text(json.dumps(settings, indent=2)) + + +def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: + for offset in range(max_attempts): + port = start_port + offset + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except OSError: + continue + raise RuntimeError( + f"Could not find available port in range {start_port}-{start_port + max_attempts}" + ) + + +def wait_for_backend(port: int, timeout: float = 120.0) -> None: + url = f"http://localhost:{port}/api/health" + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + return + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + raise RuntimeError(f"Backend on port {port} failed to start within {timeout}s") + + +def load_run(port: int, wandb_path: str, context_length: int) -> None: + url = f"http://localhost:{port}/api/runs/load" + params = {"wandb_path": wandb_path, "context_length": context_length} + resp = requests.post(url, params=params, timeout=300) + assert resp.status_code == 200, ( + f"Failed to load run {wandb_path}: {resp.status_code} {resp.text}" + ) + + +def fetch_model_info(port: int) -> dict[str, Any]: + resp = requests.get(f"http://localhost:{port}/api/pretrain_info/loaded", timeout=30) + assert resp.status_code == 200, f"Failed to fetch model info: {resp.status_code} {resp.text}" + return resp.json() + + +def log_event(events_path: Path, event: InvestigationEvent) -> None: + with open(events_path, "a") as f: + f.write(event.model_dump_json() + "\n") + + +def run_agent(inv_id: str) -> None: + """Run a single investigation agent. All config read from metadata.json.""" + inv_dir = get_investigation_output_dir(inv_id) + assert inv_dir.exists(), f"Investigation directory does not exist: {inv_dir}" + + metadata: dict[str, Any] = json.loads((inv_dir / "metadata.json").read_text()) + wandb_path: str = metadata["wandb_path"] + prompt: str = metadata["prompt"] + context_length: int = metadata["context_length"] + max_turns: int = metadata["max_turns"] + + write_claude_settings(inv_dir) + + events_path = inv_dir / "events.jsonl" + (inv_dir / "explanations.jsonl").touch() + + log_event( + events_path, + InvestigationEvent( + event_type="start", + message=f"Investigation {inv_id} starting", + details={"wandb_path": wandb_path, "inv_id": inv_id, "prompt": prompt}, + ), + ) + + port = find_available_port() + logger.info(f"[{inv_id}] Using port {port}") + + log_event( + events_path, + InvestigationEvent( + event_type="progress", + message=f"Starting backend on port {port}", + details={"port": port}, + ), + ) + + # Start backend with investigation configuration + env = os.environ.copy() + env["SPD_INVESTIGATION_DIR"] = str(inv_dir) + + backend_cmd = [ + sys.executable, + "-m", + "spd.app.backend.server", + "--port", + str(port), + ] + + backend_log_path = inv_dir / "backend.log" + backend_log = open(backend_log_path, "w") # noqa: SIM115 - managed manually + backend_proc = subprocess.Popen( + backend_cmd, + env=env, + stdout=backend_log, + stderr=subprocess.STDOUT, + ) + + def cleanup(signum: int | None = None, _frame: FrameType | None = None) -> None: + logger.info(f"[{inv_id}] Cleaning up...") + if backend_proc.poll() is None: + backend_proc.terminate() + try: + backend_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + backend_proc.kill() + backend_log.close() + if signum is not None: + sys.exit(1) + + signal.signal(signal.SIGTERM, cleanup) + signal.signal(signal.SIGINT, cleanup) + + try: + logger.info(f"[{inv_id}] Waiting for backend...") + wait_for_backend(port) + + logger.info(f"[{inv_id}] Backend ready, loading run...") + log_event( + events_path, + InvestigationEvent(event_type="progress", message="Backend ready, loading run"), + ) + + load_run(port, wandb_path, context_length) + + logger.info(f"[{inv_id}] Run loaded, fetching model info...") + model_info = fetch_model_info(port) + + logger.info(f"[{inv_id}] Launching Claude Code...") + log_event( + events_path, + InvestigationEvent( + event_type="progress", message="Run loaded, launching Claude Code agent" + ), + ) + + agent_prompt = get_agent_prompt( + wandb_path=wandb_path, + prompt=prompt, + model_info=model_info, + ) + + (inv_dir / "agent_prompt.md").write_text(agent_prompt) + + mcp_config_path = write_mcp_config(inv_dir, port) + logger.info(f"[{inv_id}] MCP config written to {mcp_config_path}") + + claude_output_path = inv_dir / "claude_output.jsonl" + claude_cmd = [ + "claude", + "--print", + "--verbose", + "--output-format", + "stream-json", + "--max-turns", + str(max_turns), + # MCP: only our backend, no inherited servers + "--mcp-config", + str(mcp_config_path), + # Permissions: only MCP tools, deny everything else + "--permission-mode", + "dontAsk", + "--allowedTools", + "mcp__spd__*", + # Isolation: skip all user/project settings (no plugins, no inherited config) + "--setting-sources", + "", + "--model", + "opus", + ] + + logger.info(f"[{inv_id}] Starting Claude Code (max_turns={max_turns})...") + logger.info(f"[{inv_id}] Monitor with: tail -f {claude_output_path}") + + with open(claude_output_path, "w") as output_file: + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=output_file, + stderr=subprocess.STDOUT, + text=True, + cwd=str(inv_dir), + ) + + assert claude_proc.stdin is not None + claude_proc.stdin.write(agent_prompt) + claude_proc.stdin.close() + + claude_proc.wait() + + log_event( + events_path, + InvestigationEvent( + event_type="complete", + message="Investigation complete", + details={"exit_code": claude_proc.returncode}, + ), + ) + + logger.info(f"[{inv_id}] Investigation complete") + + except Exception as e: + log_event( + events_path, + InvestigationEvent( + event_type="error", + message=f"Agent failed: {e}", + details={"error_type": type(e).__name__}, + ), + ) + logger.error(f"[{inv_id}] Failed: {e}") + raise + finally: + cleanup() + + +def cli() -> None: + fire.Fire(run_agent) + + +if __name__ == "__main__": + cli() diff --git a/spd/investigate/scripts/run_slurm.py b/spd/investigate/scripts/run_slurm.py new file mode 100644 index 000000000..b42f8450e --- /dev/null +++ b/spd/investigate/scripts/run_slurm.py @@ -0,0 +1,88 @@ +"""SLURM submission logic for investigation jobs.""" + +import json +import secrets +import sys +from dataclasses import dataclass +from pathlib import Path + +from spd.log import logger +from spd.settings import DEFAULT_PARTITION_NAME, SPD_OUT_DIR +from spd.utils.git_utils import create_git_snapshot +from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job +from spd.utils.wandb_utils import parse_wandb_run_path + + +@dataclass +class InvestigationResult: + inv_id: str + job_id: str + output_dir: Path + + +def get_investigation_output_dir(inv_id: str) -> Path: + return SPD_OUT_DIR / "investigations" / inv_id + + +def launch_investigation( + wandb_path: str, + prompt: str, + context_length: int, + max_turns: int, + time: str, + job_suffix: str | None, +) -> InvestigationResult: + """Launch a single investigation agent via SLURM. + + Creates a SLURM job that starts an isolated app backend, loads the SPD run, + and launches a Claude Code agent with the given research question. + """ + # Normalize wandb_path to canonical form (entity/project/run_id) + entity, project, run_id = parse_wandb_run_path(wandb_path) + canonical_wandb_path = f"{entity}/{project}/{run_id}" + + inv_id = f"inv-{secrets.token_hex(4)}" + output_dir = get_investigation_output_dir(inv_id) + output_dir.mkdir(parents=True, exist_ok=True) + + snapshot_branch, commit_hash = create_git_snapshot(inv_id) + + suffix = f"-{job_suffix}" if job_suffix else "" + job_name = f"spd-investigate{suffix}" + + metadata = { + "inv_id": inv_id, + "wandb_path": canonical_wandb_path, + "prompt": prompt, + "context_length": context_length, + "max_turns": max_turns, + "snapshot_branch": snapshot_branch, + "commit_hash": commit_hash, + } + (output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2)) + + cmd = f"{sys.executable} -m spd.investigate.scripts.run_agent {inv_id}" + + slurm_config = SlurmConfig( + job_name=job_name, + partition=DEFAULT_PARTITION_NAME, + n_gpus=1, + time=time, + snapshot_branch=snapshot_branch, + ) + script = generate_script(slurm_config, cmd) + result = submit_slurm_job(script, "investigate") + + logger.section("Investigation submitted") + logger.values( + { + "Investigation ID": inv_id, + "Job ID": result.job_id, + "WandB path": canonical_wandb_path, + "Prompt": prompt[:100] + ("..." if len(prompt) > 100 else ""), + "Output directory": str(output_dir), + "Logs": result.log_pattern, + } + ) + + return InvestigationResult(inv_id=inv_id, job_id=result.job_id, output_dir=output_dir) diff --git a/spd/investigate/scripts/run_slurm_cli.py b/spd/investigate/scripts/run_slurm_cli.py new file mode 100644 index 000000000..6a8dd13af --- /dev/null +++ b/spd/investigate/scripts/run_slurm_cli.py @@ -0,0 +1,54 @@ +"""CLI entry point for investigation SLURM launcher. + +Usage: + spd-investigate "" + spd-investigate @prompt.txt + spd-investigate "" --max_turns 30 +""" + +from pathlib import Path + +import fire + + +def _resolve_prompt(prompt: str) -> str: + """If prompt starts with @, read from that file path. Otherwise return as-is.""" + if prompt.startswith("@"): + path = Path(prompt[1:]) + assert path.exists(), f"Prompt file not found: {path}" + return path.read_text().strip() + return prompt + + +def main( + wandb_path: str, + prompt: str, + context_length: int = 128, + max_turns: int = 50, + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a single investigation agent for a specific question. + + Args: + wandb_path: WandB run path for the SPD decomposition to investigate. + prompt: The research question, or @filepath to read from a file. + context_length: Context length for prompts (default 128). + max_turns: Maximum agentic turns (default 50, prevents runaway). + time: Job time limit (default 8 hours). + job_suffix: Optional suffix for SLURM job names. + """ + from spd.investigate.scripts.run_slurm import launch_investigation + + launch_investigation( + wandb_path=wandb_path, + prompt=_resolve_prompt(prompt), + context_length=context_length, + max_turns=max_turns, + time=time, + job_suffix=job_suffix, + ) + + +def cli() -> None: + fire.Fire(main) diff --git a/spd/losses.py b/spd/losses.py index daef1773c..a35654bee 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -1,6 +1,5 @@ from typing import Literal -import torch from jaxtyping import Float, Int from torch import Tensor @@ -11,6 +10,8 @@ FaithfulnessLossConfig, ImportanceMinimalityLossConfig, LossMetricConfigType, + PersistentPGDReconLossConfig, + PersistentPGDReconSubsetLossConfig, PGDReconLayerwiseLossConfig, PGDReconLossConfig, PGDReconSubsetLossConfig, @@ -37,28 +38,27 @@ unmasked_recon_loss, ) from spd.models.component_model import CIOutputs, ComponentModel +from spd.persistent_pgd import PersistentPGDState -def compute_total_loss( +def compute_losses( loss_metric_configs: list[LossMetricConfigType], model: ComponentModel, batch: Int[Tensor, "..."], ci: CIOutputs, target_out: Tensor, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], - pre_weight_acts: dict[str, Float[Tensor, "..."]], current_frac_of_training: float, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, + ppgd_states: dict[ + PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, PersistentPGDState + ], output_loss_type: Literal["mse", "kl"], -) -> tuple[Float[Tensor, ""], dict[str, float]]: - """Compute weighted total loss and per-term raw values using new loss primitives. - - Returns (total, terms_dict). terms_dict contains raw per-term values (no coeffs) and a weighted total. - """ - total = torch.tensor(0.0, device=batch.device) - terms: dict[str, float] = {} +) -> dict[LossMetricConfigType, Float[Tensor, ""]]: + """Compute losses for each config and return a dict mapping config to loss tensor.""" + losses: dict[LossMetricConfigType, Float[Tensor, ""]] = {} for cfg in loss_metric_configs: assert cfg.coeff is not None, "All loss metric configs must have a coeff" @@ -179,15 +179,18 @@ def compute_total_loss( sampling=sampling, n_mask_samples=n_mask_samples, batch=batch, - pre_weight_acts=pre_weight_acts, + ci=ci.lower_leaky, + weight_deltas=weight_deltas if use_delta_component else None, + ) + case PersistentPGDReconLossConfig() | PersistentPGDReconSubsetLossConfig(): + loss = ppgd_states[cfg].compute_recon_loss( + model=model, + batch=batch, + target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, ) - terms[f"loss/{cfg.classname}"] = loss.item() - - total = total + cfg.coeff * loss - - terms["loss/total"] = total.item() + losses[cfg] = loss - return total, terms + return losses diff --git a/spd/metrics/__init__.py b/spd/metrics/__init__.py index 522485a95..dc7375c8c 100644 --- a/spd/metrics/__init__.py +++ b/spd/metrics/__init__.py @@ -1,4 +1,10 @@ # Note that "... as ..." allows for these to be imported elsewhere (See PEP 484 on re-exporting) +from .attn_patterns_recon_loss import ( + CIMaskedAttnPatternsReconLoss as CIMaskedAttnPatternsReconLoss, +) +from .attn_patterns_recon_loss import ( + StochasticAttnPatternsReconLoss as StochasticAttnPatternsReconLoss, +) from .ce_and_kl_losses import CEandKLLosses as CEandKLLosses from .ci_histograms import CIHistograms as CIHistograms from .ci_l0 import CI_L0 as CI_L0 @@ -14,6 +20,13 @@ from .component_activation_density import ComponentActivationDensity as ComponentActivationDensity from .faithfulness_loss import FaithfulnessLoss as FaithfulnessLoss from .faithfulness_loss import faithfulness_loss as faithfulness_loss +from .hidden_acts_recon_loss import CIHiddenActsReconLoss as CIHiddenActsReconLoss +from .hidden_acts_recon_loss import ( + StochasticHiddenActsReconLoss as StochasticHiddenActsReconLoss, +) +from .hidden_acts_recon_loss import ( + stochastic_hidden_acts_recon_loss as stochastic_hidden_acts_recon_loss, +) from .identity_ci_error import IdentityCIError as IdentityCIError from .importance_minimality_loss import ImportanceMinimalityLoss as ImportanceMinimalityLoss from .importance_minimality_loss import importance_minimality_loss as importance_minimality_loss @@ -26,12 +39,7 @@ from .pgd_masked_recon_loss import pgd_recon_loss as pgd_recon_loss from .pgd_masked_recon_subset_loss import PGDReconSubsetLoss as PGDReconSubsetLoss from .pgd_masked_recon_subset_loss import pgd_recon_subset_loss as pgd_recon_subset_loss -from .stochastic_hidden_acts_recon_loss import ( - StochasticHiddenActsReconLoss as StochasticHiddenActsReconLoss, -) -from .stochastic_hidden_acts_recon_loss import ( - stochastic_hidden_acts_recon_loss as stochastic_hidden_acts_recon_loss, -) +from .ppgd_eval_losses import PPGDReconEval as PPGDReconEval from .stochastic_recon_layerwise_loss import ( StochasticReconLayerwiseLoss as StochasticReconLayerwiseLoss, ) diff --git a/spd/metrics/attn_patterns_recon_loss.py b/spd/metrics/attn_patterns_recon_loss.py new file mode 100644 index 000000000..b02e7dea3 --- /dev/null +++ b/spd/metrics/attn_patterns_recon_loss.py @@ -0,0 +1,308 @@ +import math +from fnmatch import fnmatch +from typing import Any, ClassVar, override + +import torch +import torch.nn.functional as F +from jaxtyping import Float, Int +from torch import Tensor, nn +from torch.distributed import ReduceOp + +from spd.configs import SamplingType +from spd.metrics.base import Metric +from spd.models.component_model import CIOutputs, ComponentModel +from spd.models.components import ComponentsMaskInfo, make_mask_infos +from spd.routing import AllLayersRouter +from spd.utils.component_utils import calc_stochastic_component_mask_info +from spd.utils.distributed_utils import all_reduce +from spd.utils.general_utils import get_obj_device + + +def _resolve_paths(pattern: str, model: ComponentModel) -> list[str]: + """Resolve an fnmatch pattern against model target module paths.""" + matches = [p for p in model.target_module_paths if fnmatch(p, pattern)] + assert matches, f"Pattern {pattern!r} matched no target module paths" + return sorted(matches) + + +def _resolve_qk_paths( + model: ComponentModel, + q_proj_path: str | None, + k_proj_path: str | None, + c_attn_path: str | None, +) -> tuple[list[str], list[str], bool]: + """Resolve Q/K projection paths, returning (q_paths, k_paths, is_combined). + + For separate Q/K projections: returns matched paths paired by sorted order. + For combined c_attn: returns the same paths for both Q and K. + """ + if c_attn_path is not None: + paths = _resolve_paths(c_attn_path, model) + return paths, paths, True + assert q_proj_path is not None and k_proj_path is not None + q_paths = _resolve_paths(q_proj_path, model) + k_paths = _resolve_paths(k_proj_path, model) + assert len(q_paths) == len(k_paths), f"Q/K path counts differ: {len(q_paths)} vs {len(k_paths)}" + return q_paths, k_paths, False + + +def _resolve_attn_modules( + model: ComponentModel, + q_paths: list[str], +) -> list[nn.Module | None]: + """Derive parent attention module from Q paths, returning it if it has RoPE support. + + For each Q path (e.g. "h.0.attn.q_proj"), strips the last segment to get the parent + attention module (e.g. "h.0.attn"). Returns the module if it has `apply_rotary_pos_emb`, + otherwise None. + """ + result: list[nn.Module | None] = [] + for q_path in q_paths: + parent_path = q_path.rsplit(".", 1)[0] + attn_module = model.target_model.get_submodule(parent_path) + if hasattr(attn_module, "apply_rotary_pos_emb"): + result.append(attn_module) + else: + result.append(None) + return result + + +def _compute_attn_patterns( + q: Float[Tensor, "batch seq d"], + k: Float[Tensor, "batch seq d"], + n_heads: int, + attn_module: nn.Module | None, +) -> Float[Tensor, "batch n_heads seq seq"]: + """Compute causal attention patterns from Q and K projections. + + If attn_module is provided (has RoPE), applies rotary positional embeddings to Q and K + before computing the dot-product attention. + """ + B, S, D = q.shape + head_dim = D // n_heads + q = q.view(B, S, n_heads, head_dim).transpose(1, 2) + k = k.view(B, S, n_heads, head_dim).transpose(1, 2) + + if attn_module is not None: + position_ids = torch.arange(S, device=q.device).unsqueeze(0) + cos = attn_module.rotary_cos[position_ids].to(q.dtype) # pyright: ignore[reportIndexIssue] + sin = attn_module.rotary_sin[position_ids].to(q.dtype) # pyright: ignore[reportIndexIssue] + q, k = attn_module.apply_rotary_pos_emb(q, k, cos, sin) # pyright: ignore[reportCallIssue] + + attn = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim) + causal_mask = torch.triu(torch.ones(S, S, device=q.device, dtype=torch.bool), diagonal=1) + attn = attn.masked_fill(causal_mask, float("-inf")) + return F.softmax(attn, dim=-1) + + +def _split_combined_qkv( + output: Float[Tensor, "... d"], +) -> tuple[Float[Tensor, "..."], Float[Tensor, "..."]]: + """Split combined QKV output into Q and K projections.""" + d = output.shape[-1] // 3 + return output[..., :d], output[..., d : 2 * d] + + +def _attn_patterns_recon_loss_update( + model: ComponentModel, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + pre_weight_acts: dict[str, Float[Tensor, "..."]], + mask_infos_list: list[dict[str, ComponentsMaskInfo]], + q_paths: list[str], + k_paths: list[str], + is_combined: bool, + n_heads: int, + attn_modules: list[nn.Module | None], +) -> tuple[Float[Tensor, ""], int]: + """Shared update logic for both CI-masked and stochastic variants.""" + # 1. Compute target attention patterns from pre_weight_acts + target_patterns: list[Float[Tensor, "batch n_heads seq seq"]] = [] + for i, (q_path, k_path) in enumerate(zip(q_paths, k_paths, strict=True)): + if is_combined: + assert q_path == k_path + target_out = model.components[q_path](pre_weight_acts[q_path]) + target_q, target_k = _split_combined_qkv(target_out) + else: + target_q = model.components[q_path](pre_weight_acts[q_path]) + target_k = model.components[k_path](pre_weight_acts[k_path]) + target_patterns.append( + _compute_attn_patterns(target_q, target_k, n_heads, attn_modules[i]).detach() + ) + + # 2. Compute masked attention patterns and KL divergence + device = get_obj_device(pre_weight_acts) + sum_kl = torch.tensor(0.0, device=device) + n_distributions = 0 + + for mask_infos in mask_infos_list: + comp_cache = model(batch, mask_infos=mask_infos, cache_type="input").cache + + for i, (q_path, k_path) in enumerate(zip(q_paths, k_paths, strict=True)): + if is_combined: + assert q_path == k_path + masked_out = model.components[q_path]( + comp_cache[q_path], + mask=mask_infos[q_path].component_mask, + weight_delta_and_mask=mask_infos[q_path].weight_delta_and_mask, + ) + masked_q, masked_k = _split_combined_qkv(masked_out) + else: + masked_q = model.components[q_path]( + comp_cache[q_path], + mask=mask_infos[q_path].component_mask, + weight_delta_and_mask=mask_infos[q_path].weight_delta_and_mask, + ) + masked_k = model.components[k_path]( + comp_cache[k_path], + mask=mask_infos[k_path].component_mask, + weight_delta_and_mask=mask_infos[k_path].weight_delta_and_mask, + ) + + masked_patterns = _compute_attn_patterns(masked_q, masked_k, n_heads, attn_modules[i]) + # KL(target || masked): sum over attention distribution dimension + kl = F.kl_div( + masked_patterns.clamp(min=1e-12).log(), + target_patterns[i], + reduction="sum", + ) + sum_kl = sum_kl + kl + # Count: batch * n_heads * seq (one distribution per query position per head) + n_distributions += target_patterns[i].shape[0] * n_heads * target_patterns[i].shape[2] + + return sum_kl, n_distributions + + +def _attn_patterns_recon_loss_compute( + sum_kl: Float[Tensor, ""], + n_distributions: Int[Tensor, ""] | int, +) -> Float[Tensor, ""]: + return sum_kl / n_distributions + + +# --- CI-masked variant --- + + +class CIMaskedAttnPatternsReconLoss(Metric): + """Attention pattern reconstruction loss using CI masks.""" + + metric_section: ClassVar[str] = "loss" + + def __init__( + self, + model: ComponentModel, + device: str, + n_heads: int, + q_proj_path: str | None, + k_proj_path: str | None, + c_attn_path: str | None, + ) -> None: + self.model = model + self.n_heads = n_heads + self.q_paths, self.k_paths, self.is_combined = _resolve_qk_paths( + model, q_proj_path, k_proj_path, c_attn_path + ) + self.attn_modules = _resolve_attn_modules(model, self.q_paths) + self.sum_kl = torch.tensor(0.0, device=device) + self.n_distributions = torch.tensor(0, device=device) + + @override + def update( + self, + *, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + pre_weight_acts: dict[str, Float[Tensor, "..."]], + ci: CIOutputs, + **_: Any, + ) -> None: + mask_infos = make_mask_infos(ci.lower_leaky, weight_deltas_and_masks=None) + sum_kl, n_distributions = _attn_patterns_recon_loss_update( + model=self.model, + batch=batch, + pre_weight_acts=pre_weight_acts, + mask_infos_list=[mask_infos], + q_paths=self.q_paths, + k_paths=self.k_paths, + is_combined=self.is_combined, + n_heads=self.n_heads, + attn_modules=self.attn_modules, + ) + self.sum_kl += sum_kl + self.n_distributions += n_distributions + + @override + def compute(self) -> Float[Tensor, ""]: + sum_kl = all_reduce(self.sum_kl, op=ReduceOp.SUM) + n_distributions = all_reduce(self.n_distributions, op=ReduceOp.SUM) + return _attn_patterns_recon_loss_compute(sum_kl, n_distributions) + + +# --- Stochastic variant --- + + +class StochasticAttnPatternsReconLoss(Metric): + """Attention pattern reconstruction loss with stochastic masks.""" + + metric_section: ClassVar[str] = "loss" + + def __init__( + self, + model: ComponentModel, + device: str, + sampling: SamplingType, + use_delta_component: bool, + n_mask_samples: int, + n_heads: int, + q_proj_path: str | None, + k_proj_path: str | None, + c_attn_path: str | None, + ) -> None: + self.model = model + self.sampling: SamplingType = sampling + self.use_delta_component = use_delta_component + self.n_mask_samples = n_mask_samples + self.n_heads = n_heads + self.q_paths, self.k_paths, self.is_combined = _resolve_qk_paths( + model, q_proj_path, k_proj_path, c_attn_path + ) + self.attn_modules = _resolve_attn_modules(model, self.q_paths) + self.sum_kl = torch.tensor(0.0, device=device) + self.n_distributions = torch.tensor(0, device=device) + + @override + def update( + self, + *, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + pre_weight_acts: dict[str, Float[Tensor, "..."]], + ci: CIOutputs, + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], + **_: Any, + ) -> None: + mask_infos_list = [ + calc_stochastic_component_mask_info( + causal_importances=ci.lower_leaky, + component_mask_sampling=self.sampling, + weight_deltas=weight_deltas if self.use_delta_component else None, + router=AllLayersRouter(), + ) + for _ in range(self.n_mask_samples) + ] + sum_kl, n_distributions = _attn_patterns_recon_loss_update( + model=self.model, + batch=batch, + pre_weight_acts=pre_weight_acts, + mask_infos_list=mask_infos_list, + q_paths=self.q_paths, + k_paths=self.k_paths, + is_combined=self.is_combined, + n_heads=self.n_heads, + attn_modules=self.attn_modules, + ) + self.sum_kl += sum_kl + self.n_distributions += n_distributions + + @override + def compute(self) -> Float[Tensor, ""]: + sum_kl = all_reduce(self.sum_kl, op=ReduceOp.SUM) + n_distributions = all_reduce(self.n_distributions, op=ReduceOp.SUM) + return _attn_patterns_recon_loss_compute(sum_kl, n_distributions) diff --git a/spd/metrics/hidden_acts_recon_loss.py b/spd/metrics/hidden_acts_recon_loss.py new file mode 100644 index 000000000..fa73ac51a --- /dev/null +++ b/spd/metrics/hidden_acts_recon_loss.py @@ -0,0 +1,233 @@ +from typing import Any, ClassVar, override + +import torch +import torch.nn.functional as F +from jaxtyping import Float, Int +from torch import Tensor +from torch.distributed import ReduceOp + +from spd.configs import SamplingType +from spd.metrics.base import Metric +from spd.models.component_model import CIOutputs, ComponentModel +from spd.models.components import ComponentsMaskInfo, make_mask_infos +from spd.routing import AllLayersRouter +from spd.utils.component_utils import calc_stochastic_component_mask_info +from spd.utils.distributed_utils import all_reduce + +PerModuleMSE = dict[str, tuple[Float[Tensor, ""], int]] + + +def calc_hidden_acts_mse( + model: ComponentModel, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + mask_infos: dict[str, ComponentsMaskInfo], + target_acts: dict[str, Float[Tensor, "..."]], +) -> tuple[PerModuleMSE, Float[Tensor, "..."]]: + """Forward with mask_infos and compute per-module MSE against target output activations. + + Returns the per-module MSE dict and the component model's output tensor. + """ + result = model(batch, mask_infos=mask_infos, cache_type="output") + per_module: PerModuleMSE = {} + for layer_name, target in target_acts.items(): + assert layer_name in result.cache, f"{layer_name} not in comp_cache" + mse = F.mse_loss(result.cache[layer_name], target, reduction="sum") + per_module[layer_name] = (mse, target.numel()) + return per_module, result.output + + +def _sum_per_module_mse(per_module: PerModuleMSE) -> tuple[Float[Tensor, ""], int]: + device = next(iter(per_module.values()))[0].device + total_mse = torch.tensor(0.0, device=device) + total_n = 0 + for mse, n in per_module.values(): + total_mse = total_mse + mse + total_n += n + return total_mse, total_n + + +def _accumulate_per_module(accum: PerModuleMSE, per_module: PerModuleMSE) -> None: + for key, (mse, n) in per_module.items(): + if key in accum: + prev_mse, prev_n = accum[key] + accum[key] = (prev_mse + mse, prev_n + n) + else: + accum[key] = (mse, n) + + +def _stochastic_hidden_acts_recon_loss_update( + model: ComponentModel, + sampling: SamplingType, + n_mask_samples: int, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, +) -> PerModuleMSE: + assert ci, "Empty ci" + + target_acts = model(batch, cache_type="output").cache + + accum: PerModuleMSE = {} + stoch_mask_infos_list = [ + calc_stochastic_component_mask_info( + causal_importances=ci, + component_mask_sampling=sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + for _ in range(n_mask_samples) + ] + for stoch_mask_infos in stoch_mask_infos_list: + per_module, _ = calc_hidden_acts_mse( + model=model, + batch=batch, + mask_infos=stoch_mask_infos, + target_acts=target_acts, + ) + _accumulate_per_module(accum, per_module) + + return accum + + +def _hidden_acts_recon_loss_compute( + sum_mse: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int +) -> Float[Tensor, ""]: + return sum_mse / n_examples + + +def compute_per_module_metrics( + class_name: str, + per_module_sum_mse: dict[str, Tensor], + per_module_n_examples: dict[str, Tensor], +) -> dict[str, Float[Tensor, ""]]: + assert per_module_sum_mse, "No data accumulated" + keys = list(per_module_sum_mse.keys()) + stacked_mse = torch.stack([per_module_sum_mse[k] for k in keys]) + stacked_n = torch.stack([per_module_n_examples[k].float() for k in keys]) + stacked_mse = all_reduce(stacked_mse, op=ReduceOp.SUM) + stacked_n = all_reduce(stacked_n, op=ReduceOp.SUM) + + out: dict[str, Float[Tensor, ""]] = {} + for i, key in enumerate(keys): + out[f"{class_name}/{key}"] = stacked_mse[i] / stacked_n[i] + out[class_name] = stacked_mse.sum() / stacked_n.sum() + return out + + +def stochastic_hidden_acts_recon_loss( + model: ComponentModel, + sampling: SamplingType, + n_mask_samples: int, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, +) -> Float[Tensor, ""]: + per_module = _stochastic_hidden_acts_recon_loss_update( + model=model, + sampling=sampling, + n_mask_samples=n_mask_samples, + batch=batch, + ci=ci, + weight_deltas=weight_deltas, + ) + sum_mse, n_examples = _sum_per_module_mse(per_module) + return _hidden_acts_recon_loss_compute(sum_mse, n_examples) + + +class StochasticHiddenActsReconLoss(Metric): + """Reconstruction loss between target and stochastic hidden activations when sampling with stochastic masks.""" + + slow: ClassVar[bool] = True + metric_section: ClassVar[str] = "loss" + + def __init__( + self, + model: ComponentModel, + device: str, + sampling: SamplingType, + use_delta_component: bool, + n_mask_samples: int, + ) -> None: + self.model = model + self.sampling: SamplingType = sampling + self.use_delta_component: bool = use_delta_component + self.n_mask_samples: int = n_mask_samples + self.device = device + self.per_module_sum_mse: dict[str, Tensor] = {} + self.per_module_n_examples: dict[str, Tensor] = {} + + @override + def update( + self, + *, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + ci: CIOutputs, + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], + **_: Any, + ) -> None: + per_module = _stochastic_hidden_acts_recon_loss_update( + model=self.model, + sampling=self.sampling, + n_mask_samples=self.n_mask_samples, + batch=batch, + ci=ci.lower_leaky, + weight_deltas=weight_deltas if self.use_delta_component else None, + ) + for key, (mse, n) in per_module.items(): + if key not in self.per_module_sum_mse: + self.per_module_sum_mse[key] = torch.tensor(0.0, device=self.device) + self.per_module_n_examples[key] = torch.tensor(0, device=self.device) + self.per_module_sum_mse[key] += mse.detach() + self.per_module_n_examples[key] += n + + @override + def compute(self) -> dict[str, Float[Tensor, ""]]: + return compute_per_module_metrics( + class_name=type(self).__name__, + per_module_sum_mse=self.per_module_sum_mse, + per_module_n_examples=self.per_module_n_examples, + ) + + +class CIHiddenActsReconLoss(Metric): + """Reconstruction loss between target and component hidden activations when masking with CI values.""" + + slow: ClassVar[bool] = True + metric_section: ClassVar[str] = "loss" + + def __init__(self, model: ComponentModel, device: str) -> None: + self.model = model + self.device = device + self.per_module_sum_mse: dict[str, Tensor] = {} + self.per_module_n_examples: dict[str, Tensor] = {} + + @override + def update( + self, + *, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + ci: CIOutputs, + **_: Any, + ) -> None: + target_acts = self.model(batch, cache_type="output").cache + mask_infos = make_mask_infos(ci.lower_leaky, weight_deltas_and_masks=None) + per_module, _output = calc_hidden_acts_mse( + model=self.model, + batch=batch, + mask_infos=mask_infos, + target_acts=target_acts, + ) + for key, (mse, n) in per_module.items(): + if key not in self.per_module_sum_mse: + self.per_module_sum_mse[key] = torch.tensor(0.0, device=self.device) + self.per_module_n_examples[key] = torch.tensor(0, device=self.device) + self.per_module_sum_mse[key] += mse.detach() + self.per_module_n_examples[key] += n + + @override + def compute(self) -> dict[str, Float[Tensor, ""]]: + return compute_per_module_metrics( + class_name=type(self).__name__, + per_module_sum_mse=self.per_module_sum_mse, + per_module_n_examples=self.per_module_n_examples, + ) diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index 4635077bf..6d546c86a 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -10,30 +10,19 @@ from spd.configs import PGDConfig, PGDInitStrategy, PGDMultiBatchConfig, SamplingType from spd.log import logger from spd.models.component_model import ComponentModel, OutputWithCache -from spd.models.components import RoutingMasks, make_mask_infos +from spd.models.components import ComponentsMaskInfo, RoutingMasks, make_mask_infos from spd.routing import Router from spd.utils.distributed_utils import all_reduce, broadcast_tensor from spd.utils.general_utils import calc_sum_recon_loss_lm, extract_batch_data -def pgd_masked_recon_loss_update( +def _init_adv_sources( model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - ci: dict[str, Float[Tensor, "... C"]], + batch_dims: tuple[int, ...], + device: torch.device | str, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], - router: Router, pgd_config: PGDConfig, -) -> tuple[Float[Tensor, ""], int]: - """Central implementation of PGD masked reconstruction loss. - - Optimizes adversarial stochastic masks and optionally weight deltas for the given objective function. - """ - batch_dims = next(iter(ci.values())).shape[:-1] - - routing_masks = router.get_masks(module_names=model.target_module_paths, mask_shape=batch_dims) - +) -> dict[str, Float[Tensor, "*batch_dims mask_c"]]: adv_sources: dict[str, Float[Tensor, "*batch_dims mask_c"]] = {} for module_name in model.target_module_paths: module_c = model.module_to_c[module_name] @@ -41,32 +30,24 @@ def pgd_masked_recon_loss_update( match pgd_config.mask_scope: case "unique_per_datapoint": shape = torch.Size([*batch_dims, mask_c]) - source = _get_pgd_init_tensor(pgd_config.init, shape, batch.device) + source = get_pgd_init_tensor(pgd_config.init, shape, device) case "shared_across_batch": singleton_batch_dims = [1 for _ in batch_dims] shape = torch.Size([*singleton_batch_dims, mask_c]) - source = broadcast_tensor( - _get_pgd_init_tensor(pgd_config.init, shape, batch.device) - ) + source = broadcast_tensor(get_pgd_init_tensor(pgd_config.init, shape, device)) adv_sources[module_name] = source.requires_grad_(True) + return adv_sources - fwd_pass = partial( - _forward_with_adv_sources, - model=model, - batch=batch, - adv_sources=adv_sources, - ci=ci, - weight_deltas=weight_deltas, - routing_masks=routing_masks, - target_out=target_out, - output_loss_type=output_loss_type, - batch_dims=batch_dims, - ) +def _run_pgd_loop( + adv_sources: dict[str, Float[Tensor, "..."]], + pgd_config: PGDConfig, + fwd_fn: Callable[[], tuple[Float[Tensor, ""], int]], +) -> tuple[Float[Tensor, ""], int]: for _ in range(pgd_config.n_steps): assert all(adv.grad is None for adv in adv_sources.values()) with torch.enable_grad(): - sum_loss, n_examples = fwd_pass() + sum_loss, n_examples = fwd_fn() loss = sum_loss / n_examples grads = torch.autograd.grad(loss, list(adv_sources.values())) match pgd_config.mask_scope: @@ -82,7 +63,67 @@ def pgd_masked_recon_loss_update( adv_sources[k].add_(pgd_config.step_size * adv_sources_grads[k].sign()) adv_sources[k].clamp_(0.0, 1.0) - return fwd_pass() + return fwd_fn() + + +def _construct_mask_infos_from_adv_sources( + adv_sources: dict[str, Float[Tensor, "*batch_dim_or_ones mask_c"]], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + routing_masks: RoutingMasks, + batch_dims: tuple[int, ...], +) -> dict[str, "ComponentsMaskInfo"]: + expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} + adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] + match weight_deltas: + case None: + weight_deltas_and_masks = None + adv_sources_components = expanded_adv_sources + case dict(): + weight_deltas_and_masks = { + k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas + } + adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} + + return make_mask_infos( + component_masks=interpolate_pgd_mask(ci, adv_sources_components), + weight_deltas_and_masks=weight_deltas_and_masks, + routing_masks=routing_masks, + ) + + +def pgd_masked_recon_loss_update( + model: ComponentModel, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + target_out: Float[Tensor, "... vocab"], + output_loss_type: Literal["mse", "kl"], + router: Router, + pgd_config: PGDConfig, +) -> tuple[Float[Tensor, ""], int]: + """Central implementation of PGD masked reconstruction loss. + + Optimizes adversarial stochastic masks and optionally weight deltas for the given objective function. + """ + batch_dims = next(iter(ci.values())).shape[:-1] + routing_masks = router.get_masks(module_names=model.target_module_paths, mask_shape=batch_dims) + adv_sources = _init_adv_sources(model, batch_dims, batch.device, weight_deltas, pgd_config) + + fwd_pass = partial( + _forward_with_adv_sources, + model=model, + batch=batch, + adv_sources=adv_sources, + ci=ci, + weight_deltas=weight_deltas, + routing_masks=routing_masks, + target_out=target_out, + output_loss_type=output_loss_type, + batch_dims=batch_dims, + ) + + return _run_pgd_loop(adv_sources, pgd_config, fwd_pass) CreateDataIter = Callable[ @@ -130,7 +171,7 @@ def calc_multibatch_pgd_masked_recon_loss( mask_c = module_c if not use_delta_component else module_c + 1 shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = broadcast_tensor( - _get_pgd_init_tensor(pgd_config.init, shape, device) + get_pgd_init_tensor(pgd_config.init, shape, device) ).requires_grad_(True) fwd_bwd_fn = partial( @@ -170,22 +211,12 @@ def _forward_with_adv_sources( output_loss_type: Literal["mse", "kl"], batch_dims: tuple[int, ...], ): - expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} - adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] - match weight_deltas: - case None: - weight_deltas_and_masks = None - adv_sources_components = expanded_adv_sources - case dict(): - weight_deltas_and_masks = { - k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas - } - adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} - - mask_infos = make_mask_infos( - component_masks=_interpolate_component_mask(ci, adv_sources_components), - weight_deltas_and_masks=weight_deltas_and_masks, + mask_infos = _construct_mask_infos_from_adv_sources( + adv_sources=adv_sources, + ci=ci, + weight_deltas=weight_deltas, routing_masks=routing_masks, + batch_dims=batch_dims, ) out = model(batch, mask_infos=mask_infos) @@ -270,11 +301,12 @@ def _multibatch_pgd_fwd_bwd( return pgd_step_accum_sum_loss, pgd_step_accum_n_examples, pgd_step_accum_grads -def _get_pgd_init_tensor( +def get_pgd_init_tensor( init: PGDInitStrategy, shape: tuple[int, ...], device: torch.device | str, ) -> Float[Tensor, "... shape"]: + """Create initial PGD source tensor (random, ones, or zeroes). Shared by training PGD and app eval PGD.""" match init: case "random": return torch.rand(shape, device=device) @@ -284,7 +316,7 @@ def _get_pgd_init_tensor( return torch.zeros(shape, device=device) -def _interpolate_component_mask( +def interpolate_pgd_mask( ci: dict[str, Float[Tensor, "*batch_dims C"]], adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]], ) -> dict[str, Float[Tensor, "*batch_dims C"]]: @@ -301,7 +333,6 @@ def _interpolate_component_mask( component_masks: dict[str, Float[Tensor, "*batch_dims C"]] = {} for module_name in ci: adv_source = adv_sources_components[module_name] - assert torch.all(adv_source <= 1.0) and torch.all(adv_source >= 0.0) assert ci[module_name].shape[-1] == adv_source.shape[-1] scaled_noise_to_add = (1 - ci[module_name]) * adv_source component_masks[module_name] = ci[module_name] + scaled_noise_to_add diff --git a/spd/metrics/ppgd_eval_losses.py b/spd/metrics/ppgd_eval_losses.py new file mode 100644 index 000000000..e313d229b --- /dev/null +++ b/spd/metrics/ppgd_eval_losses.py @@ -0,0 +1,97 @@ +from typing import Any, ClassVar, Literal, override + +import torch +from jaxtyping import Float, Int +from torch import Tensor +from torch.distributed import ReduceOp + +from spd.metrics.base import Metric +from spd.metrics.hidden_acts_recon_loss import calc_hidden_acts_mse, compute_per_module_metrics +from spd.models.component_model import CIOutputs, ComponentModel +from spd.persistent_pgd import PPGDSources, get_ppgd_mask_infos +from spd.utils.distributed_utils import all_reduce +from spd.utils.general_utils import calc_sum_recon_loss_lm + + +class PPGDReconEval(Metric): + """Eval losses using persistent PGD masks: hidden activation MSE and output reconstruction. + + Handles a single persistent PGD state, keyed by metric_name. + """ + + slow: ClassVar[bool] = True + metric_section: ClassVar[str] = "loss" + + def __init__( + self, + model: ComponentModel, + device: str, + effective_sources: PPGDSources, + use_delta_component: bool, + output_loss_type: Literal["mse", "kl"], + metric_name: str, + ) -> None: + self.model = model + self.use_delta_component = use_delta_component + self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.device = device + self._effective_sources = effective_sources + self._metric_name = metric_name + + self._module_sum_mse: dict[str, Tensor] = {} + self._module_n: dict[str, Tensor] = {} + self._output_sum_loss = torch.tensor(0.0, device=device) + self._output_n = torch.tensor(0, device=device) + + @override + def update( + self, + *, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + ci: CIOutputs, + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], + target_out: Float[Tensor, "..."], + **_: Any, + ) -> None: + target_acts = self.model(batch, cache_type="output").cache + batch_dims = next(iter(ci.lower_leaky.values())).shape[:-1] + + mask_infos = get_ppgd_mask_infos( + ci=ci.lower_leaky, + weight_deltas=weight_deltas if self.use_delta_component else None, + ppgd_sources=self._effective_sources, + routing_masks="all", + batch_dims=batch_dims, + ) + per_module, comp_output = calc_hidden_acts_mse( + model=self.model, + batch=batch, + mask_infos=mask_infos, + target_acts=target_acts, + ) + for key, (mse, n) in per_module.items(): + if key not in self._module_sum_mse: + self._module_sum_mse[key] = torch.tensor(0.0, device=self.device) + self._module_n[key] = torch.tensor(0, device=self.device) + self._module_sum_mse[key] += mse.detach() + self._module_n[key] += n + + output_loss = calc_sum_recon_loss_lm( + pred=comp_output, target=target_out, loss_type=self.output_loss_type + ) + self._output_sum_loss += output_loss.detach() + self._output_n += target_out.numel() + + @override + def compute(self) -> dict[str, Float[Tensor, ""]]: + out: dict[str, Float[Tensor, ""]] = {} + per_module = compute_per_module_metrics( + class_name=f"{self._metric_name}/hidden_acts", + per_module_sum_mse=self._module_sum_mse, + per_module_n_examples=self._module_n, + ) + out.update(per_module) + sum_loss = all_reduce(self._output_sum_loss, op=ReduceOp.SUM) + n_examples = all_reduce(self._output_n.float(), op=ReduceOp.SUM) + out[f"{self._metric_name}/output_recon"] = sum_loss / n_examples + return out diff --git a/spd/metrics/stochastic_hidden_acts_recon_loss.py b/spd/metrics/stochastic_hidden_acts_recon_loss.py deleted file mode 100644 index 814e6e18c..000000000 --- a/spd/metrics/stochastic_hidden_acts_recon_loss.py +++ /dev/null @@ -1,129 +0,0 @@ -from typing import Any, ClassVar, override - -import torch -from jaxtyping import Float, Int -from torch import Tensor -from torch.distributed import ReduceOp - -from spd.configs import SamplingType -from spd.metrics.base import Metric -from spd.models.component_model import CIOutputs, ComponentModel -from spd.routing import AllLayersRouter -from spd.utils.component_utils import calc_stochastic_component_mask_info -from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import get_obj_device - - -def _stochastic_hidden_acts_recon_loss_update( - model: ComponentModel, - sampling: SamplingType, - n_mask_samples: int, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - pre_weight_acts: dict[str, Float[Tensor, "..."]], - ci: dict[str, Float[Tensor, "... C"]], - weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, -) -> tuple[Float[Tensor, ""], int]: - assert ci, "Empty ci" - assert pre_weight_acts, "Empty pre_weight_acts" - device = get_obj_device(ci) - sum_mse = torch.tensor(0.0, device=device) - n_examples = 0 - - stoch_mask_infos_list = [ - calc_stochastic_component_mask_info( - causal_importances=ci, - component_mask_sampling=sampling, - weight_deltas=weight_deltas, - router=AllLayersRouter(), - ) - for _ in range(n_mask_samples) - ] - for stoch_mask_infos in stoch_mask_infos_list: - comp_pre_weight_acts = model(batch, mask_infos=stoch_mask_infos, cache_type="input").cache - - # Calculate MSE between pre_weight_acts with and without components - for layer_name, target_acts in pre_weight_acts.items(): - assert layer_name in comp_pre_weight_acts, f"{layer_name} not in comp_pre_weight_acts" - mse = torch.nn.functional.mse_loss( - comp_pre_weight_acts[layer_name], target_acts, reduction="sum" - ) - sum_mse += mse - n_examples += target_acts.numel() - - return sum_mse, n_examples - - -def _stochastic_hidden_acts_recon_loss_compute( - sum_mse: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int -) -> Float[Tensor, ""]: - return sum_mse / n_examples - - -def stochastic_hidden_acts_recon_loss( - model: ComponentModel, - sampling: SamplingType, - n_mask_samples: int, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - pre_weight_acts: dict[str, Float[Tensor, "..."]], - ci: dict[str, Float[Tensor, "... C"]], - weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, -) -> Float[Tensor, ""]: - sum_mse, n_examples = _stochastic_hidden_acts_recon_loss_update( - model, - sampling, - n_mask_samples, - batch, - pre_weight_acts, - ci, - weight_deltas, - ) - return _stochastic_hidden_acts_recon_loss_compute(sum_mse, n_examples) - - -class StochasticHiddenActsReconLoss(Metric): - """Reconstruction loss between target and stochastic hidden activations when sampling with stochastic masks.""" - - metric_section: ClassVar[str] = "loss" - - def __init__( - self, - model: ComponentModel, - device: str, - sampling: SamplingType, - use_delta_component: bool, - n_mask_samples: int, - ) -> None: - self.model = model - self.sampling: SamplingType = sampling - self.use_delta_component: bool = use_delta_component - self.n_mask_samples: int = n_mask_samples - self.sum_mse = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) - - @override - def update( - self, - *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - pre_weight_acts: dict[str, Float[Tensor, "..."]], - ci: CIOutputs, - weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], - **_: Any, - ) -> None: - sum_mse, n_examples = _stochastic_hidden_acts_recon_loss_update( - model=self.model, - sampling=self.sampling, - n_mask_samples=self.n_mask_samples, - batch=batch, - pre_weight_acts=pre_weight_acts, - ci=ci.lower_leaky, - weight_deltas=weight_deltas if self.use_delta_component else None, - ) - self.sum_mse += sum_mse - self.n_examples += n_examples - - @override - def compute(self) -> Float[Tensor, ""]: - sum_mse = all_reduce(self.sum_mse, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _stochastic_hidden_acts_recon_loss_compute(sum_mse, n_examples) diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 4a2a20fb2..d17cb5e2f 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,3 +1,4 @@ +import fnmatch from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -10,25 +11,51 @@ from torch.utils.hooks import RemovableHandle from transformers.pytorch_utils import Conv1D as RadfordConv1D -from spd.configs import Config, SamplingType +from spd.configs import CiConfig, Config, GlobalCiConfig, LayerwiseCiConfig, SamplingType from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo from spd.models.components import ( Components, ComponentsMaskInfo, EmbeddingComponents, + GlobalCiFnWrapper, + GlobalReverseResidualCiFn, + GlobalSharedMLPCiFn, + GlobalSharedTransformerCiFn, Identity, + LayerwiseCiFnWrapper, LinearComponents, MLPCiFn, + TargetLayerConfig, VectorMLPCiFn, VectorSharedMLPCiFn, ) from spd.models.sigmoids import SIGMOID_TYPES, SigmoidType -from spd.spd_types import CiFnType, ModelPath -from spd.utils.general_utils import resolve_class, runtime_cast +from spd.spd_types import LayerwiseCiFnType, ModelPath +from spd.utils.general_utils import resolve_class from spd.utils.module_utils import ModulePathInfo, expand_module_patterns +def _validate_checkpoint_ci_config_compatibility( + state_dict: dict[str, Tensor], ci_config: CiConfig +) -> None: + """Validate that checkpoint CI weights match the config CI mode.""" + has_layerwise_ci_fns = any(k.startswith("ci_fn._ci_fns") for k in state_dict) + has_global_ci_fn = any(k.startswith("ci_fn._global_ci_fn") for k in state_dict) + + match ci_config: + case LayerwiseCiConfig(): + assert has_layerwise_ci_fns, ( + f"Config specifies layerwise CI but checkpoint has no ci_fn._ci_fns keys " + f"(has ci_fn._global_ci_fn: {has_global_ci_fn})" + ) + case GlobalCiConfig(): + assert has_global_ci_fn, ( + f"Config specifies global CI but checkpoint has no ci_fn._global_ci_fn keys " + f"(has ci_fn._ci_fns: {has_layerwise_ci_fns})" + ) + + @dataclass class SPDRunInfo(RunInfo[Config]): """Run info from training a ComponentModel (i.e. from an SPD run).""" @@ -75,8 +102,7 @@ def __init__( self, target_model: nn.Module, module_path_info: list[ModulePathInfo], - ci_fn_type: CiFnType, - ci_fn_hidden_dims: list[int], + ci_config: CiConfig, sigmoid_type: SigmoidType, pretrained_model_output_attr: str | None, ): @@ -101,15 +127,33 @@ def __init__( {k.replace(".", "-"): self.components[k] for k in sorted(self.components)} ) - self.ci_fns = ComponentModel._create_ci_fns( - target_model=target_model, - module_to_c=self.module_to_c, - ci_fn_type=ci_fn_type, - ci_fn_hidden_dims=ci_fn_hidden_dims, - ) - self._ci_fns = nn.ModuleDict( - {k.replace(".", "-"): self.ci_fns[k] for k in sorted(self.ci_fns)} - ) + match ci_config: + case LayerwiseCiConfig(): + raw_layerwise_ci_fns = { + path: ComponentModel._create_layerwise_ci_fn( + target_module=target_model.get_submodule(path), + C=C, + ci_fn_type=ci_config.fn_type, + ci_fn_hidden_dims=ci_config.hidden_dims, + ) + for path, C in self.module_to_c.items() + } + self.ci_fn = LayerwiseCiFnWrapper( + ci_fns=raw_layerwise_ci_fns, + components=self.components, + ci_fn_type=ci_config.fn_type, + ) + case GlobalCiConfig(): + raw_global_ci_fn = ComponentModel._create_global_ci_fn( + target_model=target_model, + module_to_c=self.module_to_c, + components=self.components, + ci_config=ci_config, + ) + self.ci_fn = GlobalCiFnWrapper( + global_ci_fn=raw_global_ci_fn, + components=self.components, + ) if sigmoid_type == "leaky_hard": self.lower_leaky_fn = SIGMOID_TYPES["lower_leaky_hard"] @@ -187,28 +231,39 @@ def _create_components( return components @staticmethod - def _create_ci_fn( + def _get_module_input_dim(target_module: nn.Module) -> int: + """Extract input dimension from a Linear-like module. + + For embedding layers, this should not be called - handle them separately. + """ + match target_module: + case nn.Linear(): + return target_module.weight.shape[1] + case RadfordConv1D(): + return target_module.weight.shape[0] + case Identity(): + return target_module.d + case _: + raise ValueError( + f"Module {type(target_module)} not supported. " + "Embedding modules should be handled separately." + ) + + @staticmethod + def _create_layerwise_ci_fn( target_module: nn.Module, C: int, - ci_fn_type: CiFnType, + ci_fn_type: LayerwiseCiFnType, ci_fn_hidden_dims: list[int], ) -> nn.Module: - """Helper to create a causal importance function (ci_fn) based on ci_fn_type and module type.""" + """Helper to create a single layerwise CI function based on ci_fn_type and module type.""" if isinstance(target_module, nn.Embedding): assert ci_fn_type == "mlp", "Embedding modules only supported for ci_fn_type='mlp'" if ci_fn_type == "mlp": return MLPCiFn(C=C, hidden_dims=ci_fn_hidden_dims) - match target_module: - case nn.Linear(): - input_dim = target_module.weight.shape[1] - case RadfordConv1D(): - input_dim = target_module.weight.shape[0] - case Identity(): - input_dim = target_module.d - case _: - raise ValueError(f"Module {type(target_module)} not supported for {ci_fn_type=}") + input_dim = ComponentModel._get_module_input_dim(target_module) match ci_fn_type: case "vector_mlp": @@ -217,22 +272,107 @@ def _create_ci_fn( return VectorSharedMLPCiFn(C=C, input_dim=input_dim, hidden_dims=ci_fn_hidden_dims) @staticmethod - def _create_ci_fns( + def _create_global_ci_fn( target_model: nn.Module, module_to_c: dict[str, int], - ci_fn_type: CiFnType, - ci_fn_hidden_dims: list[int], - ) -> dict[str, nn.Module]: - ci_fns: dict[str, nn.Module] = {} + components: dict[str, Components], + ci_config: GlobalCiConfig, + ) -> GlobalSharedMLPCiFn | GlobalSharedTransformerCiFn | GlobalReverseResidualCiFn: + """Create a global CI function that takes all layer activations as input.""" + ci_fn_type = ci_config.fn_type + ci_fn_hidden_dims = ci_config.hidden_dims + + # Build layer_configs: layer_name -> (input_dim, C) + layer_configs: dict[str, tuple[int, int]] = {} for target_module_path, target_module_c in module_to_c.items(): target_module = target_model.get_submodule(target_module_path) - ci_fns[target_module_path] = ComponentModel._create_ci_fn( - target_module=target_module, - C=target_module_c, - ci_fn_type=ci_fn_type, - ci_fn_hidden_dims=ci_fn_hidden_dims, - ) - return ci_fns + component = components[target_module_path] + + # For embeddings, global CI uses component acts (C dimensions) + # For linear-like modules, use the actual input dimension + if isinstance(target_module, nn.Embedding): + assert isinstance(component, EmbeddingComponents) + input_dim = component.C + else: + input_dim = ComponentModel._get_module_input_dim(target_module) + + layer_configs[target_module_path] = (input_dim, target_module_c) + + match ci_fn_type: + case "global_shared_mlp": + assert ci_fn_hidden_dims is not None # validated by Pydantic + return GlobalSharedMLPCiFn( + layer_configs=layer_configs, hidden_dims=ci_fn_hidden_dims + ) + case "global_shared_transformer": + transformer_cfg = ci_config.simple_transformer_ci_cfg + assert transformer_cfg is not None # validated by Pydantic + + return GlobalSharedTransformerCiFn( + target_model_layer_configs={ + target_module_path: TargetLayerConfig(input_dim=input_dim, C=C) + for target_module_path, (input_dim, C) in layer_configs.items() + }, + d_model=transformer_cfg.d_model, + n_layers=transformer_cfg.n_blocks, + n_heads=transformer_cfg.attn_config.n_heads, + mlp_hidden_dims=transformer_cfg.mlp_hidden_dim, + max_len=transformer_cfg.attn_config.max_len, + rope_base=transformer_cfg.attn_config.rope_base, + ) + case "global_reverse_residual": + # block_groups, d_resid_ci_fn, reader_hidden_dims, transition_hidden_dim + # are validated by Pydantic + block_groups = ci_config.block_groups + d_resid_ci_fn = ci_config.d_resid_ci_fn + reader_hidden_dims = ci_config.reader_hidden_dims + transition_hidden_dim = ci_config.transition_hidden_dim + assert block_groups is not None # for type narrowing + assert d_resid_ci_fn is not None # for type narrowing + assert reader_hidden_dims is not None # for type narrowing + + # Build block_configs from block_groups + block_configs: list[tuple[str, list[str], list[int], list[int]]] = [] + all_matched_modules: set[str] = set() + + for block_group in block_groups: + matched_modules: list[str] = [] + for pattern in block_group.patterns: + matches = [name for name in module_to_c if fnmatch.fnmatch(name, pattern)] + assert matches, ( + f"Block pattern '{pattern}' in block '{block_group.name}' " + f"matched no modules. Available: {list(module_to_c.keys())}" + ) + for match in matches: + assert match not in matched_modules, ( + f"Module '{match}' matched multiple patterns in block " + f"'{block_group.name}'" + ) + matched_modules.extend(matches) + + for module in matched_modules: + assert module not in all_matched_modules, ( + f"Module '{module}' matched multiple block groups" + ) + all_matched_modules.add(module) + + input_dims = [layer_configs[m][0] for m in matched_modules] + c_values = [layer_configs[m][1] for m in matched_modules] + + block_configs.append((block_group.name, matched_modules, input_dims, c_values)) + + assert all_matched_modules == set(module_to_c.keys()), ( + f"Some modules not in any block group. " + f"Missing: {set(module_to_c.keys()) - all_matched_modules}" + ) + + return GlobalReverseResidualCiFn( + block_configs=block_configs, + d_resid_ci_fn=d_resid_ci_fn, + reader_hidden_dims=reader_hidden_dims, + transition_hidden_dim=transition_hidden_dim, + attn_config=ci_config.transition_attn_config, + ) def _extract_output(self, raw_output: Any) -> Tensor: """Extract the desired output from the model's raw output. @@ -270,7 +410,7 @@ def __call__( self, *args: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["component_acts", "input"], + cache_type: Literal["component_acts", "input", "output"], **kwargs: Any, ) -> OutputWithCache: ... @@ -292,30 +432,20 @@ def forward( self, *args: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["component_acts", "input", "none"] = "none", + cache_type: Literal["component_acts", "input", "output", "none"] = "none", **kwargs: Any, ) -> Tensor | OutputWithCache: - """Forward pass with optional component replacement and/or input caching. - - This method handles the following 4 cases: - 1. mask_infos is None and cache_type is "none": Regular forward pass. - 2. mask_infos is None and cache_type is "input" or "component_acts": Forward pass with - caching on all modules in self.target_module_paths. - 3. mask_infos is not None and cache_type is "input" or "component_acts": Forward pass with - component replacement and caching on the modules provided in mask_infos. - 4. mask_infos is not None and cache_type is "none": Forward pass with component replacement - on the modules provided in mask_infos and no caching. + """Forward pass with optional component replacement and/or input/output caching. Args: mask_infos: Dictionary mapping module names to ComponentsMaskInfo. If provided, those modules will be replaced with their components. - cache_type: If "input" or "component_acts", cache the inputs or component acts to the - modules provided in mask_infos. If "none", no caching is done. If mask_infos is None, - cache the inputs or component acts to all modules in self.target_module_paths. + cache_type: What to cache for each hooked module. "input" caches pre-weight + activations, "output" caches post-weight activations, "component_acts" caches + per-component activations, "none" disables caching. Returns: - OutputWithCache object if cache_type is "input" or "component_acts", otherwise the - model output tensor. + OutputWithCache object if cache_type is not "none", otherwise the model output tensor. """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. @@ -344,7 +474,7 @@ def forward( out = self._extract_output(raw_out) match cache_type: - case "input" | "component_acts": + case "input" | "output" | "component_acts": return OutputWithCache(output=out, cache=cache) case "none": return out @@ -358,7 +488,7 @@ def _components_and_cache_hook( module_name: str, components: Components | None, mask_info: ComponentsMaskInfo | None, - cache_type: Literal["component_acts", "input", "none"], + cache_type: Literal["component_acts", "input", "output", "none"], cache: dict[str, Tensor], ) -> Any | None: """Unified hook function that handles both component replacement and caching. @@ -402,12 +532,20 @@ def _components_and_cache_hook( for k, v in component_acts_cache.items(): cache[f"{module_name}_{k}"] = v - if mask_info.routing_mask == "all": - return components_out + final_out = ( + components_out + if mask_info.routing_mask == "all" + else torch.where(mask_info.routing_mask[..., None], components_out, output) + ) - return torch.where(mask_info.routing_mask[..., None], components_out, output) + if cache_type == "output": + cache[module_name] = final_out + return final_out # No component replacement - keep original output + if cache_type == "output": + assert isinstance(output, Tensor) + cache[module_name] = output return None @contextmanager @@ -470,10 +608,9 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": comp_model = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - pretrained_model_output_attr=config.pretrained_model_output_attr, + ci_config=config.ci_config, sigmoid_type=config.sigmoid_type, + pretrained_model_output_attr=config.pretrained_model_output_attr, ) comp_model_weights = torch.load( @@ -482,6 +619,8 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": handle_deprecated_state_dict_keys_(comp_model_weights) + _validate_checkpoint_ci_config_compatibility(comp_model_weights, config.ci_config) + comp_model.load_state_dict(comp_model_weights) return comp_model @@ -498,38 +637,33 @@ def calc_causal_importances( sampling: SamplingType, detach_inputs: bool = False, ) -> CIOutputs: - """Calculate causal importances. + """Calculate causal importances using the unified CI function interface. Args: pre_weight_acts: The activations before each layer in the target model. + sampling: The sampling type for stochastic masks. detach_inputs: Whether to detach the inputs to the causal importance function. Returns: - Tuple of (causal_importances, causal_importances_upper_leaky) dictionaries for each layer. + CIOutputs containing lower_leaky, upper_leaky, and pre_sigmoid CI values. """ + if detach_inputs: + pre_weight_acts = {k: v.detach() for k, v in pre_weight_acts.items()} + + ci_fn_outputs = self.ci_fn(pre_weight_acts) + return self._apply_sigmoid_to_ci_outputs(ci_fn_outputs, sampling) + + def _apply_sigmoid_to_ci_outputs( + self, + ci_fn_outputs: dict[str, Float[Tensor, "... C"]], + sampling: SamplingType, + ) -> CIOutputs: + """Apply sigmoid functions to CI function outputs.""" causal_importances_lower_leaky = {} causal_importances_upper_leaky = {} pre_sigmoid = {} - for target_module_name in pre_weight_acts: - input_activations = pre_weight_acts[target_module_name] - ci_fn = self.ci_fns[target_module_name] - - match ci_fn: - case MLPCiFn(): - ci_fn_input = self.components[target_module_name].get_component_acts( - input_activations - ) - case VectorMLPCiFn() | VectorSharedMLPCiFn(): - ci_fn_input = input_activations - case _: - raise ValueError(f"Unknown ci_fn type: {type(ci_fn)}") - - if detach_inputs: - ci_fn_input = ci_fn_input.detach() - - ci_fn_output = runtime_cast(Tensor, ci_fn(ci_fn_input)) - + for target_module_name, ci_fn_output in ci_fn_outputs.items(): if sampling == "binomial": ci_fn_output_for_lower_leaky = 1.05 * ci_fn_output - 0.05 * torch.rand_like( ci_fn_output @@ -538,11 +672,11 @@ def calc_causal_importances( ci_fn_output_for_lower_leaky = ci_fn_output lower_leaky_output = self.lower_leaky_fn(ci_fn_output_for_lower_leaky) - assert lower_leaky_output.all() <= 1.0 + assert (lower_leaky_output <= 1.0).all() causal_importances_lower_leaky[target_module_name] = lower_leaky_output upper_leaky_output = self.upper_leaky_fn(ci_fn_output) - assert upper_leaky_output.all() >= 0 + assert (upper_leaky_output >= 0).all() causal_importances_upper_leaky[target_module_name] = upper_leaky_output pre_sigmoid[target_module_name] = ci_fn_output @@ -601,6 +735,9 @@ def handle_deprecated_state_dict_keys_(state_dict: dict[str, Tensor]) -> None: ) # module path has "." replaced with "-" new_key = f"_components.{target_module_path.replace('.', '-')}.{new_key.split('.')[-1]}" + # Old checkpoints had _ci_fns.* at top level, now under ci_fn._ci_fns.* + if new_key.startswith("_ci_fns.") and not new_key.startswith("ci_fn."): + new_key = "ci_fn." + new_key # replace if modified if new_key != key: state_dict[new_key] = state_dict.pop(key) diff --git a/spd/models/components.py b/spd/models/components.py index 2827df75c..bb60504cb 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -1,14 +1,19 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Literal, override +from typing import TYPE_CHECKING, Literal, override import einops import torch +import torch.nn.functional as F from jaxtyping import Bool, Float, Int from torch import Tensor, nn from spd.utils.module_utils import _NonlinearityType, init_param_ +if TYPE_CHECKING: + from spd.configs import AttnConfig + from spd.spd_types import LayerwiseCiFnType + class ParallelLinear(nn.Module): """C parallel linear layers""" @@ -42,6 +47,134 @@ def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... d_out"]: return einops.einsum(x, self.W, "... d_in, d_in d_out -> ... d_out") + self.b +class RoPEEmbedding(nn.Module): + """Rotary Position Embedding (RoPE) for sequence modeling. + + Computes position-dependent rotations for query and key tensors to encode + relative position information. Supports arbitrary sequence lengths up to max_len. + """ + + def __init__(self, d_head: int, max_len: int = 2048, base: float = 10000.0): + super().__init__() + assert d_head % 2 == 0, f"RoPE requires even d_head, got {d_head}" + inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head)) + self.register_buffer("inv_freq", inv_freq) + self.max_len = max_len + self.d_head = d_head + + @override + def forward( + self, + q: Float[Tensor, "... n_heads seq d_head"], + k: Float[Tensor, "... n_heads seq d_head"], + ) -> tuple[Float[Tensor, "... n_heads seq d_head"], Float[Tensor, "... n_heads seq d_head"]]: + """Apply rotary embeddings to Q and K tensors.""" + seq_len = q.shape[-2] + assert seq_len <= self.max_len, f"seq_len {seq_len} exceeds max_len {self.max_len}" + + assert isinstance(self.inv_freq, Tensor) + positions = torch.arange(seq_len, device=q.device, dtype=self.inv_freq.dtype) + angles = einops.einsum(positions, self.inv_freq, "seq, d -> seq d") + # Create full rotation: [cos, cos] and [sin, sin] interleaved + cos_emb = torch.cat([angles.cos(), angles.cos()], dim=-1) + sin_emb = torch.cat([angles.sin(), angles.sin()], dim=-1) + + q_rot = self._apply_rotation(q, cos_emb, sin_emb) + k_rot = self._apply_rotation(k, cos_emb, sin_emb) + return q_rot, k_rot + + def _apply_rotation( + self, + x: Float[Tensor, "... n_heads seq d_head"], + cos: Float[Tensor, "seq d_head"], + sin: Float[Tensor, "seq d_head"], + ) -> Float[Tensor, "... n_heads seq d_head"]: + """Apply rotation: x' = x * cos + rotate_half(x) * sin.""" + # Split into first half and second half + x1 = x[..., : self.d_head // 2] + x2 = x[..., self.d_head // 2 :] + # Rotate: [-x2, x1] + x_rotated = torch.cat([-x2, x1], dim=-1) + return x * cos + x_rotated * sin + + +class SelfAttention(nn.Module): + """Multi-head bidirectional self-attention with RoPE positional embeddings.""" + + def __init__(self, d_model: int, n_heads: int, max_len: int = 2048, rope_base: float = 10000.0): + super().__init__() + assert d_model % n_heads == 0, f"d_model={d_model} must be divisible by n_heads={n_heads}" + + self.d_model = d_model + self.n_heads = n_heads + self.d_head = d_model // n_heads + + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.out_proj = nn.Linear(d_model, d_model, bias=False) + + self.rope = RoPEEmbedding(self.d_head, max_len, rope_base) + + @override + def forward(self, x: Float[Tensor, "... seq d_model"]) -> Float[Tensor, "... seq d_model"]: + """Apply bidirectional self-attention with RoPE.""" + *batch_dims, seq_len, _ = x.shape + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + # Reshape to multi-head: (..., seq, d_model) -> (..., n_heads, seq, d_head) + q = q.view(*batch_dims, seq_len, self.n_heads, self.d_head).transpose(-3, -2) + k = k.view(*batch_dims, seq_len, self.n_heads, self.d_head).transpose(-3, -2) + v = v.view(*batch_dims, seq_len, self.n_heads, self.d_head).transpose(-3, -2) + + q, k = self.rope(q, k) + + # Bidirectional attention (no causal mask) + attn_out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=False + ) + + # Reshape back: (..., n_heads, seq, d_head) -> (..., seq, d_model) + attn_out = attn_out.transpose(-3, -2).contiguous().view(*batch_dims, seq_len, self.d_model) + return self.out_proj(attn_out) + + +class TransformerBlock(nn.Module): + """RMSNorm → self-attention → residual → RMSNorm → MLP → residual.""" + + def __init__( + self, + d_model: int, + n_heads: int, + mlp_hidden_dims: list[int], + max_len: int = 2048, + rope_base: float = 10000.0, + ): + super().__init__() + self.attn = SelfAttention( + d_model=d_model, n_heads=n_heads, max_len=max_len, rope_base=rope_base + ) + self.d_model = d_model + + mlp_layers = nn.Sequential() + in_dim = d_model + for hidden_dim in mlp_hidden_dims: + mlp_layers.append(Linear(in_dim, hidden_dim, nonlinearity="relu")) + mlp_layers.append(nn.GELU()) + in_dim = hidden_dim + mlp_layers.append(Linear(in_dim, d_model, nonlinearity="linear")) + self.mlp = mlp_layers + + @override + def forward(self, x: Float[Tensor, "... seq d_model"]) -> Float[Tensor, "... seq d_model"]: + x = x + self.attn(F.rms_norm(x, (self.d_model,))) + x = x + self.mlp(F.rms_norm(x, (self.d_model,))) + return x + + class MLPCiFn(nn.Module): """MLP-based function that creates a scalar output for each component.""" @@ -110,6 +243,283 @@ def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... C"]: return self.layers(x) +class GlobalSharedMLPCiFn(nn.Module): + """Global CI function that concatenates all layer activations and outputs CI for all layers.""" + + def __init__( + self, + layer_configs: dict[str, tuple[int, int]], # layer_name -> (input_dim, C) + hidden_dims: list[int], + ): + super().__init__() + + self.layer_order = sorted(layer_configs.keys()) + self.layer_configs = layer_configs + self.split_sizes = [layer_configs[name][1] for name in self.layer_order] + + total_input_dim = sum(input_dim for input_dim, _ in layer_configs.values()) + total_C = sum(C for _, C in layer_configs.values()) + + self.layers = nn.Sequential() + for i in range(len(hidden_dims)): + in_dim = total_input_dim if i == 0 else hidden_dims[i - 1] + output_dim = hidden_dims[i] + self.layers.append(Linear(in_dim, output_dim, nonlinearity="relu")) + self.layers.append(nn.GELU()) + final_dim = hidden_dims[-1] if len(hidden_dims) > 0 else total_input_dim + self.layers.append(Linear(final_dim, total_C, nonlinearity="linear")) + + @override + def forward( + self, + input_acts: dict[str, Float[Tensor, "... d_in"]], + ) -> dict[str, Float[Tensor, "... C"]]: + inputs_list = [input_acts[name] for name in self.layer_order] + concatenated = torch.cat(inputs_list, dim=-1) + output = self.layers(concatenated) + split_outputs = torch.split(output, self.split_sizes, dim=-1) + return {name: split_outputs[i] for i, name in enumerate(self.layer_order)} + + +@dataclass +class TargetLayerConfig: + input_dim: int + C: int + + +class GlobalSharedTransformerCiFn(nn.Module): + """Global CI function that projects concatenated activations and attends over sequence.""" + + def __init__( + self, + target_model_layer_configs: dict[str, TargetLayerConfig], + d_model: int, + n_layers: int, + n_heads: int, + mlp_hidden_dims: list[int] | None = None, + max_len: int = 2048, + rope_base: float = 10000.0, + ): + super().__init__() + + self.layer_order = sorted(target_model_layer_configs.keys()) + self.target_model_layer_configs = target_model_layer_configs + self.split_sizes = [target_model_layer_configs[name].C for name in self.layer_order] + self.d_model = d_model + self.n_transformer_layers = n_layers + self.n_heads = n_heads + + if mlp_hidden_dims is None: + mlp_hidden_dims = [4 * d_model] + + total_input_dim = sum(config.input_dim for config in target_model_layer_configs.values()) + total_c = sum(config.C for config in target_model_layer_configs.values()) + + self._input_projector = Linear(total_input_dim, d_model, nonlinearity="relu") + self._output_head = Linear(d_model, total_c, nonlinearity="linear") + + self._blocks = nn.ModuleList( + [ + TransformerBlock( + d_model=d_model, + n_heads=n_heads, + mlp_hidden_dims=mlp_hidden_dims, + max_len=max_len, + rope_base=rope_base, + ) + for _ in range(n_layers) + ] + ) + + @override + def forward( + self, + input_acts: dict[str, Float[Tensor, "... d_in"]], + ) -> dict[str, Float[Tensor, "... C"]]: + inputs_list = [ + F.rms_norm(input_acts[name], (input_acts[name].shape[-1],)) for name in self.layer_order + ] + concatenated = torch.cat(inputs_list, dim=-1) + projected: Tensor = self._input_projector(concatenated) + + # The transformer blocks expect a sequence dimension, so we add an extra dimension to our + # activations if we only have 2D acts (e.g. in TMS and resid_mlp). + added_seq_dim = False + if projected.ndim < 3: + projected = projected.unsqueeze(-2) + added_seq_dim = True + + x = projected + for block in self._blocks: + x = block(x) + + output = self._output_head(x) + + if added_seq_dim: + output = output.squeeze(-2) + + split_outputs = torch.split(output, self.split_sizes, dim=-1) + outputs = {name: split_outputs[i] for i, name in enumerate(self.layer_order)} + + return outputs + + +class GlobalReverseResidualCiFn(nn.Module): + """Global CI function that processes blocks in reverse order with a residual stream. + + Architecture: + 1. Initialize residual stream to zeros (batch..., d_resid_ci_fn) + 2. Process blocks in order (typically: unembed → layer N MLP → layer N attn → ... → embed) + 3. For each block: + - Concat relevant activations from all modules in the block + - Project to d_resid_ci_fn and add to residual stream + - RMSNorm → Reader MLP outputs CI values for all modules in block + - Transition updates residual stream for next block (except after last block): + - If attn_config is provided: RMSNorm → attn → add → RMSNorm → MLP(GeLU) → add + - Otherwise: RMSNorm → MLP(GeLU) → add + - If transition_hidden_dim is provided: MLP d_resid → hidden → d_resid with GeLU + - Otherwise: linear d_resid → d_resid + + """ + + def __init__( + self, + block_configs: list[tuple[str, list[str], list[int], list[int]]], + d_resid_ci_fn: int, + reader_hidden_dims: list[int], + transition_hidden_dim: int | None, + attn_config: "AttnConfig | None" = None, + ): + """Initialize the reverse residual CI function. + + Args: + block_configs: List of (block_name, module_names, input_dims, c_values) tuples. + Ordered in processing order (first block processed first). + d_resid_ci_fn: Dimension of the residual stream. + reader_hidden_dims: Hidden dimensions for reader MLPs. + transition_hidden_dim: Hidden dimension for transition MLPs. + attn_config: Optional config for self-attention in transitions. If provided, + transitions use a transformer block (attention → residual → MLP → residual). + """ + super().__init__() + + if attn_config is not None: + assert d_resid_ci_fn % attn_config.n_heads == 0, ( + f"d_resid_ci_fn ({d_resid_ci_fn}) must be divisible by " + f"attn_config.n_heads ({attn_config.n_heads})" + ) + d_head = d_resid_ci_fn // attn_config.n_heads + assert d_head % 2 == 0, ( + f"d_head ({d_head}) must be even for RoPE. " + f"d_resid_ci_fn={d_resid_ci_fn}, n_heads={attn_config.n_heads}" + ) + self.d_resid_ci_fn = d_resid_ci_fn + self.n_blocks = len(block_configs) + self.block_safe_names = [name.replace(".", "-") for name, _, _, _ in block_configs] + self.block_module_names = [modules for _, modules, _, _ in block_configs] + self.block_input_dims = [dims for _, _, dims, _ in block_configs] + self.block_c_values = [cs for _, _, _, cs in block_configs] + + self._inp_projectors = nn.ModuleDict() + self._readers = nn.ModuleDict() + self._reader_norms = nn.ModuleDict() + self._transitions = nn.ModuleDict() + self._transition_norms = nn.ModuleDict() + self._attn_transitions: nn.ModuleDict | None = ( + nn.ModuleDict() if attn_config is not None else None + ) + self._attn_norms: nn.ModuleDict | None = ( + nn.ModuleDict() if attn_config is not None else None + ) + + for block_idx, (_, _, input_dims, c_values) in enumerate(block_configs): + safe_name = self.block_safe_names[block_idx] + total_input_dim = sum(input_dims) + total_c = sum(c_values) + + self._inp_projectors[safe_name] = Linear( + total_input_dim, d_resid_ci_fn, nonlinearity="relu" + ) + + reader_layers = nn.Sequential() + for i in range(len(reader_hidden_dims)): + in_dim = d_resid_ci_fn if i == 0 else reader_hidden_dims[i - 1] + out_dim = reader_hidden_dims[i] + reader_layers.append(Linear(in_dim, out_dim, nonlinearity="relu")) + reader_layers.append(nn.GELU()) + final_dim = reader_hidden_dims[-1] if len(reader_hidden_dims) > 0 else d_resid_ci_fn + reader_layers.append(Linear(final_dim, total_c, nonlinearity="linear")) + self._readers[safe_name] = reader_layers + self._reader_norms[safe_name] = nn.RMSNorm(d_resid_ci_fn) + + if block_idx < self.n_blocks - 1: + if transition_hidden_dim is not None: + transition = nn.Sequential( + Linear(d_resid_ci_fn, transition_hidden_dim, nonlinearity="relu"), + nn.GELU(), + Linear(transition_hidden_dim, d_resid_ci_fn, nonlinearity="relu"), + ) + else: + transition = Linear(d_resid_ci_fn, d_resid_ci_fn, nonlinearity="relu") + self._transitions[safe_name] = transition + self._transition_norms[safe_name] = nn.RMSNorm(d_resid_ci_fn) + if attn_config is not None: + assert self._attn_transitions is not None + self._attn_transitions[safe_name] = SelfAttention( + d_model=d_resid_ci_fn, + n_heads=attn_config.n_heads, + max_len=attn_config.max_len, + rope_base=attn_config.rope_base, + ) + assert self._attn_norms is not None + self._attn_norms[safe_name] = nn.RMSNorm(d_resid_ci_fn) + + @override + def forward( + self, + input_acts: dict[str, Float[Tensor, "... d_in"]], + ) -> dict[str, Float[Tensor, "... C"]]: + first_tensor = next(iter(input_acts.values())) + batch_shape = first_tensor.shape[:-1] + device = first_tensor.device + dtype = first_tensor.dtype + + residual = torch.zeros(*batch_shape, self.d_resid_ci_fn, device=device, dtype=dtype) + + all_outputs: dict[str, Float[Tensor, "... C"]] = {} + + for block_idx in range(self.n_blocks): + safe_name = self.block_safe_names[block_idx] + module_names = self.block_module_names[block_idx] + c_values = self.block_c_values[block_idx] + + block_acts = [input_acts[name] for name in module_names] + concat_acts = torch.cat(block_acts, dim=-1) + + projection = self._inp_projectors[safe_name](concat_acts) + residual = residual + projection + + if block_idx < self.n_blocks - 1: + # With attention: norm → attn → residual add → norm → MLP → residual add + # Without attention: norm → MLP → residual add + if self._attn_transitions is not None: + assert self._attn_norms is not None + attn_out = self._attn_transitions[safe_name]( + self._attn_norms[safe_name](residual) + ) + residual = residual + attn_out + + mlp_out = self._transitions[safe_name](self._transition_norms[safe_name](residual)) + residual = residual + mlp_out + ci_output = self._readers[safe_name](self._reader_norms[safe_name](residual)) + + split_outputs = torch.split(ci_output, c_values, dim=-1) + for module_name, module_ci in zip(module_names, split_outputs, strict=True): + all_outputs[module_name] = module_ci + + return all_outputs + + WeightDeltaAndMask = tuple[Float[Tensor, "d_out d_in"], Float[Tensor, "..."]] @@ -360,3 +770,81 @@ def make_mask_infos( ) return result + + +class LayerwiseCiFnWrapper(nn.Module): + """Wraps a dict of per-layer CI functions with a unified interface. + + Calls each layer's CI function independently on its corresponding input activations. + """ + + def __init__( + self, + ci_fns: dict[str, nn.Module], + components: dict[str, Components], + ci_fn_type: "LayerwiseCiFnType", + ): + super().__init__() + self.layer_names = sorted(ci_fns.keys()) + self.components = components + self.ci_fn_type = ci_fn_type + + # Store as ModuleDict with "." replaced by "-" for state dict compatibility + self._ci_fns = nn.ModuleDict( + {name.replace(".", "-"): ci_fns[name] for name in self.layer_names} + ) + + @override + def forward( + self, + layer_acts: dict[str, Float[Tensor, "..."]], + ) -> dict[str, Float[Tensor, "... C"]]: + outputs: dict[str, Float[Tensor, "... C"]] = {} + + for layer_name in self.layer_names: + ci_fn = self._ci_fns[layer_name.replace(".", "-")] + input_acts = layer_acts[layer_name] + + # MLPCiFn expects component activations, others take raw input + if self.ci_fn_type == "mlp": + ci_fn_input = self.components[layer_name].get_component_acts(input_acts) + else: + ci_fn_input = input_acts + + outputs[layer_name] = ci_fn(ci_fn_input) + + return outputs + + +class GlobalCiFnWrapper(nn.Module): + """Wraps global CI functions with a unified interface. + + Transforms embedding layer inputs to component activations before calling + the underlying global CI function. + """ + + def __init__( + self, + global_ci_fn: GlobalSharedMLPCiFn | GlobalSharedTransformerCiFn | GlobalReverseResidualCiFn, + components: dict[str, Components], + ): + super().__init__() + self._global_ci_fn = global_ci_fn + self.components = components + + @override + def forward( + self, + layer_acts: dict[str, Float[Tensor, "..."]], + ) -> dict[str, Float[Tensor, "... C"]]: + transformed: dict[str, Float[Tensor, ...]] = {} + + for layer_name, acts in layer_acts.items(): + component = self.components[layer_name] + if isinstance(component, EmbeddingComponents): + # Embeddings pass token IDs; convert to component activations + transformed[layer_name] = component.get_component_acts(acts) + else: + transformed[layer_name] = acts + + return self._global_ci_fn(transformed) diff --git a/spd/persistent_pgd.py b/spd/persistent_pgd.py new file mode 100644 index 000000000..1200829f1 --- /dev/null +++ b/spd/persistent_pgd.py @@ -0,0 +1,335 @@ +"""Persistent PGD: Persistent adversarial sources that evolve across training steps. + +Instead of reinitializing PGD sources each training step and running N optimization steps, +PersistentPGD maintains persistent sources that receive one gradient update per training step. +Over many steps, these sources converge to strong adversarial configurations. + +The key insight is that this amortizes PGD optimization across training steps - getting the +benefit of many PGD steps without the per-step computational cost. +""" + +from abc import ABC, abstractmethod +from typing import Literal, override + +import torch +from jaxtyping import Float, Int +from torch import Tensor +from torch.distributed import ReduceOp + +from spd.configs import ( + AdamPGDConfig, + BroadcastAcrossBatchScope, + PerBatchPerPositionScope, + PersistentPGDReconLossConfig, + PersistentPGDReconSubsetLossConfig, + PGDOptimizerConfig, + RepeatAcrossBatchScope, + SignPGDConfig, + SingleSourceScope, +) +from spd.models.component_model import ComponentModel +from spd.models.components import ComponentsMaskInfo, RoutingMasks, make_mask_infos +from spd.routing import AllLayersRouter, Router, get_subset_router +from spd.utils.distributed_utils import all_reduce, broadcast_tensor +from spd.utils.general_utils import calc_sum_recon_loss_lm, get_scheduled_value + +PPGDSources = dict[str, Float[Tensor, " source_c"]] + + +class PPGDOptimizer(ABC): + """Interface for persistent PGD optimizers.""" + + @abstractmethod + def init_state(self, sources: PPGDSources) -> None: + """Initialize any optimizer-specific state for the given sources.""" + + @abstractmethod + def step(self, sources: PPGDSources, grads: PPGDSources) -> None: + """Perform one update step on sources using gradients. Updates sources in-place.""" + + @abstractmethod + def set_lr(self, lr: float) -> None: + """Update the learning rate / step size.""" + + +class SignPGDOptimizer(PPGDOptimizer): + def __init__(self, cfg: SignPGDConfig) -> None: + self._step_size = cfg.lr_schedule.start_val + + @override + def init_state(self, sources: PPGDSources) -> None: + pass + + @override + def step(self, sources: PPGDSources, grads: PPGDSources) -> None: + for module_name in sources: + sources[module_name].add_(self._step_size * grads[module_name].sign()) + + @override + def set_lr(self, lr: float) -> None: + self._step_size = lr + + +class AdamPGDOptimizer(PPGDOptimizer): + def __init__(self, cfg: AdamPGDConfig) -> None: + self._lr = cfg.lr_schedule.start_val + self._beta1 = cfg.beta1 + self._beta2 = cfg.beta2 + self._eps = cfg.eps + self._step_count = 0 + self._m: PPGDSources = {} + self._v: PPGDSources = {} + + @override + def init_state(self, sources: PPGDSources) -> None: + for module_name, source in sources.items(): + self._m[module_name] = torch.zeros_like(source) + self._v[module_name] = torch.zeros_like(source) + + @override + def step(self, sources: PPGDSources, grads: PPGDSources) -> None: + self._step_count += 1 + bias_correction1 = 1 - self._beta1**self._step_count + bias_correction2 = 1 - self._beta2**self._step_count + for module_name, source in sources.items(): + grad = grads[module_name] + m = self._m[module_name] + v = self._v[module_name] + m.mul_(self._beta1).add_(grad, alpha=1 - self._beta1) + v.mul_(self._beta2).addcmul_(grad, grad, value=1 - self._beta2) + m_hat = m / bias_correction1 + v_hat = v / bias_correction2 + denom = v_hat.sqrt().add_(self._eps) + source.add_(self._lr * m_hat / denom) + + @override + def set_lr(self, lr: float) -> None: + self._lr = lr + + +def make_ppgd_optimizer(cfg: PGDOptimizerConfig) -> PPGDOptimizer: + match cfg: + case SignPGDConfig(): + return SignPGDOptimizer(cfg) + case AdamPGDConfig(): + return AdamPGDOptimizer(cfg) + + +class PersistentPGDState: + """Persistent state for persistent PGD optimization. + + Holds adversarial sources per module that persist across training steps. + Source shape depends on scope: shared across batch (SingleSource, BroadcastAcrossBatch), + repeated along batch dim (RepeatAcrossBatch), or per-batch-element-per-position with no + cross-rank synchronization (PerBatchPerPosition). + """ + + def __init__( + self, + module_to_c: dict[str, int], + batch_dims: tuple[int, ...], + device: torch.device | str, + use_delta_component: bool, + cfg: PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, + output_loss_type: Literal["mse", "kl"], + ) -> None: + self.optimizer = make_ppgd_optimizer(cfg.optimizer) + self._skip_all_reduce = isinstance(cfg.scope, PerBatchPerPositionScope) + self._use_sigmoid_parameterization = cfg.use_sigmoid_parameterization + self._router = _get_router_for_ppgd_config(cfg, device) + self._n_warmup_steps = cfg.n_warmup_steps + self._output_loss_type: Literal["mse", "kl"] = output_loss_type + self._lr_schedule = cfg.optimizer.lr_schedule + + self.sources: PPGDSources = {} + + match cfg.scope: + case SingleSourceScope(): + source_leading_dims = [1] * len(batch_dims) + case BroadcastAcrossBatchScope(): + source_leading_dims = [1] + list(batch_dims[1:]) + case RepeatAcrossBatchScope(n_sources=n): + assert batch_dims[0] % n == 0, ( + f"n_sources={n} must divide the per-rank microbatch size " + f"{batch_dims[0]}, not the global batch size. " + f"Adjust n_sources or batch_size to satisfy this." + ) + source_leading_dims = [n] + list(batch_dims[1:]) + case PerBatchPerPositionScope(): + source_leading_dims = list(batch_dims) + + init_fn = torch.randn if self._use_sigmoid_parameterization else torch.rand + for module_name, module_c in module_to_c.items(): + source_c = module_c + 1 if use_delta_component else module_c + source_shape = source_leading_dims + [source_c] + source_data = init_fn(source_shape, device=device) + if not self._skip_all_reduce: + broadcast_tensor(source_data) + self.sources[module_name] = source_data.requires_grad_(True) + + self.optimizer.init_state(self.sources) + + def get_grads(self, loss: Float[Tensor, ""], retain_graph: bool = True) -> PPGDSources: + grads = torch.autograd.grad(loss, list(self.sources.values()), retain_graph=retain_graph) + + if self._skip_all_reduce: + return dict(zip(self.sources.keys(), grads, strict=True)) + return { + k: all_reduce(g, op=ReduceOp.AVG) + for k, g in zip(self.sources.keys(), grads, strict=True) + } + + def step(self, grads: PPGDSources) -> None: + """Perform one PGD update step using the provided gradients. + + Updates sources in-place, then clamps to [0, 1] (or leaves unbounded when using sigmoid + parameterization, where sigmoid is applied when reading effective sources). + """ + with torch.no_grad(): + self.optimizer.step(self.sources, grads) + + if not self._use_sigmoid_parameterization: + for source in self.sources.values(): + source.clamp_(0.0, 1.0) + + def get_effective_sources(self) -> PPGDSources: + """Return sources in [0, 1] range. + + If using sigmoid parameterization, applies sigmoid to unconstrained values. Otherwise + returns raw sources (already clamped to [0, 1]). + """ + if self._use_sigmoid_parameterization: + return {k: torch.sigmoid(v) for k, v in self.sources.items()} + return self.sources + + def update_lr(self, step: int, total_steps: int) -> None: + lr = get_scheduled_value(step, total_steps, self._lr_schedule) + self.optimizer.set_lr(lr) + + def warmup( + self, + model: ComponentModel, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + target_out: Float[Tensor, "... vocab"], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + ) -> None: + """Run extra PGD steps to refine adversarial sources before the final loss computation. + + Each step computes the recon loss, extracts gradients, and updates sources in-place. + When n_warmup_steps=0 (default), this is a no-op. + """ + for _ in range(self._n_warmup_steps): + loss = self.compute_recon_loss(model, batch, target_out, ci, weight_deltas) + grads = self.get_grads(loss, retain_graph=False) + self.step(grads) + + def compute_recon_loss( + self, + model: ComponentModel, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + target_out: Float[Tensor, "... vocab"], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + ) -> Float[Tensor, ""]: + """Pure forward pass that returns the PPGD reconstruction loss. No source mutation.""" + batch_dims = next(iter(ci.values())).shape[:-1] + routing_masks = self._router.get_masks( + module_names=model.target_module_paths, mask_shape=batch_dims + ) + ppgd_sources = self.get_effective_sources() + sum_loss, n_examples = _compute_ppgd_recon_loss( + model=model, + ppgd_sources=ppgd_sources, + output_loss_type=self._output_loss_type, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=weight_deltas, + routing_masks=routing_masks, + ) + return sum_loss / n_examples + + +def _get_router_for_ppgd_config( + cfg: PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, + device: torch.device | str, +) -> Router: + match cfg: + case PersistentPGDReconLossConfig(): + return AllLayersRouter() + case PersistentPGDReconSubsetLossConfig(routing=routing): + return get_subset_router(routing, device) + + +def get_ppgd_mask_infos( + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + ppgd_sources: dict[str, Float[Tensor, "*batch_dims source_c"]], + routing_masks: RoutingMasks, + batch_dims: tuple[int, ...], +) -> dict[str, ComponentsMaskInfo]: + """Get mask infos for persistent PGD.""" + + expanded_adv_sources: dict[str, Float[Tensor, "*batch_dims source_c"]] = {} + for module_name, source in ppgd_sources.items(): + B = batch_dims[0] + N = source.shape[0] + if N == 1 or N == B: + expanded_adv_sources[module_name] = source.expand(*batch_dims, -1) + else: + assert B % N == 0, f"source leading dim {N} must divide batch dim {B}" + repeat_dims = (B // N,) + (1,) * (source.ndim - 1) + expanded_adv_sources[module_name] = source.repeat(*repeat_dims) + + # Split into component sources and weight delta sources + adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] + weight_deltas_and_masks: ( + dict[str, tuple[Float[Tensor, "d_out d_in"], Float[Tensor, ...]]] | None + ) + match weight_deltas: + case None: + weight_deltas_and_masks = None + adv_sources_components = expanded_adv_sources + case dict(): + weight_deltas_and_masks = { + k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas + } + adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} + + component_masks = _interpolate_component_mask(ci, adv_sources_components) + + return make_mask_infos( + component_masks=component_masks, + weight_deltas_and_masks=weight_deltas_and_masks, + routing_masks=routing_masks, + ) + + +def _interpolate_component_mask( + ci: dict[str, Float[Tensor, "... C"]], + adv_sources: dict[str, Float[Tensor, "... C"]], +) -> dict[str, Float[Tensor, "... C"]]: + """Interpolate CI with adversarial sources: mask = ci + (1 - ci) * adv.""" + return {name: ci[name] + (1 - ci[name]) * adv_sources[name] for name in ci} + + +def _compute_ppgd_recon_loss( + model: ComponentModel, + ppgd_sources: PPGDSources, + output_loss_type: Literal["mse", "kl"], + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + target_out: Float[Tensor, "... vocab"], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + routing_masks: RoutingMasks, +) -> tuple[Float[Tensor, ""], int]: + assert ci, "Empty ci" + batch_dims = next(iter(ci.values())).shape[:-1] + + mask_infos = get_ppgd_mask_infos(ci, weight_deltas, ppgd_sources, routing_masks, batch_dims) + out = model(batch, mask_infos=mask_infos) + loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) + n_examples = out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + + return loss, n_examples diff --git a/spd/postprocess/__init__.py b/spd/postprocess/__init__.py new file mode 100644 index 000000000..ae0795e97 --- /dev/null +++ b/spd/postprocess/__init__.py @@ -0,0 +1,166 @@ +"""Unified postprocessing pipeline for decomposition runs. + +Submits all postprocessing steps to SLURM with proper dependency chaining. +All steps always run — data accumulates (harvest upserts, autointerp resumes). + +Dependency graph: + harvest (GPU array -> merge, GPU, SPD-only) + ├── intruder eval (CPU, depends on harvest merge, label-free) + ├── attributions (GPU array -> merge, depends on harvest merge, SPD-only) + └── autointerp (CPU, LLM calls, resumes via completed keys) + ├── detection (CPU, label-dependent) + └── fuzzing (CPU, label-dependent) +""" + +import secrets +from datetime import datetime +from pathlib import Path + +import yaml + +from spd.autointerp.scripts.run_slurm import AutointerpSubmitResult, submit_autointerp +from spd.dataset_attributions.scripts.run_slurm import submit_attributions +from spd.graph_interp.scripts.run_slurm import GraphInterpSubmitResult, submit_graph_interp +from spd.harvest.config import SPDHarvestConfig +from spd.harvest.scripts import run_intruder +from spd.harvest.scripts.run_slurm import submit_harvest +from spd.log import logger +from spd.postprocess.config import PostprocessConfig +from spd.settings import SPD_OUT_DIR +from spd.utils.git_utils import create_git_snapshot +from spd.utils.slurm import SlurmConfig, SubmitResult, generate_script, submit_slurm_job + + +def postprocess(config: PostprocessConfig, dependency_job_id: str | None = None) -> Path: + """Submit all postprocessing jobs with SLURM dependency chaining. + + Args: + config: Postprocessing configuration. + dependency_job_id: SLURM job to wait for before starting harvest + (e.g. a training job that must complete first). + + Returns: + Path to the manifest YAML file. + """ + + snapshot_branch, commit_hash = create_git_snapshot(f"postprocess-{secrets.token_hex(4)}") + logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + + decomp_cfg = config.harvest.config.method_config + + # === 1. Harvest (always runs, upserts into harvest.db) === + harvest_result = submit_harvest( + config.harvest, + snapshot_branch=snapshot_branch, + dependency_job_id=dependency_job_id, + ) + + # === 2. Autointerp (depends on harvest, resumes via completed keys) === + autointerp_result: AutointerpSubmitResult | None = None + if config.autointerp is not None: + autointerp_result = submit_autointerp( + decomposition_id=decomp_cfg.id, + config=config.autointerp, + dependency_job_id=harvest_result.merge_result.job_id, + snapshot_branch=snapshot_branch, + harvest_subrun_id=harvest_result.subrun_id, + ) + + # === 3. Intruder eval (depends on harvest merge, label-free) === + intruder_result: SubmitResult | None = None + if config.intruder is not None: + intruder_cmd = run_intruder.get_command( + decomposition_id=decomp_cfg.id, + config=config.intruder.config, + harvest_subrun_id=harvest_result.subrun_id, + ) + + intruder_slurm = SlurmConfig( + job_name="spd-intruder-eval", + partition=config.intruder.partition, + n_gpus=2, + time=config.intruder.time, + snapshot_branch=snapshot_branch, + dependency_job_id=harvest_result.merge_result.job_id, + ) + intruder_script = generate_script(intruder_slurm, intruder_cmd) + intruder_result = submit_slurm_job(intruder_script, "intruder_eval") + + logger.section("Intruder eval job submitted") + logger.values( + { + "Job ID": intruder_result.job_id, + "Depends on": f"harvest merge ({harvest_result.merge_result.job_id})", + "Log": intruder_result.log_pattern, + } + ) + + # === 4. Attributions (depends on harvest merge, SPD-only) === + attr_result = None + if config.attributions is not None: + assert isinstance(decomp_cfg, SPDHarvestConfig) + attr_result = submit_attributions( + wandb_path=decomp_cfg.wandb_path, + config=config.attributions, + harvest_subrun_id=harvest_result.subrun_id, + snapshot_branch=snapshot_branch, + dependency_job_id=harvest_result.merge_result.job_id, + ) + + # === 5. Graph interp (depends on harvest merge + attribution merge) === + graph_interp_result: GraphInterpSubmitResult | None = None + if config.graph_interp is not None: + assert attr_result is not None + graph_interp_result = submit_graph_interp( + decomposition_id=decomp_cfg.id, + config=config.graph_interp, + dependency_job_ids=[ + harvest_result.merge_result.job_id, + attr_result.merge_result.job_id, + ], + snapshot_branch=snapshot_branch, + harvest_subrun_id=harvest_result.subrun_id, + ) + + # === Write manifest === + manifest_id = "pp-" + datetime.now().strftime("%Y%m%d_%H%M%S") + manifest_dir = SPD_OUT_DIR / "postprocess" / manifest_id + manifest_dir.mkdir(parents=True, exist_ok=True) + manifest_path = manifest_dir / "manifest.yaml" + + jobs: dict[str, str] = { + "harvest_array": harvest_result.array_result.job_id, + "harvest_merge": harvest_result.merge_result.job_id, + "harvest_subrun": harvest_result.subrun_id, + } + if intruder_result is not None: + jobs["intruder_eval"] = intruder_result.job_id + if attr_result is not None: + jobs["attr_array"] = attr_result.array_result.job_id + jobs["attr_merge"] = attr_result.merge_result.job_id + jobs["attr_subrun"] = attr_result.subrun_id + if autointerp_result is not None: + jobs["interpret"] = autointerp_result.interpret_result.job_id + if autointerp_result.detection_result is not None: + jobs["detection"] = autointerp_result.detection_result.job_id + if autointerp_result.fuzzing_result is not None: + jobs["fuzzing"] = autointerp_result.fuzzing_result.job_id + if graph_interp_result is not None: + jobs["graph_interp"] = graph_interp_result.result.job_id + + manifest = { + "timestamp": datetime.now().isoformat(timespec="seconds"), + "decomposition": config.harvest.config.method_config.model_dump(), + "snapshot_branch": snapshot_branch, + "commit_hash": commit_hash, + "config": config.model_dump(), + "jobs": jobs, + } + + with open(manifest_path, "w") as f: + yaml.dump(manifest, f, default_flow_style=False, sort_keys=False) + + logger.section("Postprocess manifest saved") + logger.info(str(manifest_path)) + + return manifest_path diff --git a/spd/postprocess/cli.py b/spd/postprocess/cli.py new file mode 100644 index 000000000..2e6a6b3e7 --- /dev/null +++ b/spd/postprocess/cli.py @@ -0,0 +1,44 @@ +"""CLI entry point for unified postprocessing pipeline. + +Thin wrapper for fast --help. Heavy imports deferred to postprocess.py. + +Uses argparse instead of Fire because SLURM job IDs like "311644_1" get +parsed by Fire as integers (underscore is a numeric separator in Python). + +Usage: + spd-postprocess config.yaml + spd-postprocess config.yaml --dependency 311644_1 +""" + +import argparse + + +def main() -> None: + parser = argparse.ArgumentParser(description="Submit all postprocessing jobs for an SPD run.") + parser.add_argument("config", help="Path to PostprocessConfig YAML.") + parser.add_argument( + "--dependency", + help="SLURM job ID to wait for before starting (e.g. a training job).", + ) + parser.add_argument("--dry_run", action="store_true") + args = parser.parse_args() + + import yaml + + from spd.log import logger + from spd.postprocess import postprocess + from spd.postprocess.config import PostprocessConfig + + cfg = PostprocessConfig.from_file(args.config) + + if args.dry_run: + logger.info("Dry run: skipping submission\n\nConfig:\n") + logger.info(yaml.dump(cfg.model_dump(), indent=2, sort_keys=False)) + return + + manifest_path = postprocess(config=cfg, dependency_job_id=args.dependency) + logger.info(f"Manifest: {manifest_path}") + + +def cli() -> None: + main() diff --git a/spd/postprocess/config.py b/spd/postprocess/config.py new file mode 100644 index 000000000..858df94bb --- /dev/null +++ b/spd/postprocess/config.py @@ -0,0 +1,52 @@ +"""Postprocess pipeline configuration. + +PostprocessConfig composes sub-configs for harvest, attributions, autointerp, +and intruder eval. Set any section to null to skip that pipeline stage. +""" + +from typing import Any, override + +from spd.autointerp.config import AutointerpSlurmConfig +from spd.base_config import BaseConfig +from spd.dataset_attributions.config import AttributionsSlurmConfig +from spd.graph_interp.config import GraphInterpSlurmConfig +from spd.harvest.config import HarvestSlurmConfig, IntruderSlurmConfig, SPDHarvestConfig + + +class PostprocessConfig(BaseConfig): + """Top-level config for the unified postprocessing pipeline. + + Composes sub-configs for each pipeline stage. Set a section to null + to skip that stage entirely. + + Dependency graph: + harvest (GPU array -> merge) + ├── intruder eval (CPU, depends on harvest merge, label-free) + └── autointerp (depends on harvest merge) + ├── interpret + │ ├── detection + │ └── fuzzing + attributions (GPU array -> merge, depends on harvest merge) + """ + + harvest: HarvestSlurmConfig + autointerp: AutointerpSlurmConfig | None + intruder: IntruderSlurmConfig | None + attributions: AttributionsSlurmConfig | None + graph_interp: GraphInterpSlurmConfig | None + + @override + def model_post_init(self, __context: Any) -> None: + expects_attributions = self.attributions is not None + is_not_spd = not isinstance(self.harvest.config.method_config, SPDHarvestConfig) + if expects_attributions and is_not_spd: + raise ValueError("Attributions only work for SPD decompositions") + if self.graph_interp is not None and self.attributions is None: + raise ValueError("Graph interp requires attributions") + + +if __name__ == "__main__": + import json + + with open("spd/postprocess/postprocess.schema.json", "w") as f: + json.dump(PostprocessConfig.model_json_schema(), f, indent=2) diff --git a/spd/postprocess/pile.yaml b/spd/postprocess/pile.yaml new file mode 100644 index 000000000..0a5de2183 --- /dev/null +++ b/spd/postprocess/pile.yaml @@ -0,0 +1,5 @@ +autointerp: + config: + type: compact_skeptical + forbidden_words: [] + cost_limit_usd: 100 \ No newline at end of file diff --git a/spd/pretrain/configs/pile_llama_simple_mlp-12L-768.yaml b/spd/pretrain/configs/pile_llama_simple_mlp-12L-768.yaml index 5336408ef..0d1c24e2c 100644 --- a/spd/pretrain/configs/pile_llama_simple_mlp-12L-768.yaml +++ b/spd/pretrain/configs/pile_llama_simple_mlp-12L-768.yaml @@ -1,7 +1,7 @@ wandb_project: spd dtype: bfloat16 batch_size: 1024 -num_iterations: 100_000 +num_iterations: 200_000 warmup_iters: 600 learning_rate: 3e-4 learning_rate_decay_frac: 0.1 @@ -26,7 +26,7 @@ model: flash_attention: false train_dataset_config: - name: danbraunai/pile-uncopyrighted-tok + name: danbraunai/pile-uncopyrighted-tok-shuffled is_tokenized: true hf_tokenizer_path: EleutherAI/gpt-neox-20b split: train @@ -36,7 +36,7 @@ train_dataset_config: column_name: input_ids val_dataset_config: - name: danbraunai/pile-uncopyrighted-tok + name: danbraunai/pile-uncopyrighted-tok-shuffled is_tokenized: true hf_tokenizer_path: EleutherAI/gpt-neox-20b split: val @@ -45,3 +45,5 @@ val_dataset_config: seed: 0 column_name: input_ids + + diff --git a/spd/pretrain/configs/pile_llama_simple_mlp-4L-768.yaml b/spd/pretrain/configs/pile_llama_simple_mlp-4L-768.yaml index 6aa5badc7..f5cc7abb7 100644 --- a/spd/pretrain/configs/pile_llama_simple_mlp-4L-768.yaml +++ b/spd/pretrain/configs/pile_llama_simple_mlp-4L-768.yaml @@ -33,7 +33,7 @@ model: vocab_size: 50277 train_dataset_config: - name: danbraunai/pile-uncopyrighted-tok + name: danbraunai/pile-uncopyrighted-tok-shuffled is_tokenized: true hf_tokenizer_path: EleutherAI/gpt-neox-20b split: train @@ -43,7 +43,7 @@ train_dataset_config: column_name: input_ids val_dataset_config: - name: danbraunai/pile-uncopyrighted-tok + name: danbraunai/pile-uncopyrighted-tok-shuffled is_tokenized: true hf_tokenizer_path: EleutherAI/gpt-neox-20b split: val diff --git a/spd/registry.py b/spd/registry.py index b5c92e538..fd8c8f5d5 100644 --- a/spd/registry.py +++ b/spd/registry.py @@ -37,42 +37,42 @@ class ExperimentConfig: decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_5-2_config.yaml"), expected_runtime=4, - canonical_run="wandb:goodfire/spd/runs/nbejm03m", + canonical_run="wandb:goodfire/spd/runs/s-38e1a3e2", ), "tms_5-2-id": ExperimentConfig( task_name="tms", decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_5-2-id_config.yaml"), expected_runtime=4, - canonical_run="wandb:goodfire/spd/runs/2orsxfx4", + canonical_run="wandb:goodfire/spd/runs/s-a1c0e9e2", ), "tms_40-10": ExperimentConfig( task_name="tms", decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_40-10_config.yaml"), expected_runtime=5, - canonical_run="wandb:goodfire/spd/runs/nb25nhgw", + canonical_run="wandb:goodfire/spd/runs/s-7387fc20", ), "tms_40-10-id": ExperimentConfig( task_name="tms", decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_40-10-id_config.yaml"), expected_runtime=5, - canonical_run="wandb:goodfire/spd/runs/eobwic8t", + canonical_run="wandb:goodfire/spd/runs/s-2a2b5a57", ), "resid_mlp1": ExperimentConfig( task_name="resid_mlp", decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), config_path=Path("spd/experiments/resid_mlp/resid_mlp1_config.yaml"), expected_runtime=3, - canonical_run="wandb:goodfire/spd/runs/0d2lld8j", + canonical_run="wandb:goodfire/spd/runs/s-62fce8c4", ), "resid_mlp2": ExperimentConfig( task_name="resid_mlp", decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), config_path=Path("spd/experiments/resid_mlp/resid_mlp2_config.yaml"), expected_runtime=5, - canonical_run="wandb:goodfire/spd/runs/q9uydy18", + canonical_run="wandb:goodfire/spd/runs/s-a9ad193d", ), "resid_mlp3": ExperimentConfig( task_name="resid_mlp", @@ -151,6 +151,12 @@ class ExperimentConfig: task_name="lm", decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), config_path=Path("spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml"), + expected_runtime=1600, + ), + "ss_llama_simple_mlp-2L-wide_global_reverse": ExperimentConfig( + task_name="lm", + decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), + config_path=Path("spd/experiments/lm/ss_llama_simple_mlp-2L-wide_global_reverse.yaml"), expected_runtime=480, ), "ss_llama_simple_mlp": ExperimentConfig( @@ -171,12 +177,24 @@ class ExperimentConfig: config_path=Path("spd/experiments/lm/pile_llama_simple_mlp-2L.yaml"), expected_runtime=720, ), + "pile_gpt2_simple-2L_global_reverse": ExperimentConfig( + task_name="lm", + decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), + config_path=Path("spd/experiments/lm/pile_gpt2_simple-2L_global_reverse.yaml"), + expected_runtime=3000, + ), "pile_llama_simple_mlp-4L": ExperimentConfig( task_name="lm", decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), config_path=Path("spd/experiments/lm/pile_llama_simple_mlp-4L.yaml"), expected_runtime=1440, ), + "pile_llama_simple_mlp-12L": ExperimentConfig( + task_name="lm", + decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), + config_path=Path("spd/experiments/lm/pile_llama_simple_mlp-12L.yaml"), + expected_runtime=2880, + ), } diff --git a/spd/run_spd.py b/spd/run_spd.py index 37c439f05..d303a24c5 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -24,6 +24,8 @@ Config, LossMetricConfigType, MetricConfigType, + PersistentPGDReconLossConfig, + PersistentPGDReconSubsetLossConfig, PGDMultiBatchConfig, PGDMultiBatchReconLossConfig, PGDMultiBatchReconSubsetLossConfig, @@ -32,9 +34,10 @@ from spd.eval import evaluate, evaluate_multibatch_pgd from spd.identity_insertion import insert_identity_operations_ from spd.log import logger -from spd.losses import compute_total_loss +from spd.losses import compute_losses from spd.metrics import faithfulness_loss from spd.models.component_model import ComponentModel, OutputWithCache +from spd.persistent_pgd import PersistentPGDState from spd.settings import SPD_OUT_DIR from spd.utils.component_utils import calc_ci_l_zero from spd.utils.distributed_utils import ( @@ -62,14 +65,7 @@ def run_faithfulness_warmup( component_params: list[torch.nn.Parameter], config: Config, ) -> None: - """Run faithfulness warmup phase to improve initialization. - - Args: - component_model: The component model to warm up - component_params: List of component parameters to optimize - config: Configuration object containing warmup settings - """ - + """Run faithfulness warmup phase to improve initialization.""" logger.info("Starting faithfulness warmup phase...") assert component_params, "component_params is empty" @@ -156,10 +152,9 @@ def create_pgd_data_iter() -> ( model = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, + ci_config=config.ci_config, sigmoid_type=config.sigmoid_type, + pretrained_model_output_attr=config.pretrained_model_output_attr, ) model.to(device) @@ -200,10 +195,10 @@ def create_pgd_data_iter() -> ( tgt.V.data = src.U.data.T component_params: list[torch.nn.Parameter] = [] - ci_fn_params: list[torch.nn.Parameter] = [] for name in component_model.target_module_paths: component_params.extend(component_model.components[name].parameters()) - ci_fn_params.extend(component_model.ci_fns[name].parameters()) + + ci_fn_params = list(component_model.ci_fn.parameters()) assert len(component_params) > 0, "No parameters found in components to optimize" @@ -213,6 +208,14 @@ def create_pgd_data_iter() -> ( if config.faithfulness_warmup_steps > 0: run_faithfulness_warmup(component_model, component_params, config) + persistent_pgd_configs: list[ + PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig + ] = [ + cfg + for cfg in config.loss_metric_configs + if isinstance(cfg, PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig) + ] + eval_metric_configs = get_unique_metric_configs( loss_configs=config.loss_metric_configs, eval_configs=config.eval_metric_configs ) @@ -224,7 +227,6 @@ def create_pgd_data_iter() -> ( eval_metric_configs = [ cfg for cfg in eval_metric_configs if cfg not in multibatch_pgd_eval_configs ] - batch_dims: tuple[int, ...] | None = None sample_batch = extract_batch_data(next(train_iterator)) batch_dims = ( @@ -233,6 +235,20 @@ def create_pgd_data_iter() -> ( else sample_batch.shape # else it's a batch of token ids ) + ppgd_states: dict[ + PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, PersistentPGDState + ] = { + ppgd_cfg: PersistentPGDState( + module_to_c=model.module_to_c, + batch_dims=batch_dims, + device=device, + use_delta_component=config.use_delta_component, + cfg=ppgd_cfg, + output_loss_type=config.output_loss_type, + ) + for ppgd_cfg in persistent_pgd_configs + } + for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): optimizer.zero_grad() @@ -242,11 +258,14 @@ def create_pgd_data_iter() -> ( for group in optimizer.param_groups: group["lr"] = step_lr + for ppgd_cfg in persistent_pgd_configs: + ppgd_states[ppgd_cfg].update_lr(step, config.steps) + weight_deltas = component_model.calc_weight_deltas() - step_log_data: defaultdict[str, float] = defaultdict(float) + batch_log_data: defaultdict[str, float] = defaultdict(float) - batch = extract_batch_data(next(train_iterator)).to(device) + batch = extract_batch_data(next(train_iterator)).to(device, non_blocking=True) with bf16_autocast(enabled=config.autocast_bf16): # NOTE: we need to call the wrapped_model at least once each step in order @@ -261,62 +280,83 @@ def create_pgd_data_iter() -> ( sampling=config.sampling, ) - total_loss, loss_terms = compute_total_loss( + for ppgd_cfg in persistent_pgd_configs: + ppgd_states[ppgd_cfg].warmup( + model=component_model, + batch=batch, + target_out=target_model_output.output, + ci=ci.lower_leaky, + weight_deltas=weight_deltas if config.use_delta_component else None, + ) + + losses = compute_losses( loss_metric_configs=config.loss_metric_configs, model=component_model, batch=batch, ci=ci, target_out=target_model_output.output, weight_deltas=weight_deltas, - pre_weight_acts=target_model_output.cache, current_frac_of_training=step / config.steps, sampling=config.sampling, use_delta_component=config.use_delta_component, n_mask_samples=config.n_mask_samples, + ppgd_states=ppgd_states, output_loss_type=config.output_loss_type, ) + total_loss = torch.tensor(0.0, device=device) + for loss_cfg, loss_val in losses.items(): + assert loss_cfg.coeff is not None + total_loss = total_loss + loss_cfg.coeff * loss_val + batch_log_data[f"train/loss/{loss_cfg.classname}"] = loss_val.item() + + batch_log_data["train/loss/total"] = total_loss.item() + + ppgd_grads = { + cfg: ppgd_states[cfg].get_grads(losses[cfg], retain_graph=True) + for cfg in persistent_pgd_configs + } + total_loss.backward() - for loss_name, loss_value in loss_terms.items(): - step_log_data[f"train/{loss_name}"] += loss_value + for ppgd_cfg in persistent_pgd_configs: + ppgd_states[ppgd_cfg].step(ppgd_grads[ppgd_cfg]) for layer_name, layer_ci in ci.lower_leaky.items(): l0_val = calc_ci_l_zero(layer_ci, config.ci_alive_threshold) - step_log_data[f"train/l0/{layer_name}"] += l0_val + batch_log_data[f"train/l0/{layer_name}"] = l0_val # --- Train Logging --- # if step % config.train_log_freq == 0: - avg_metrics = avg_metrics_across_ranks(step_log_data, device=device) - step_log_data = cast(defaultdict[str, float], avg_metrics) + avg_metrics = avg_metrics_across_ranks(batch_log_data, device=device) + batch_log_data = cast(defaultdict[str, float], avg_metrics) grad_norms = get_grad_norms_dict(component_model, device) dict_safe_update_( - step_log_data, {f"train/grad_norms/{k}": v for k, v in grad_norms.items()} + batch_log_data, {f"train/grad_norms/{k}": v for k, v in grad_norms.items()} ) - step_log_data["train/schedules/lr"] = step_lr + batch_log_data["train/schedules/lr"] = step_lr if is_main_process(): assert out_dir is not None tqdm.write(f"--- Step {step} ---") tqdm.write(f"LR: {step_lr:.6f}") - for name, value in step_log_data.items(): + for name, value in batch_log_data.items(): tqdm.write(f"{name}: {value:.15f}") - local_log(step_log_data, step, out_dir) + local_log(batch_log_data, step, out_dir) if config.wandb_project: - try_wandb(wandb.log, step_log_data, step=step) + try_wandb(wandb.log, batch_log_data, step=step) # --- Evaluation --- # if step % config.eval_freq == 0: - with torch.no_grad(): + with torch.no_grad(), bf16_autocast(enabled=config.autocast_bf16): slow_step: bool = ( config.slow_eval_on_first_step if step == 0 else step % config.slow_eval_freq == 0 ) - assert batch_dims is not None, "batch_dims is not set" multibatch_pgd_metrics = evaluate_multibatch_pgd( multibatch_pgd_eval_configs=multibatch_pgd_eval_configs, model=component_model, @@ -335,6 +375,7 @@ def create_pgd_data_iter() -> ( slow_step=slow_step, n_eval_steps=n_eval_steps, current_frac_of_training=step / config.steps, + ppgd_states=ppgd_states, ) dict_safe_update_(metrics, multibatch_pgd_metrics) diff --git a/spd/scripts/plot_component_activations.py b/spd/scripts/plot_component_activations.py index c4079ca24..0817d7709 100644 --- a/spd/scripts/plot_component_activations.py +++ b/spd/scripts/plot_component_activations.py @@ -11,62 +11,18 @@ """ import argparse -import json from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt import numpy as np -import torch - -from spd.harvest.schemas import ( - ActivationExample, - ComponentData, - ComponentTokenPMI, - get_activation_contexts_dir, -) -from spd.settings import SPD_OUT_DIR - - -def load_activation_contexts(run_id: str) -> dict[str, ComponentData]: - """Load all activation contexts.""" - ctx_dir = get_activation_contexts_dir(run_id) - path = ctx_dir / "components.jsonl" - assert path.exists(), f"No harvest data found for run {run_id}" - - components: dict[str, ComponentData] = {} - with open(path) as f: - for line in f: - data = json.loads(line) - data["activation_examples"] = [ - ActivationExample( - token_ids=ex["token_ids"], - ci_values=ex["ci_values"], - component_acts=ex.get("component_acts", [0.0] * len(ex["token_ids"])), - ) - for ex in data["activation_examples"] - ] - data["input_token_pmi"] = ComponentTokenPMI(**data["input_token_pmi"]) - data["output_token_pmi"] = ComponentTokenPMI(**data["output_token_pmi"]) - comp = ComponentData(**data) - components[comp.component_key] = comp - return components - - -def load_firing_counts(run_id: str) -> dict[str, int]: - """Load pre-calculated firing counts from harvest data.""" - token_stats_path = SPD_OUT_DIR / "harvest" / run_id / "correlations" / "token_stats.pt" - assert token_stats_path.exists(), f"No token stats found for run {run_id}" - - data = torch.load(token_stats_path) - component_keys = data["component_keys"] - firing_counts = data["firing_counts"] - - return {key: int(count) for key, count in zip(component_keys, firing_counts, strict=True)} + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData def extract_activations( - contexts: dict[str, ComponentData], + components: list[ComponentData], ci_threshold: float, ) -> tuple[dict[str, dict[str, list[float]]], dict[str, dict[str, list[float]]]]: """Extract component activations, separating all vs above-threshold. @@ -79,10 +35,13 @@ def extract_activations( all_activations: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) filtered_activations: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) - for component_key, component_data in contexts.items(): + for component_data in components: layer = component_data.layer + component_key = component_data.component_key for example in component_data.activation_examples: - for ci_val, act_val in zip(example.ci_values, example.component_acts, strict=True): + ci_vals = example.activations["causal_importance"] + act_vals = example.activations["component_activation"] + for ci_val, act_val in zip(ci_vals, act_vals, strict=True): all_activations[layer][component_key].append(act_val) if ci_val > ci_threshold: filtered_activations[layer][component_key].append(act_val) @@ -183,15 +142,23 @@ def main(): output_dir_median.mkdir(parents=True, exist_ok=True) output_dir_freq.mkdir(parents=True, exist_ok=True) - print(f"Loading activation contexts for run {args.run_id}...") - contexts = load_activation_contexts(args.run_id) - print(f"Loaded {len(contexts)} components") + repo = HarvestRepo.open_most_recent(decomposition_id=args.run_id, readonly=True) + assert repo is not None, f"No harvest data for {args.run_id}" + + print(f"Loading components for run {args.run_id}...") + components = repo.get_all_components() + print(f"Loaded {len(components)} components") print("Loading firing counts...") - firing_counts = load_firing_counts(args.run_id) + token_stats = repo.get_token_stats() + assert token_stats is not None, f"No token stats found for run {args.run_id}" + firing_counts = { + key: int(count) + for key, count in zip(token_stats.component_keys, token_stats.firing_counts, strict=True) + } print("Extracting activations...") - all_by_layer, filtered_by_layer = extract_activations(contexts, args.ci_threshold) + all_by_layer, filtered_by_layer = extract_activations(components, args.ci_threshold) n_layers = len(filtered_by_layer) n_total = sum(sum(len(v) for v in layer.values()) for layer in filtered_by_layer.values()) @@ -201,7 +168,6 @@ def main(): print("No datapoints found above threshold. Try lowering --ci-threshold.") return - # Create plots ordered by median normalized activation print(f"Creating per-layer plots (ordered by median) in {output_dir_median}/...") for layer_name in sorted(all_by_layer.keys()): all_acts = all_by_layer[layer_name] @@ -215,7 +181,6 @@ def main(): create_layer_scatter_plot(normalized, ordered_keys, layer_name, args.run_id, output_path) print(f" {output_path}") - # Create plots ordered by CI activation frequency (with abs distance from midpoint) print(f"Creating per-layer plots (ordered by frequency) in {output_dir_freq}/...") for layer_name in sorted(all_by_layer.keys()): all_acts = all_by_layer[layer_name] @@ -223,7 +188,6 @@ def main(): normalized = normalize_per_component(all_acts, filtered_acts) if not normalized: continue - # Transform to absolute distance from midpoint abs_from_midpoint = {key: np.abs(acts - 0.5) for key, acts in normalized.items()} ordered_keys = order_by_frequency(abs_from_midpoint, firing_counts) safe_name = layer_name.replace(".", "_") diff --git a/spd/scripts/resid_mlp1_global_sweep.yaml b/spd/scripts/resid_mlp1_global_sweep.yaml new file mode 100644 index 000000000..e597742ef --- /dev/null +++ b/spd/scripts/resid_mlp1_global_sweep.yaml @@ -0,0 +1,9 @@ +# Sweep params for resid_mlp1 with global CI function +global: + ci_config: + hidden_dims: + values: [[2000, 1000, 1000, 1000]] + loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: + values: [6e-6, 5.5e-6, 2.5e-6] diff --git a/spd/scripts/resid_mlp2_global_reverse_sweep.yaml b/spd/scripts/resid_mlp2_global_reverse_sweep.yaml new file mode 100644 index 000000000..abcd30e73 --- /dev/null +++ b/spd/scripts/resid_mlp2_global_reverse_sweep.yaml @@ -0,0 +1,11 @@ +# Sweep config for resid_mlp2_global_reverse +# Base config: ImpMin coeff=3e-5, LR start=5e-4 + +global: + lr_schedule: + start_val: + values: [1e-3, 5e-4, 2e-4] + loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: + values: [1e-4, 5e-5, 3e-5, 1e-5] diff --git a/spd/scripts/ss_llama_simple_mlp-2L-wide_global_reverse_sweep.yaml b/spd/scripts/ss_llama_simple_mlp-2L-wide_global_reverse_sweep.yaml new file mode 100644 index 000000000..37726b174 --- /dev/null +++ b/spd/scripts/ss_llama_simple_mlp-2L-wide_global_reverse_sweep.yaml @@ -0,0 +1,8 @@ +# Sweep config for ss_llama_simple_mlp-2L-wide_global_reverse +# Sweeps over importance minimality coefficient + +global: + loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: + values: [1e-2, 1e-3, 1e-4, 1e-5] diff --git a/spd/settings.py b/spd/settings.py index 9e3e37f7b..0dbd08d57 100644 --- a/spd/settings.py +++ b/spd/settings.py @@ -24,3 +24,6 @@ DEFAULT_PARTITION_NAME = "h200-reserved" DEFAULT_PROJECT_NAME = "spd" + +# Default run for the app to load on startup if set +SPD_APP_DEFAULT_RUN: str | None = os.environ.get("SPD_APP_DEFAULT_RUN") diff --git a/spd/spd_types.py b/spd/spd_types.py index 012dc554a..13c0f8996 100644 --- a/spd/spd_types.py +++ b/spd/spd_types.py @@ -45,9 +45,9 @@ def validate_path(v: str | Path) -> str | Path: Path, BeforeValidator(to_root_path), PlainSerializer(lambda x: str(from_root_path(x))) ] - Probability = Annotated[float, Ge(0), Le(1)] - TaskName = Literal["tms", "resid_mlp", "lm", "ih"] - -CiFnType = Literal["mlp", "vector_mlp", "shared_mlp"] +LayerwiseCiFnType = Literal["mlp", "vector_mlp", "shared_mlp"] +GlobalCiFnType = Literal[ + "global_shared_mlp", "global_reverse_residual", "global_shared_transformer" +] diff --git a/spd/topology/__init__.py b/spd/topology/__init__.py new file mode 100644 index 000000000..54d422477 --- /dev/null +++ b/spd/topology/__init__.py @@ -0,0 +1,22 @@ +"""Canonical transformer topology. + +Two layers: +- canonical.py: Pure data types for model-agnostic layer addressing. + No torch dependency. Used by serialization, database, frontend layout. +- topology.py: Bidirectional mapping between canonical and concrete module paths. + Depends on torch.nn and specific model classes. + +Canonical layer address format: + "embed" — embedding + "output" — unembed / logits + "{block}.attn.{p}" — separate attention (p: q | k | v | o) + "{block}.attn_fused.{p}" — fused attention (p: qkv | o) + "{block}.glu.{p}" — gated FFN / SwiGLU (p: up | down | gate) + "{block}.mlp.{p}" — simple FFN (p: up | down) + +Node key format: + "{layer_address}:{seq_pos}:{component_idx}" +""" + +from spd.topology.gradient_connectivity import get_sources_by_target as get_sources_by_target +from spd.topology.topology import TransformerTopology as TransformerTopology diff --git a/spd/topology/canonical.py b/spd/topology/canonical.py new file mode 100644 index 000000000..0e9ee7f9a --- /dev/null +++ b/spd/topology/canonical.py @@ -0,0 +1,110 @@ +"""Canonical weight types for model-agnostic layer addressing. + +Pure data types — no torch dependency. Safe to import anywhere. +""" + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Literal, override + +_EMBED_RE = re.compile(r"^embed$") +_OUTPUT_RE = re.compile(r"^output$") +_LAYER_RE = re.compile(r"^(?P\d+)\.(?Pattn|attn_fused|glu|mlp)\.(?P[a-z]+)$") + + +class CanonicalWeight(ABC): + @abstractmethod + def canonical_str(self) -> str: ... + + @staticmethod + def parse(s: str) -> CanonicalWeight: + """Parse a canonical address string into a CanonicalWeight.""" + m_embed = _EMBED_RE.match(s) + m_output = _OUTPUT_RE.match(s) + m_layer = _LAYER_RE.match(s) + + matches = [m for m in (m_embed, m_output, m_layer) if m is not None] + assert len(matches) == 1, f"Invalid canonical address: {s!r}" + + if m_embed: + return Embed() + if m_output: + return Unembed() + + assert m_layer is not None + layer_idx = int(m_layer.group("layer")) + sublayer = m_layer.group("sublayer") + proj = m_layer.group("proj") + + cls, valid = _SUBLAYER_PROJECTIONS[sublayer] + assert proj in valid, f"Invalid projection {proj!r} for {sublayer!r} in {s!r}" + return LayerWeight(layer_idx, cls(weight=proj)) + + +@dataclass(frozen=True) +class Embed(CanonicalWeight): + @override + def canonical_str(self) -> str: + return "embed" + + +@dataclass(frozen=True) +class Unembed(CanonicalWeight): + @override + def canonical_str(self) -> str: + return "output" + + +@dataclass(frozen=True) +class SeparateAttnWeight: + weight: Literal["q", "k", "v", "o"] + + +@dataclass(frozen=True) +class FusedAttnWeight: + weight: Literal["qkv", "o"] + + +AttnWeight = SeparateAttnWeight | FusedAttnWeight + + +@dataclass(frozen=True) +class GLUWeight: + weight: Literal["up", "down", "gate"] + + +@dataclass(frozen=True) +class MLPWeight: + weight: Literal["up", "down"] + + +FFNWeight = GLUWeight | MLPWeight + + +@dataclass(frozen=True) +class LayerWeight(CanonicalWeight): + layer_idx: int + name: AttnWeight | FFNWeight + + @override + def canonical_str(self) -> str: + match self.name: + case SeparateAttnWeight(weight=p): + return f"{self.layer_idx}.attn.{p}" + case FusedAttnWeight(weight=p): + return f"{self.layer_idx}.attn_fused.{p}" + case GLUWeight(weight=p): + return f"{self.layer_idx}.glu.{p}" + case MLPWeight(weight=p): + return f"{self.layer_idx}.mlp.{p}" + + +_SUBLAYER_PROJECTIONS: dict[str, tuple[type, tuple[str, ...]]] = { + "attn": (SeparateAttnWeight, ("q", "k", "v", "o")), + "attn_fused": (FusedAttnWeight, ("qkv", "o")), + "glu": (GLUWeight, ("up", "down", "gate")), + "mlp": (MLPWeight, ("up", "down")), +} diff --git a/spd/topology/gradient_connectivity.py b/spd/topology/gradient_connectivity.py new file mode 100644 index 000000000..31ba61b5a --- /dev/null +++ b/spd/topology/gradient_connectivity.py @@ -0,0 +1,102 @@ +"""Discover gradient connectivity between layers of a ComponentModel.""" + +from collections import defaultdict +from typing import Any + +import torch +from jaxtyping import Float +from torch import Tensor, nn + +from spd.configs import SamplingType +from spd.models.component_model import ComponentModel, OutputWithCache +from spd.models.components import make_mask_infos +from spd.topology.topology import TransformerTopology +from spd.utils.general_utils import bf16_autocast + + +def get_sources_by_target( + model: ComponentModel, + topology: TransformerTopology, + device: str, + sampling: SamplingType, +) -> dict[str, list[str]]: + """Find valid gradient connections grouped by target layer. + + Includes embedding as a source and unembed as a target, using the topology's + actual module paths (not pseudo-names). + + Returns: + Dict mapping out_layer -> list of in_layers that have gradient flow to it. + """ + # Use a small dummy batch - we only need to trace gradient connections + batch: Float[Tensor, "batch seq"] = torch.zeros(2, 3, dtype=torch.long, device=device) + + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + + # Create masks so we can use all components + mask_infos = make_mask_infos( + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + routing_masks="all", + ) + + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + + # Hook to capture embedding output with gradients + embed_cache: dict[str, Tensor] = {} + + def embed_hook( + _module: nn.Module, _args: tuple[Any, ...], _kwargs: dict[Any, Any], output: Tensor + ) -> Any: + output.requires_grad_(True) + embed_cache[f"{embed_path}_post_detach"] = output + return output + + embed_handle = topology.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) + + with torch.enable_grad(), bf16_autocast(): + comp_output_with_cache: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + + embed_handle.remove() + + cache = comp_output_with_cache.cache + cache[f"{embed_path}_post_detach"] = embed_cache[f"{embed_path}_post_detach"] + cache[f"{unembed_path}_pre_detach"] = comp_output_with_cache.output + + source_layers = [embed_path, *model.target_module_paths] # Don't include "output" as source + target_layers = [*model.target_module_paths, unembed_path] # Don't include embed as target + + # Test all distinct pairs for gradient flow + test_pairs = [] + for source_layer in source_layers: + for target_layer in target_layers: + if source_layer != target_layer: + test_pairs.append((source_layer, target_layer)) + + sources_by_target: dict[str, list[str]] = defaultdict(list) + for source_layer, target_layer in test_pairs: + out_pre_detach = cache[f"{target_layer}_pre_detach"] + in_post_detach = cache[f"{source_layer}_post_detach"] + out_value = out_pre_detach[0, 0, 0] + grads = torch.autograd.grad( + outputs=out_value, + inputs=in_post_detach, + retain_graph=True, + allow_unused=True, + ) + assert len(grads) == 1 + grad = grads[0] + if grad is not None: # pyright: ignore[reportUnnecessaryComparison] + sources_by_target[target_layer].append(source_layer) + return dict(sources_by_target) diff --git a/spd/topology/path_schemas.py b/spd/topology/path_schemas.py new file mode 100644 index 000000000..c9b7a05c2 --- /dev/null +++ b/spd/topology/path_schemas.py @@ -0,0 +1,236 @@ +"""Path schemas: bidirectional mapping between concrete module paths and canonical weights. + +Each model family gets a PathSchema subclass that declares its concrete naming conventions. +These are private implementation details — only TransformerTopology is public. +""" + +import re +from abc import ABC +from dataclasses import dataclass +from typing import Literal + +from torch import nn + +from spd.topology.canonical import ( + CanonicalWeight, + Embed, + FusedAttnWeight, + GLUWeight, + LayerWeight, + MLPWeight, + SeparateAttnWeight, + Unembed, +) + + +@dataclass +class _SeparateAttnPathSchema: + base: str + q: str + k: str + v: str + o: str + + def _lookup(self) -> dict[str, Literal["q", "k", "v", "o"]]: + return {self.q: "q", self.k: "k", self.v: "v", self.o: "o"} + + def _reverse(self) -> dict[str, str]: + return {"q": self.q, "k": self.k, "v": self.v, "o": self.o} + + def parse(self, projection_name: str, layer_idx: int) -> LayerWeight: + table = self._lookup() + assert projection_name in table, f"Unknown attn projection: {projection_name}" + return LayerWeight(layer_idx, SeparateAttnWeight(table[projection_name])) + + def render(self, w: SeparateAttnWeight) -> str: + return f"{self.base}.{self._reverse()[w.weight]}" + + +@dataclass +class _FusedAttnPathSchema: + base: str + qkv: str + o: str + + def _lookup(self) -> dict[str, Literal["qkv", "o"]]: + return {self.qkv: "qkv", self.o: "o"} + + def _reverse(self) -> dict[str, str]: + return {"qkv": self.qkv, "o": self.o} + + def parse(self, projection_name: str, layer_idx: int) -> LayerWeight: + table = self._lookup() + assert projection_name in table, f"Unknown fused attn projection: {projection_name}" + return LayerWeight(layer_idx, FusedAttnWeight(table[projection_name])) + + def render(self, w: FusedAttnWeight) -> str: + return f"{self.base}.{self._reverse()[w.weight]}" + + +@dataclass +class _GLUPathSchema: + base: str + gate: str + up: str + down: str + + def _lookup(self) -> dict[str, Literal["up", "down", "gate"]]: + return {self.gate: "gate", self.up: "up", self.down: "down"} + + def _reverse(self) -> dict[str, str]: + return {"gate": self.gate, "up": self.up, "down": self.down} + + def parse(self, projection_name: str, layer_idx: int) -> LayerWeight: + table = self._lookup() + assert projection_name in table, f"Unknown GLU projection: {projection_name}" + return LayerWeight(layer_idx, GLUWeight(table[projection_name])) + + def render(self, w: GLUWeight) -> str: + return f"{self.base}.{self._reverse()[w.weight]}" + + +@dataclass +class _FFNPathSchema: + base: str + up: str + down: str + + def _lookup(self) -> dict[str, Literal["up", "down"]]: + return {self.up: "up", self.down: "down"} + + def _reverse(self) -> dict[str, str]: + return {"up": self.up, "down": self.down} + + def parse(self, projection_name: str, layer_idx: int) -> LayerWeight: + table = self._lookup() + assert projection_name in table, f"Unknown MLP projection: {projection_name}" + return LayerWeight(layer_idx, MLPWeight(table[projection_name])) + + def render(self, w: MLPWeight) -> str: + return f"{self.base}.{self._reverse()[w.weight]}" + + +class _PathSchema(ABC): + embedding_path: str + blocks: str + attn: _SeparateAttnPathSchema | _FusedAttnPathSchema + mlp: _GLUPathSchema | _FFNPathSchema + unembed_path: str + _block_re: re.Pattern[str] | None = None + + def parse_target_path(self, path: str) -> CanonicalWeight: + if path == self.embedding_path: + return Embed() + if path == self.unembed_path: + return Unembed() + return self._parse_block_path(path) + + def render_canonical_weight(self, weight: CanonicalWeight) -> str: + match weight: + case Embed(): + return self.embedding_path + case Unembed(): + return self.unembed_path + case LayerWeight() as lw: + return self._render_layer_weight(lw) + case _: + raise ValueError(f"Unknown canonical weight: {weight!r}") + + def _parse_block_path(self, path: str) -> LayerWeight: + """Parse a block-level path like 'h.3.attn.q_proj' into a LayerWeight.""" + if self._block_re is None: + attn_base = re.escape(self.attn.base) + mlp_base = re.escape(self.mlp.base) + blocks = re.escape(self.blocks) + self._block_re = re.compile( + rf"^{blocks}\.(?P\d+)\." + rf"(?:(?P{attn_base})\.(?P\w+)" + rf"|(?P{mlp_base})\.(?P\w+))$" + ) + + m = self._block_re.match(path) + assert m is not None, f"Invalid block path: {path!r}" + + layer_idx = int(m.group("idx")) + if m.group("attn"): + return self.attn.parse(m.group("attn_proj"), layer_idx) + return self.mlp.parse(m.group("mlp_proj"), layer_idx) + + def _render_layer_weight(self, w: LayerWeight) -> str: + """Render a LayerWeight into a concrete path.""" + base = f"{self.blocks}.{w.layer_idx}" + match w.name: + case SeparateAttnWeight() as attn_w: + assert isinstance(self.attn, _SeparateAttnPathSchema) + return f"{base}.{self.attn.render(attn_w)}" + case FusedAttnWeight() as attn_w: + assert isinstance(self.attn, _FusedAttnPathSchema) + return f"{base}.{self.attn.render(attn_w)}" + case GLUWeight() as ffn_w: + assert isinstance(self.mlp, _GLUPathSchema) + return f"{base}.{self.mlp.render(ffn_w)}" + case MLPWeight() as ffn_w: + assert isinstance(self.mlp, _FFNPathSchema) + return f"{base}.{self.mlp.render(ffn_w)}" + + +class _LlamaSimplePathSchema(_PathSchema): + embedding_path = "wte" + blocks = "h" + attn = _SeparateAttnPathSchema(base="attn", q="q_proj", k="k_proj", v="v_proj", o="o_proj") + mlp = _GLUPathSchema(base="mlp", gate="gate_proj", up="up_proj", down="down_proj") + unembed_path = "lm_head" + + +class _LlamaSimpleMLPPathSchema(_PathSchema): + embedding_path = "wte" + blocks = "h" + attn = _SeparateAttnPathSchema(base="attn", q="q_proj", k="k_proj", v="v_proj", o="o_proj") + mlp = _FFNPathSchema(base="mlp", up="c_fc", down="down_proj") + unembed_path = "lm_head" + + +class _GPT2SimplePathSchema(_PathSchema): + embedding_path = "wte" + blocks = "h" + attn = _SeparateAttnPathSchema(base="attn", q="q_proj", k="k_proj", v="v_proj", o="o_proj") + mlp = _FFNPathSchema(base="mlp", up="c_fc", down="down_proj") + unembed_path = "lm_head" + + +class _GPT2PathSchema(_PathSchema): + embedding_path = "wte" + blocks = "h_torch" + attn = _FusedAttnPathSchema(base="attn", qkv="c_attn", o="c_proj") + mlp = _FFNPathSchema(base="mlp", up="c_fc", down="c_proj") + unembed_path = "lm_head" + + +class _HFGpt2PathSchema(_PathSchema): + embedding_path = "transformer.wte" + blocks = "transformer.h" + attn = _FusedAttnPathSchema(base="attn", qkv="c_attn", o="c_proj") + mlp = _FFNPathSchema(base="mlp", up="c_fc", down="c_proj") + unembed_path = "lm_head" + + +def get_path_schema(model: nn.Module) -> _PathSchema: + from transformers.models.gpt2 import GPT2LMHeadModel + + from spd.pretrain.models import GPT2, GPT2Simple, LlamaSimple, LlamaSimpleMLP + + match model: + case LlamaSimple(): + return _LlamaSimplePathSchema() + case LlamaSimpleMLP(): + return _LlamaSimpleMLPPathSchema() + case GPT2Simple(): + return _GPT2SimplePathSchema() + case GPT2(): + return _GPT2PathSchema() + case GPT2LMHeadModel(): + return _HFGpt2PathSchema() + case _: + raise ValueError( + f"Unsupported model class {type(model).__name__}. Add a _PathSchema in path_schemas.py." + ) diff --git a/spd/topology/topology.py b/spd/topology/topology.py new file mode 100644 index 000000000..e5c76a7be --- /dev/null +++ b/spd/topology/topology.py @@ -0,0 +1,74 @@ +"""TransformerTopology: the public interface for canonical ↔ concrete path mapping.""" + +from torch import nn + +from spd.topology.canonical import ( + CanonicalWeight, + Embed, + FusedAttnWeight, + LayerWeight, + SeparateAttnWeight, + Unembed, +) +from spd.topology.path_schemas import get_path_schema + + +class TransformerTopology: + """Bidirectional mapping between canonical weights and concrete module paths. + + Built from a target model (nn.Module). Independent of decomposition. + """ + + def __init__(self, target_model: nn.Module) -> None: + self.target_model = target_model + self.path_schema = get_path_schema(target_model) + + def canon_to_target(self, canonical: str) -> str: + return self.path_schema.render_canonical_weight(CanonicalWeight.parse(canonical)) + + def target_to_canon(self, target_module_path: str) -> str: + return self.path_schema.parse_target_path(target_module_path).canonical_str() + + def _get_module(self, canonical: CanonicalWeight) -> nn.Module: + target_path = self.path_schema.render_canonical_weight(canonical) + return self.target_model.get_submodule(target_path) + + @property + def embedding_module(self) -> nn.Embedding: + mod = self._get_module(Embed()) + assert isinstance(mod, nn.Embedding) + return mod + + @property + def unembed_module(self) -> nn.Linear: + mod = self._get_module(Unembed()) + assert isinstance(mod, nn.Linear) + return mod + + @property + def n_blocks(self) -> int: + blocks = self.target_model.get_submodule(self.path_schema.blocks) + assert isinstance(blocks, nn.ModuleList) + return len(blocks) + + def get_unembed_weight(self): + """Unembedding weight matrix transposed to [d_model, vocab].""" + return self.unembed_module.weight.T.detach() + + def is_cross_seq_pair(self, source_canonical: str, target_canonical: str) -> bool: + """True if source is k/v and target is o in the same block.""" + source = CanonicalWeight.parse(source_canonical) + target = CanonicalWeight.parse(target_canonical) + match source, target: + case ( + LayerWeight(layer_idx=si, name=SeparateAttnWeight(weight="k" | "v")), + LayerWeight(layer_idx=ti, name=SeparateAttnWeight(weight="o")), + ): + return si == ti + case ( + LayerWeight(layer_idx=si, name=FusedAttnWeight(weight="qkv")), + LayerWeight(layer_idx=ti, name=FusedAttnWeight(weight="o")), + ): + return si == ti + case _: + return False diff --git a/spd/utils/logging_utils.py b/spd/utils/logging_utils.py index fd39afeeb..47649ca82 100644 --- a/spd/utils/logging_utils.py +++ b/spd/utils/logging_utils.py @@ -56,14 +56,13 @@ def get_grad_norms_dict( comp_grad_norm_sq_sum += param_grad_sum_sq ci_fn_grad_norm_sq_sum: Float[Tensor, ""] = torch.zeros((), device=device) - for target_module_path, ci_fn in component_model.ci_fns.items(): - for local_param_name, local_param in ci_fn.named_parameters(): - ci_fn_grad = runtime_cast(Tensor, local_param.grad) - ci_fn_grad_sum_sq = ci_fn_grad.pow(2).sum() - key = f"ci_fns/{target_module_path}.{local_param_name}" - assert key not in out, f"Key {key} already exists in grad norms log" - out[key] = ci_fn_grad_sum_sq.sqrt().item() - ci_fn_grad_norm_sq_sum += ci_fn_grad_sum_sq + for local_param_name, local_param in component_model.ci_fn.named_parameters(): + ci_fn_grad = runtime_cast(Tensor, local_param.grad) + ci_fn_grad_sum_sq = ci_fn_grad.pow(2).sum() + key = f"ci_fns/{local_param_name}" + assert key not in out, f"Key {key} already exists in grad norms log" + out[key] = ci_fn_grad_sum_sq.sqrt().item() + ci_fn_grad_norm_sq_sum += ci_fn_grad_sum_sq out["summary/components"] = comp_grad_norm_sq_sum.sqrt().item() out["summary/ci_fns"] = ci_fn_grad_norm_sq_sum.sqrt().item() diff --git a/spd/utils/markdown.py b/spd/utils/markdown.py new file mode 100644 index 000000000..0098c4d8e --- /dev/null +++ b/spd/utils/markdown.py @@ -0,0 +1,44 @@ +"""Minimal Markdown document builder for prompt construction. + +Atomic unit is a block (paragraph, heading, list, etc.). +build() joins blocks with double newlines. +""" + + +class Md: + """Accumulates Markdown blocks with a fluent API. + + Each method appends a block and returns self for chaining. + Call .build() to get the final string (blocks joined by blank lines). + """ + + def __init__(self) -> None: + self._blocks: list[str] = [] + + def h(self, level: int, text: str) -> "Md": + self._blocks.append(f"{'#' * level} {text}") + return self + + def p(self, text: str) -> "Md": + self._blocks.append(text) + return self + + def bullets(self, items: list[str]) -> "Md": + self._blocks.append("\n".join(f"- {item}" for item in items)) + return self + + def labeled_list(self, label: str, items: list[str]) -> "Md": + lines = [label] + [f"- {item}" for item in items] + self._blocks.append("\n".join(lines)) + return self + + def numbered(self, items: list[str]) -> "Md": + self._blocks.append("\n".join(f"{i}. {item}" for i, item in enumerate(items, 1))) + return self + + def extend(self, other: "Md") -> "Md": + self._blocks.extend(other._blocks) + return self + + def build(self) -> str: + return "\n\n".join(self._blocks) diff --git a/spd/utils/slurm.py b/spd/utils/slurm.py index 1dc7c18a2..0f03bc121 100644 --- a/spd/utils/slurm.py +++ b/spd/utils/slurm.py @@ -40,6 +40,7 @@ class SlurmConfig: n_gpus: int = 1 n_nodes: int = 1 time: str = "72:00:00" + mem: str | None = None # Memory limit (e.g., "64G", "128G") cpus_per_task: int | None = None snapshot_branch: str | None = None dependency_job_id: str | None = None @@ -279,6 +280,9 @@ def _sbatch_header( if config.cpus_per_task is not None: lines.append(f"#SBATCH --cpus-per-task={config.cpus_per_task}") + if config.mem is not None: + lines.append(f"#SBATCH --mem={config.mem}") + if is_array and array_range: lines.append(f"#SBATCH --array={array_range}") diff --git a/spd/utils/sqlite.py b/spd/utils/sqlite.py new file mode 100644 index 000000000..a75fa1976 --- /dev/null +++ b/spd/utils/sqlite.py @@ -0,0 +1,33 @@ +"""SQLite connection helpers for NFS-mounted databases. + +Two environments exist in this codebase: + +1. **NFS databases** (harvest, autointerp, graph_interp, dataset_attributions): + - Live at SPD_OUT_DIR on shared NFS mount + - WAL mode MUST NOT be used — it requires POSIX advisory locking which + NFS doesn't support reliably, causing "database is locked" errors + - Readonly uses ?immutable=1 (no lock files created at all) + - Write mode uses default DELETE journal + +2. **App database** (prompt_attr.db): + - Lives at SPD_OUT_DIR/app/ on NFS (shared across team) + - Uses DELETE journal mode with fcntl.flock write locking + - Managed by PromptAttrDB in spd/app/backend/database.py +""" + +import sqlite3 +from pathlib import Path + + +def open_nfs_sqlite(path: Path, readonly: bool) -> sqlite3.Connection: + """Open a SQLite connection safe for NFS-mounted databases. + + Readonly: ?immutable=1 URI (zero lock files, safe for concurrent readers). + Write: default DELETE journal (WAL breaks on NFS). + """ + if readonly: + conn = sqlite3.connect(f"file:{path}?immutable=1", uri=True, check_same_thread=False) + else: + conn = sqlite3.connect(str(path), check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 391bc168d..9e0f086fc 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -15,7 +15,7 @@ from spd.base_config import BaseConfig from spd.log import logger from spd.registry import EXPERIMENT_REGISTRY -from spd.settings import REPO_ROOT +from spd.settings import DEFAULT_PROJECT_NAME, REPO_ROOT from spd.utils.general_utils import fetch_latest_checkpoint_name WORKSPACE_TEMPLATES = { @@ -31,7 +31,11 @@ # Regex patterns for parsing W&B run references # Run IDs can be 8 chars (e.g., "d2ec3bfe") or prefixed with char-dash (e.g., "s-d2ec3bfe") +DEFAULT_WANDB_ENTITY = "goodfire" +DEFAULT_WANDB_PROJECT = DEFAULT_PROJECT_NAME + _RUN_ID_PATTERN = r"(?:[a-z0-9]-)?[a-z0-9]{8}" +_BARE_RUN_ID_RE = re.compile(r"^(s-[a-z0-9]{8})$") _WANDB_PATH_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/({_RUN_ID_PATTERN})$") _WANDB_PATH_WITH_RUNS_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/runs/({_RUN_ID_PATTERN})$") _WANDB_URL_RE = re.compile( @@ -52,7 +56,12 @@ "PGDReconLoss": "PGDRecon", "PGDReconSubsetLoss": "PGDReconSub", "PGDReconLayerwiseLoss": "PGDReconLayer", - "StochasticHiddenActsReconLoss": "StochHiddenRecon", + "PersistentPGDReconLoss": "PersistPGDRecon", + "PersistentPGDReconSubsetLoss": "PersistPGDReconSub", + "StochasticHiddenActsReconLoss": "StochHiddenActRecon", + "CIHiddenActsReconLoss": "CIHiddenActRecon", + "StochasticAttnPatternsReconLoss": "StochAttnRecon", + "CIMaskedAttnPatternsReconLoss": "CIAttnRecon", "UnmaskedReconLoss": "UnmaskedRecon", # Eval metrics "CEandKLLosses": "CEandKL", @@ -66,6 +75,8 @@ "StochasticReconSubsetCEAndKL": "StochReconSubCEKL", "PGDMultiBatchReconLoss": "PGDMultiBatchRecon", "PGDMultiBatchReconSubsetLoss": "PGDMultiBatchReconSub", + "PersistentPGDReconEval": "PersistPGDReconEval", + "PersistentPGDReconSubsetEval": "PersistPGDReconSubEval", } @@ -167,6 +178,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: """Parse various W&B run reference formats into (entity, project, run_id). Accepts: + - "s-xxxxxxxx" (bare SPD run ID, assumes goodfire/spd) - "entity/project/runId" (compact form) - "entity/project/runs/runId" (with /runs/) - "wandb:entity/project/runId" (with wandb: prefix) @@ -185,6 +197,10 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: if s.startswith("wandb:"): s = s[6:] + # Bare run ID (e.g. "s-17805b61") → default entity/project + if m := _BARE_RUN_ID_RE.match(s): + return DEFAULT_WANDB_ENTITY, DEFAULT_WANDB_PROJECT, m.group(1) + # Try compact form: entity/project/runid if m := _WANDB_PATH_RE.match(s): return m.group(1), m.group(2), m.group(3) @@ -199,6 +215,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: raise ValueError( f"Invalid W&B run reference. Expected one of:\n" + f' - "s-xxxxxxxx" (bare run ID)\n' f' - "entity/project/xxxxxxxx"\n' f' - "entity/project/runs/xxxxxxxx"\n' f' - "wandb:entity/project/runs/xxxxxxxx"\n' @@ -293,8 +310,8 @@ def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path: """ file_on_wandb = run.file(file_name) assert isinstance(file_on_wandb, File) - path = Path(file_on_wandb.download(exist_ok=True, replace=False, root=str(wandb_run_dir)).name) - return path + file_on_wandb.download(exist_ok=True, replace=False, root=str(wandb_run_dir)) + return wandb_run_dir / file_name def init_wandb( diff --git a/tests/app/test_app_tokenizer.py b/tests/app/test_app_tokenizer.py new file mode 100644 index 000000000..d91c71857 --- /dev/null +++ b/tests/app/test_app_tokenizer.py @@ -0,0 +1,91 @@ +"""Tests for AppTokenizer span reconstruction and display logic.""" + +import pytest +from transformers import AutoTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from spd.app.backend.app_tokenizer import AppTokenizer + +# Test strings covering various tokenization edge cases +BASIC_STRINGS = [ + "Hello, world!", + "The quick brown fox jumps over the lazy dog.", + "It's a beautiful day.", + "price is $19.99", + "foo bar", # multiple spaces + "line1\nline2", # newline +] + +UNICODE_STRINGS = [ + "café résumé naïve", + "日本語テスト", +] + + +@pytest.fixture(scope="module") +def gpt2_tokenizer() -> AppTokenizer: + tok = AutoTokenizer.from_pretrained("openai-community/gpt2") + assert isinstance(tok, PreTrainedTokenizerBase) + return AppTokenizer(tok) + + +class TestGetSpans: + """Test that get_spans produces strings that concatenate to the full decoded text.""" + + def test_empty(self, gpt2_tokenizer: AppTokenizer) -> None: + assert gpt2_tokenizer.get_spans([]) == [] + + @pytest.mark.parametrize("text", BASIC_STRINGS) + def test_round_trip_basic(self, gpt2_tokenizer: AppTokenizer, text: str) -> None: + token_ids = gpt2_tokenizer.encode(text) + spans = gpt2_tokenizer.get_spans(token_ids) + assert len(spans) == len(token_ids) + from spd.app.backend.app_tokenizer import escape_for_display + + assert "".join(spans) == escape_for_display(gpt2_tokenizer.decode(token_ids)) + + @pytest.mark.parametrize("text", UNICODE_STRINGS) + def test_round_trip_unicode(self, gpt2_tokenizer: AppTokenizer, text: str) -> None: + token_ids = gpt2_tokenizer.encode(text) + spans = gpt2_tokenizer.get_spans(token_ids) + assert len(spans) == len(token_ids) + # For unicode, some spans may be empty (multi-byte split), but concat must match + from spd.app.backend.app_tokenizer import escape_for_display + + assert "".join(spans) == escape_for_display(gpt2_tokenizer.decode(token_ids)) + + def test_single_token(self, gpt2_tokenizer: AppTokenizer) -> None: + token_ids = gpt2_tokenizer.encode("hi") + assert len(token_ids) == 1 + spans = gpt2_tokenizer.get_spans(token_ids) + assert spans == [gpt2_tokenizer.decode(token_ids)] + + +class TestGetTokDisplay: + """Test single-token display strings.""" + + def test_known_tokens(self, gpt2_tokenizer: AppTokenizer) -> None: + # Token 0 is "!" in GPT-2 + display = gpt2_tokenizer.get_tok_display(0) + assert isinstance(display, str) + assert len(display) > 0 + + def test_space_token(self, gpt2_tokenizer: AppTokenizer) -> None: + # " the" is a common GPT-2 token + token_ids = gpt2_tokenizer.encode(" the") + assert len(token_ids) == 1 + display = gpt2_tokenizer.get_tok_display(token_ids[0]) + assert "the" in display + + +class TestEncodeDecode: + """Test encode/decode round-trip.""" + + def test_encode_decode(self, gpt2_tokenizer: AppTokenizer) -> None: + text = "Hello, world!" + token_ids = gpt2_tokenizer.encode(text) + decoded = gpt2_tokenizer.decode(token_ids) + assert decoded == text + + def test_vocab_size(self, gpt2_tokenizer: AppTokenizer) -> None: + assert gpt2_tokenizer.vocab_size == 50257 diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 508f69ac4..1d55e3d45 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -13,16 +13,23 @@ import pytest from fastapi.testclient import TestClient -from spd.app.backend.compute import get_sources_by_target +from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB from spd.app.backend.routers import graphs as graphs_router -from spd.app.backend.routers import prompts as prompts_router +from spd.app.backend.routers import intervention as intervention_router from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app -from spd.app.backend.state import HarvestCache, RunState, StateManager -from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig +from spd.app.backend.state import RunState, StateManager +from spd.configs import ( + Config, + LayerwiseCiConfig, + LMTaskConfig, + ModulePatternInfoConfig, + ScheduleConfig, +) from spd.models.component_model import ComponentModel from spd.pretrain.models.gpt2_simple import GPT2Simple, GPT2SimpleConfig +from spd.topology import TransformerTopology, get_sources_by_target from spd.utils.module_utils import expand_module_patterns DEVICE = "cpu" @@ -48,7 +55,7 @@ def app_with_state(): # Patch DEVICE in all router modules to use CPU for tests with ( mock.patch.object(graphs_router, "DEVICE", DEVICE), - mock.patch.object(prompts_router, "DEVICE", DEVICE), + mock.patch.object(intervention_router, "DEVICE", DEVICE), mock.patch.object(runs_router, "DEVICE", DEVICE), ): db = PromptAttrDB(db_path=Path(":memory:"), check_same_thread=False) @@ -83,8 +90,7 @@ def app_with_state(): config = Config( n_mask_samples=1, - ci_fn_type="shared_mlp", - ci_fn_hidden_dims=[16], + ci_config=LayerwiseCiConfig(fn_type="shared_mlp", hidden_dims=[16]), sampling="continuous", sigmoid_type="leaky_hard", module_info=[ @@ -115,28 +121,35 @@ def app_with_state(): model = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_config=config.ci_config, pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) model.eval() + topology = TransformerTopology(model.target_model) sources_by_target = get_sources_by_target( - model=model, device=DEVICE, sampling=config.sampling + model=model, topology=topology, device=DEVICE, sampling=config.sampling ) - # The model has vocab_size=4019, so create entries for all token IDs - token_strings = {i: f"tok_{i}" for i in range(model_config.vocab_size)} + from transformers import AutoTokenizer + from transformers.tokenization_utils_base import PreTrainedTokenizerBase + + hf_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + assert isinstance(hf_tokenizer, PreTrainedTokenizerBase) + tokenizer = AppTokenizer(hf_tokenizer) run_state = RunState( run=run, model=model, + topology=topology, context_length=1, - tokenizer=None, # pyright: ignore[reportArgumentType] + tokenizer=tokenizer, sources_by_target=sources_by_target, config=config, - token_strings=token_strings, - harvest=HarvestCache(run_id="test_run"), + harvest=None, + interp=None, + attributions=None, + graph_interp=None, ) manager = StateManager.get() @@ -161,7 +174,6 @@ def app_with_prompt(app_with_state: TestClient) -> tuple[TestClient, int]: prompt_id = manager.db.add_custom_prompt( run_id=manager.run_state.run.id, token_ids=[0, 2, 1], - active_components={}, # Empty for testing context_length=manager.run_state.context_length, ) return app_with_state, prompt_id @@ -222,35 +234,59 @@ def test_compute_graph(app_with_prompt: tuple[TestClient, int]): assert "outputProbs" in data -# ----------------------------------------------------------------------------- -# Streaming: Prompt Generation -# ----------------------------------------------------------------------------- - +def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClient, int]): + """Run-and-save intervention should use graph-linked prompt tokens (no text in request).""" + client, prompt_id = app_with_prompt -@pytest.mark.slow -def test_generate_prompts_streaming(app_with_state: TestClient): - """Test streaming prompt generation with CI harvesting.""" - response = app_with_state.post("/api/prompts/generate", params={"n_prompts": 2}) + graph_response = client.post( + "/api/graphs", + params={"prompt_id": prompt_id, "normalize": "none", "ci_threshold": 0.0}, + ) + assert graph_response.status_code == 200 + events = [line for line in graph_response.text.strip().split("\n") if line.startswith("data:")] + final_data = json.loads(events[-1].replace("data: ", "")) + graph_data = final_data["data"] + graph_id = graph_data["id"] + + selected_nodes = [ + key + for key, ci in graph_data["nodeCiVals"].items() + if not key.startswith("embed:") and not key.startswith("output:") and ci > 0 + ] + assert len(selected_nodes) > 0 + + request = { + "graph_id": graph_id, + "selected_nodes": selected_nodes[:5], + "top_k": 5, + "adv_pgd": {"n_steps": 1, "step_size": 1.0}, + } + response = client.post("/api/intervention/run", json=request) assert response.status_code == 200 + body = response.json() + assert body["selected_nodes"] == request["selected_nodes"] + result = body["result"] + assert len(result["input_tokens"]) > 0 + assert len(result["ci"]) > 0 + assert len(result["stochastic"]) > 0 + assert len(result["adversarial"]) > 0 + assert result["ablated"] is None + assert "ci_loss" in result + assert "stochastic_loss" in result + assert "adversarial_loss" in result + assert result["ablated_loss"] is None - # Parse SSE stream - lines = response.text.strip().split("\n") - events = [line for line in lines if line.startswith("data:")] - # Should have progress events and a complete event - assert len(events) >= 1 - - # Final event should be complete - final_data = json.loads(events[-1].replace("data: ", "")) - assert final_data["type"] == "complete" +# ----------------------------------------------------------------------------- +# Streaming: Prompt Generation +# ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- -# Prompts and Search +# Prompts # ----------------------------------------------------------------------------- -@pytest.mark.slow def test_get_prompts_initially_empty(app_with_state: TestClient): """Test that prompts list is initially empty.""" response = app_with_state.get("/api/prompts") @@ -259,33 +295,25 @@ def test_get_prompts_initially_empty(app_with_state: TestClient): assert len(prompts) == 0 -@pytest.mark.slow -def test_get_prompts_after_generation(app_with_state: TestClient): - """Test getting prompts after generation.""" - # Generate some prompts first - app_with_state.post("/api/prompts/generate", params={"n_prompts": 2}) +def test_get_prompts_after_adding(app_with_state: TestClient): + """Test getting prompts after adding via database.""" + manager = StateManager.get() + assert manager.run_state is not None + manager.db.add_custom_prompt( + run_id=manager.run_state.run.id, + token_ids=[0, 2, 1], + context_length=manager.run_state.context_length, + ) + manager.db.add_custom_prompt( + run_id=manager.run_state.run.id, + token_ids=[1, 3, 2], + context_length=manager.run_state.context_length, + ) response = app_with_state.get("/api/prompts") assert response.status_code == 200 prompts = response.json() - assert len(prompts) >= 2 - - -@pytest.mark.slow -def test_search_prompts(app_with_state: TestClient): - """Test searching prompts by component keys.""" - # Generate prompts first - app_with_state.post("/api/prompts/generate", params={"n_prompts": 2}) - - # Search for a component that should exist (wte:0 is always active) - response = app_with_state.get( - "/api/prompts/search", - params={"components": "wte:0", "mode": "any"}, - ) - assert response.status_code == 200 - data = response.json() - assert "count" in data - assert data["count"] >= 0 + assert len(prompts) == 2 # ----------------------------------------------------------------------------- @@ -314,13 +342,16 @@ def test_compute_optimized_stream(app_with_prompt: tuple[TestClient, int]): "prompt_id": prompt_id, "label_token": 2, "imp_min_coeff": 0.01, - "ce_loss_coeff": 1.0, + "loss_type": "ce", + "loss_coeff": 1.0, + "loss_position": 2, "steps": 5, # Very few steps for testing "pnorm": 0.5, "beta": 0.5, "normalize": "none", "ci_threshold": 0.0, "output_prob_threshold": 0.01, + "mask_type": "stochastic", }, ) assert response.status_code == 200 diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 3ee626bd0..717df25ec 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -19,7 +19,7 @@ def test_run_clustering_happy_path(monkeypatch: Any): monkeypatch.setattr("spd.utils.run_utils.SPD_OUT_DIR", temp_path) config = ClusteringRunConfig( - model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run + model_path="wandb:goodfire/spd/runs/s-a9ad193d", # A resid_mlp2 run batch_size=4, dataset_seed=0, ensemble_id=None, @@ -38,6 +38,5 @@ def test_run_clustering_happy_path(monkeypatch: Any): plot=100, artifact=100, ), - dataset_streaming=True, # tests in CI very slow without this, see https://github.com/goodfire-ai/spd/pull/199 ) main(config) diff --git a/tests/dataset_attributions/test_harvester.py b/tests/dataset_attributions/test_harvester.py deleted file mode 100644 index 96ebc5df8..000000000 --- a/tests/dataset_attributions/test_harvester.py +++ /dev/null @@ -1,265 +0,0 @@ -"""Tests for dataset attribution harvester logic.""" - -from pathlib import Path - -import torch - -from spd.dataset_attributions.storage import DatasetAttributionStorage - - -def _make_storage( - n_components: int = 2, - vocab_size: int = 3, - d_model: int = 4, - source_to_component: torch.Tensor | None = None, - source_to_out_residual: torch.Tensor | None = None, -) -> DatasetAttributionStorage: - """Helper to create storage with default values.""" - n_sources = vocab_size + n_components - if source_to_component is None: - source_to_component = torch.zeros(n_sources, n_components) - if source_to_out_residual is None: - source_to_out_residual = torch.zeros(n_sources, d_model) - - return DatasetAttributionStorage( - component_layer_keys=[f"layer1:{i}" for i in range(n_components)], - vocab_size=vocab_size, - d_model=d_model, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - n_batches_processed=10, - n_tokens_processed=1000, - ci_threshold=0.0, - ) - - -class TestDatasetAttributionStorage: - """Tests for DatasetAttributionStorage. - - Storage structure: - - source_to_component: (n_sources, n_components) for component target attributions - - source_to_out_residual: (n_sources, d_model) for output target attributions (via w_unembed) - """ - - def test_has_source_and_target(self) -> None: - """Test has_source and has_target methods.""" - storage = _make_storage(n_components=2, vocab_size=3) - - # wte tokens can only be sources - assert storage.has_source("wte:0") - assert storage.has_source("wte:2") - assert not storage.has_source("wte:3") # Out of vocab - assert not storage.has_target("wte:0") # wte can't be target - - # Component layers can be both sources and targets - assert storage.has_source("layer1:0") - assert storage.has_source("layer1:1") - assert storage.has_target("layer1:0") - assert storage.has_target("layer1:1") - assert not storage.has_source("layer1:2") - assert not storage.has_target("layer1:2") - - # output tokens can only be targets - assert storage.has_target("output:0") - assert storage.has_target("output:2") - assert not storage.has_target("output:3") # Out of vocab - assert not storage.has_source("output:0") # output can't be source - - def test_get_attribution_component_target(self) -> None: - """Test get_attribution for component targets (no w_unembed needed).""" - # 2 component layers: layer1:0, layer1:1 - # vocab_size=2, d_model=4 - # n_sources = 2 + 2 = 4 - # source_to_component shape: (4, 2) - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 -> components - [3.0, 4.0], # wte:1 -> components - [5.0, 6.0], # layer1:0 -> components - [7.0, 8.0], # layer1:1 -> components - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # wte:0 -> layer1:0 - assert storage.get_attribution("wte:0", "layer1:0") == 1.0 - # wte:1 -> layer1:1 - assert storage.get_attribution("wte:1", "layer1:1") == 4.0 - # layer1:0 -> layer1:1 - assert storage.get_attribution("layer1:0", "layer1:1") == 6.0 - - def test_get_attribution_output_target(self) -> None: - """Test get_attribution for output targets (requires w_unembed).""" - # source_to_out_residual shape: (4, 4) for n_sources=4, d_model=4 - source_to_out_residual = torch.tensor( - [ - [1.0, 0.0, 0.0, 0.0], # wte:0 -> out_residual - [0.0, 1.0, 0.0, 0.0], # wte:1 -> out_residual - [0.0, 0.0, 1.0, 0.0], # layer1:0 -> out_residual - [0.0, 0.0, 0.0, 1.0], # layer1:1 -> out_residual - ] - ) - # w_unembed shape: (d_model=4, vocab=2) - w_unembed = torch.tensor( - [ - [1.0, 2.0], # d0 -> outputs - [3.0, 4.0], # d1 -> outputs - [5.0, 6.0], # d2 -> outputs - [7.0, 8.0], # d3 -> outputs - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, d_model=4, source_to_out_residual=source_to_out_residual - ) - - # wte:0 -> output:0 = out_residual[0] @ w_unembed[:, 0] = [1,0,0,0] @ [1,3,5,7] = 1.0 - assert storage.get_attribution("wte:0", "output:0", w_unembed=w_unembed) == 1.0 - # wte:1 -> output:1 = [0,1,0,0] @ [2,4,6,8] = 4.0 - assert storage.get_attribution("wte:1", "output:1", w_unembed=w_unembed) == 4.0 - # layer1:0 -> output:0 = [0,0,1,0] @ [1,3,5,7] = 5.0 - assert storage.get_attribution("layer1:0", "output:0", w_unembed=w_unembed) == 5.0 - - def test_get_top_sources_component_target(self) -> None: - """Test get_top_sources for component targets.""" - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 - [5.0, 3.0], # wte:1 - [2.0, 4.0], # layer1:0 - [3.0, 1.0], # layer1:1 - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # Top sources TO layer1:0 (column 0): wte:0=1.0, wte:1=5.0, layer1:0=2.0, layer1:1=3.0 - sources = storage.get_top_sources("layer1:0", k=2, sign="positive") - assert len(sources) == 2 - assert sources[0].component_key == "wte:1" - assert sources[0].value == 5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == 3.0 - - def test_get_top_sources_negative(self) -> None: - """Test get_top_sources with negative sign.""" - source_to_component = torch.tensor( - [ - [-1.0, 2.0], - [-5.0, 3.0], - [-2.0, 4.0], - [-3.0, 1.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - sources = storage.get_top_sources("layer1:0", k=2, sign="negative") - assert len(sources) == 2 - # wte:1 has most negative (-5.0), then layer1:1 (-3.0) - assert sources[0].component_key == "wte:1" - assert sources[0].value == -5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == -3.0 - - def test_get_top_component_targets(self) -> None: - """Test get_top_component_targets (no w_unembed needed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - targets = storage.get_top_component_targets("layer1:0", k=2, sign="positive") - assert len(targets) == 2 - assert targets[0].component_key == "layer1:1" - assert targets[0].value == 4.0 - assert targets[1].component_key == "layer1:0" - assert targets[1].value == 2.0 - - def test_get_top_targets_with_outputs(self) -> None: - """Test get_top_targets including outputs (requires w_unembed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - # Make out_residual attribution that produces high output values - source_to_out_residual = torch.tensor( - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0], # layer1:0 -> out_residual (sum=4 per output) - [0.0, 0.0, 0.0, 0.0], - ] - ) - # w_unembed that gives output:0=10, output:1=5 - w_unembed = torch.tensor( - [ - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - ] - ) - storage = _make_storage( - n_components=2, - vocab_size=2, - d_model=4, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - ) - - targets = storage.get_top_targets("layer1:0", k=3, sign="positive", w_unembed=w_unembed) - assert len(targets) == 3 - # output:0 = 10.0, output:1 = 5.0, layer1:1 = 4.0 - assert targets[0].component_key == "output:0" - assert targets[0].value == 10.0 - assert targets[1].component_key == "output:1" - assert targets[1].value == 5.0 - assert targets[2].component_key == "layer1:1" - assert targets[2].value == 4.0 - - def test_save_and_load(self, tmp_path: Path) -> None: - """Test save and load roundtrip.""" - n_components = 2 - vocab_size = 3 - d_model = 4 - n_sources = vocab_size + n_components - - original = DatasetAttributionStorage( - component_layer_keys=["layer:0", "layer:1"], - vocab_size=vocab_size, - d_model=d_model, - source_to_component=torch.randn(n_sources, n_components), - source_to_out_residual=torch.randn(n_sources, d_model), - n_batches_processed=100, - n_tokens_processed=10000, - ci_threshold=0.01, - ) - - path = tmp_path / "test_attributions.pt" - original.save(path) - - loaded = DatasetAttributionStorage.load(path) - - assert loaded.component_layer_keys == original.component_layer_keys - assert loaded.vocab_size == original.vocab_size - assert loaded.d_model == original.d_model - assert loaded.n_batches_processed == original.n_batches_processed - assert loaded.n_tokens_processed == original.n_tokens_processed - assert loaded.ci_threshold == original.ci_threshold - assert torch.allclose(loaded.source_to_component, original.source_to_component) - assert torch.allclose(loaded.source_to_out_residual, original.source_to_out_residual) diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py new file mode 100644 index 000000000..a972c1107 --- /dev/null +++ b/tests/dataset_attributions/test_storage.py @@ -0,0 +1,358 @@ +"""Tests for DatasetAttributionStorage.""" + +import math +from pathlib import Path + +import torch +from torch import Tensor + +from spd.dataset_attributions.storage import DatasetAttributionStorage + +VOCAB_SIZE = 4 +D_MODEL = 4 +LAYER_0 = "0.glu.up" +LAYER_1 = "1.glu.up" +C0 = 3 # components in layer 0 +C1 = 2 # components in layer 1 + + +def _make_storage(seed: int = 0, n_tokens: int = 640) -> DatasetAttributionStorage: + """Build storage for test topology. + + Sources by target: + "0.glu.up": ["embed"] -> embed edge (C0, VOCAB_SIZE) + "1.glu.up": ["embed", "0.glu.up"] -> embed edge (C1, VOCAB_SIZE) + regular (C1, C0) + "output": ["0.glu.up", "1.glu.up"] -> unembed (D_MODEL, C0), (D_MODEL, C1) + "output": ["embed"] -> embed_unembed (D_MODEL, VOCAB_SIZE) + """ + g = torch.Generator().manual_seed(seed) + + def rand(*shape: int) -> Tensor: + return torch.randn(*shape, generator=g) + + return DatasetAttributionStorage( + regular_attr={LAYER_1: {LAYER_0: rand(C1, C0)}}, + regular_attr_abs={LAYER_1: {LAYER_0: rand(C1, C0)}}, + embed_attr={LAYER_0: rand(C0, VOCAB_SIZE), LAYER_1: rand(C1, VOCAB_SIZE)}, + embed_attr_abs={LAYER_0: rand(C0, VOCAB_SIZE), LAYER_1: rand(C1, VOCAB_SIZE)}, + unembed_attr={LAYER_0: rand(D_MODEL, C0), LAYER_1: rand(D_MODEL, C1)}, + embed_unembed_attr=rand(D_MODEL, VOCAB_SIZE), + w_unembed=rand(D_MODEL, VOCAB_SIZE), + ci_sum={LAYER_0: rand(C0).abs() + 1.0, LAYER_1: rand(C1).abs() + 1.0}, + component_act_sq_sum={LAYER_0: rand(C0).abs() + 1.0, LAYER_1: rand(C1).abs() + 1.0}, + logit_sq_sum=rand(VOCAB_SIZE).abs() + 1.0, + embed_token_count=torch.randint(100, 1000, (VOCAB_SIZE,), generator=g), + ci_threshold=1e-6, + n_tokens_processed=n_tokens, + ) + + +class TestNComponents: + def test_counts_all_target_layers(self): + storage = _make_storage() + assert storage.n_components == C0 + C1 + + +class TestGetTopSources: + def test_component_target_returns_entries(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + assert all(r.value > 0 for r in results) + assert len(results) <= 5 + + def test_component_target_includes_embed(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=20, sign="positive", metric="attr") + layers = {r.layer for r in results} + assert "embed" in layers or LAYER_0 in layers + + def test_output_target(self): + storage = _make_storage() + results = storage.get_top_sources("output:0", k=5, sign="positive", metric="attr") + assert len(results) <= 5 + + def test_output_target_attr_abs_returns_empty(self): + storage = _make_storage() + results = storage.get_top_sources("output:0", k=5, sign="positive", metric="attr_abs") + assert results == [] + + def test_target_only_in_embed_attr(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_0}:0", k=5, sign="positive", metric="attr") + assert len(results) <= 5 + assert all(r.layer == "embed" for r in results) + + def test_attr_abs_metric(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr_abs") + assert len(results) <= 5 + + def test_no_nan_in_results(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=20, sign="positive", metric="attr") + assert all(not torch.isnan(torch.tensor(r.value)) for r in results) + + +class TestGetTopTargets: + def test_component_source(self): + storage = _make_storage() + results = storage.get_top_targets( + f"{LAYER_0}:0", k=5, sign="positive", metric="attr", include_outputs=False + ) + assert len(results) <= 5 + assert all(r.value > 0 for r in results) + + def test_embed_source(self): + storage = _make_storage() + results = storage.get_top_targets( + "embed:0", k=5, sign="positive", metric="attr", include_outputs=False + ) + assert len(results) <= 5 + + def test_include_outputs(self): + storage = _make_storage() + results = storage.get_top_targets(f"{LAYER_0}:0", k=20, sign="positive", metric="attr") + assert len(results) > 0 + + def test_embed_source_with_outputs(self): + storage = _make_storage() + results = storage.get_top_targets("embed:0", k=20, sign="positive", metric="attr") + assert len(results) > 0 + + def test_attr_abs_skips_output_targets(self): + storage = _make_storage() + results = storage.get_top_targets(f"{LAYER_0}:0", k=20, sign="positive", metric="attr_abs") + assert all(r.layer != "output" for r in results) + + +class TestSaveLoad: + def test_roundtrip(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + + loaded = DatasetAttributionStorage.load(path) + + assert loaded.ci_threshold == original.ci_threshold + assert loaded.n_tokens_processed == original.n_tokens_processed + assert loaded.n_components == original.n_components + + def test_roundtrip_query_consistency(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + loaded = DatasetAttributionStorage.load(path) + + orig_results = original.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + load_results = loaded.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + + assert len(orig_results) == len(load_results) + for orig, loaded in zip(orig_results, load_results, strict=True): + assert orig.component_key == loaded.component_key + assert abs(orig.value - loaded.value) < 1e-5 + + +class TestMerge: + def test_two_workers_additive(self, tmp_path: Path): + s1 = _make_storage(seed=0, n_tokens=320) + s2 = _make_storage(seed=42, n_tokens=320) + + p1 = tmp_path / "rank_0.pt" + p2 = tmp_path / "rank_1.pt" + s1.save(p1) + s2.save(p2) + + merged = DatasetAttributionStorage.merge([p1, p2]) + + assert merged.n_tokens_processed == 640 + + def test_single_file(self, tmp_path: Path): + original = _make_storage(seed=7, n_tokens=640) + path = tmp_path / "rank_0.pt" + original.save(path) + + merged = DatasetAttributionStorage.merge([path]) + + assert merged.n_tokens_processed == original.n_tokens_processed + + orig_results = original.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + merge_results = merged.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + for o, m in zip(orig_results, merge_results, strict=True): + assert o.component_key == m.component_key + assert abs(o.value - m.value) < 1e-5 + + +# --------------------------------------------------------------------------- +# Deterministic normalization and merge tests with hand-computed values +# --------------------------------------------------------------------------- + +# Minimal topology: 1 layer with 2 components, vocab=2, d_model=2 +_L = "0.up" +_NC = 2 +_V = 2 +_D = 2 +_N_TOKENS = 100 + + +def _deterministic_storage( + _regular_val: float = 10.0, + embed_val: float = 6.0, + ci_sum_val: float = 50.0, + act_sq_sum_val: float = 400.0, + embed_count_val: int = 200, + n_tokens: int = _N_TOKENS, +) -> DatasetAttributionStorage: + """Storage with uniform known values for hand-computation.""" + return DatasetAttributionStorage( + regular_attr={}, + regular_attr_abs={}, + embed_attr={_L: torch.full((_NC, _V), embed_val)}, + embed_attr_abs={_L: torch.full((_NC, _V), embed_val * 2)}, + unembed_attr={_L: torch.full((_D, _NC), 3.0)}, + embed_unembed_attr=torch.full((_D, _V), 1.0), + w_unembed=torch.eye(_D, _V), + ci_sum={_L: torch.full((_NC,), ci_sum_val)}, + component_act_sq_sum={_L: torch.full((_NC,), act_sq_sum_val)}, + logit_sq_sum=torch.full((_V,), 900.0), + embed_token_count=torch.full((_V,), embed_count_val, dtype=torch.long), + ci_threshold=0.0, + n_tokens_processed=n_tokens, + ) + + +class TestNormalizationCorrectness: + """Verify normalization produces exact expected values from known inputs. + + Formula: normalized = raw / source_denom / target_rms + - source_denom: ci_sum[source] for components, embed_token_count[tok] for embed + - target_rms: sqrt(act_sq_sum[target] / n_tokens) for components, + sqrt(logit_sq_sum[tok] / n_tokens) for output + """ + + def test_embed_to_component_normalization(self): + s = _deterministic_storage() + results = s.get_top_sources(f"{_L}:0", k=_V, sign="positive", metric="attr") + + # raw = embed_attr[_L][0, :] = 6.0 for each vocab entry + # source_denom = embed_count = 200.0 + # target_rms = sqrt(400 / 100) = 2.0 + # normalized = 6.0 / 200.0 / 2.0 = 0.015 + assert len(results) == _V + for r in results: + assert r.layer == "embed" + assert abs(r.value - 0.015) < 1e-6 + + def test_embed_to_component_abs_metric(self): + s = _deterministic_storage() + results = s.get_top_sources(f"{_L}:0", k=_V, sign="positive", metric="attr_abs") + + # raw = embed_attr_abs[_L][0, :] = 12.0 + # same denoms: 200.0, 2.0 + # normalized = 12.0 / 200.0 / 2.0 = 0.03 + assert len(results) == _V + for r in results: + assert abs(r.value - 0.03) < 1e-6 + + def test_component_to_output_normalization(self): + s = _deterministic_storage() + results = s.get_top_sources("output:0", k=5, sign="positive", metric="attr") + + # unembed_attr[_L] = 3.0 * ones(2, 2), w_unembed = eye(2, 2) + # For output:0, w = w_unembed[:, 0] = [1, 0] + # raw per source component = w @ unembed_attr[_L] = [1,0] @ [[3,3],[3,3]] = [3, 3] + # but actually w @ attr_matrix where attr_matrix is (d_model, n_components): + # raw = w @ unembed_attr[_L] = [1,0] @ [[3,3],[3,3]] = [3, 3] shape (n_c,) + # source_denom = ci_sum = 50.0 + # target_rms = sqrt(900 / 100) = 3.0 + # normalized = 3.0 / 50.0 / 3.0 = 0.02 + component_results = [r for r in results if r.layer == _L] + assert len(component_results) == _NC + for r in component_results: + assert abs(r.value - 0.02) < 1e-6 + + def test_embed_to_output_normalization(self): + s = _deterministic_storage() + results = s.get_top_sources("output:0", k=10, sign="positive", metric="attr") + + # embed_unembed_attr = 1.0 * ones(2, 2), w = [1, 0] + # raw per embed token = w @ embed_unembed_attr = [1,0] @ [[1,1],[1,1]] = [1, 1] + # source_denom = embed_count = 200.0 + # target_rms = 3.0 + # normalized = 1.0 / 200.0 / 3.0 ≈ 0.001667 + embed_results = [r for r in results if r.layer == "embed"] + assert len(embed_results) == _V + for r in embed_results: + assert abs(r.value - 1.0 / 200.0 / 3.0) < 1e-6 + + def test_sign_filtering(self): + """Positive sign excludes negative values, negative sign excludes positive.""" + s = DatasetAttributionStorage( + regular_attr={}, + regular_attr_abs={}, + embed_attr={_L: torch.tensor([[5.0, -3.0]])}, + embed_attr_abs={_L: torch.tensor([[5.0, -3.0]])}, + unembed_attr={}, + embed_unembed_attr=torch.zeros(_D, _V), + w_unembed=torch.eye(_D, _V), + ci_sum={_L: torch.tensor([1.0])}, + component_act_sq_sum={_L: torch.tensor([100.0])}, + logit_sq_sum=torch.ones(_V), + embed_token_count=torch.ones(_V, dtype=torch.long), + ci_threshold=0.0, + n_tokens_processed=100, + ) + + pos = s.get_top_sources(f"{_L}:0", k=10, sign="positive", metric="attr") + neg = s.get_top_sources(f"{_L}:0", k=10, sign="negative", metric="attr") + + assert all(r.value > 0 for r in pos) + assert all(r.value < 0 for r in neg) + assert len(pos) == 1 # only embed:0 is positive + assert len(neg) == 1 # only embed:1 is negative + + +class TestMergeNumericCorrectness: + """Verify merge produces correct normalized values.""" + + def test_merge_equals_sum_of_parts(self, tmp_path: Path): + """Two workers with known values; merged queries should equal manual computation.""" + s1 = _deterministic_storage( + embed_val=4.0, ci_sum_val=20.0, act_sq_sum_val=100.0, embed_count_val=80, n_tokens=40 + ) + s2 = _deterministic_storage( + embed_val=8.0, ci_sum_val=30.0, act_sq_sum_val=500.0, embed_count_val=120, n_tokens=60 + ) + + p1, p2 = tmp_path / "r0.pt", tmp_path / "r1.pt" + s1.save(p1) + s2.save(p2) + merged = DatasetAttributionStorage.merge([p1, p2]) + + assert merged.n_tokens_processed == 100 + + # Merged raw: embed_attr[_L][0, 0] = 4.0 + 8.0 = 12.0 + # Merged embed_count[0] = 80 + 120 = 200 + # Merged act_sq_sum[_L][0] = 100 + 500 = 600 + # target_rms = sqrt(600 / 100) = sqrt(6) + # normalized = 12.0 / 200.0 / sqrt(6) + expected = 12.0 / 200.0 / math.sqrt(6) + + results = merged.get_top_sources(f"{_L}:0", k=_V, sign="positive", metric="attr") + assert len(results) == _V + for r in results: + assert abs(r.value - expected) < 1e-6 + + def test_merge_identity(self, tmp_path: Path): + """Merging a single file produces identical query results.""" + s = _deterministic_storage() + path = tmp_path / "single.pt" + s.save(path) + merged = DatasetAttributionStorage.merge([path]) + + for key in [f"{_L}:0", f"{_L}:1"]: + orig = s.get_top_sources(key, k=10, sign="positive", metric="attr") + mrgd = merged.get_top_sources(key, k=10, sign="positive", metric="attr") + assert len(orig) == len(mrgd) + for o, m in zip(orig, mrgd, strict=True): + assert o.component_key == m.component_key + assert abs(o.value - m.value) < 1e-6 diff --git a/tests/harvest/__init__.py b/tests/harvest/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/harvest/test_harvester.py b/tests/harvest/test_harvester.py new file mode 100644 index 000000000..7ad656ec3 --- /dev/null +++ b/tests/harvest/test_harvester.py @@ -0,0 +1,555 @@ +"""Tests for the Harvester class and extract_padding_firing_windows.""" + +import random +from pathlib import Path + +import pytest +import torch + +from spd.harvest.harvester import Harvester, extract_padding_firing_windows +from spd.harvest.reservoir import WINDOW_PAD_SENTINEL, ActivationWindows + +DEVICE = torch.device("cpu") + +LAYERS = [("layer_0", 4), ("layer_1", 4)] +N_TOTAL = 8 +VOCAB_SIZE = 10 +MAX_EXAMPLES = 5 +CONTEXT_TOKENS_PER_SIDE = 1 +WINDOW = 2 * CONTEXT_TOKENS_PER_SIDE + 1 # 3 + +ACT_TYPES = ["ci", "inner"] + + +def _make_harvester() -> Harvester: + return Harvester( + layers=LAYERS, + vocab_size=VOCAB_SIZE, + max_examples_per_component=MAX_EXAMPLES, + context_tokens_per_side=CONTEXT_TOKENS_PER_SIDE, + max_examples_per_batch_per_component=5, + device=DEVICE, + ) + + +def _make_activation_windows( + comp_idx: list[int], + token_windows: torch.Tensor, + firings: torch.Tensor | None = None, +) -> ActivationWindows: + n = len(comp_idx) + w = token_windows.shape[1] + if firings is None: + firings = torch.ones(n, w, dtype=torch.bool) + return ActivationWindows( + component_idx=torch.tensor(comp_idx), + token_windows=token_windows, + firing_windows=firings, + activation_windows={at: torch.ones(n, w) for at in ACT_TYPES}, + ) + + +class TestInit: + def test_tensor_shapes(self): + h = _make_harvester() + assert h.firing_counts.shape == (N_TOTAL,) + assert h.cooccurrence_counts.shape == (N_TOTAL, N_TOTAL) + assert h.input_cooccurrence.shape == (N_TOTAL, VOCAB_SIZE) + assert h.input_marginals.shape == (VOCAB_SIZE,) + assert h.output_cooccurrence.shape == (N_TOTAL, VOCAB_SIZE) + assert h.output_marginals.shape == (VOCAB_SIZE,) + assert h.reservoir.tokens.shape == (N_TOTAL, MAX_EXAMPLES, WINDOW) + assert h.reservoir.n_items.shape == (N_TOTAL,) + assert h.reservoir.n_seen.shape == (N_TOTAL,) + + def test_tensors_on_correct_device(self): + h = _make_harvester() + assert h.firing_counts.device == DEVICE + assert h.reservoir.tokens.device == DEVICE + assert h.cooccurrence_counts.device == DEVICE + + def test_layer_offsets(self): + h = _make_harvester() + assert h.layer_offsets == {"layer_0": 0, "layer_1": 4} + + def test_component_keys(self): + h = _make_harvester() + expected = [f"layer_0:{i}" for i in range(4)] + [f"layer_1:{i}" for i in range(4)] + assert h.component_keys == expected + + def test_tensors_initialized_to_zero(self): + h = _make_harvester() + assert h.firing_counts.sum() == 0 + assert h.cooccurrence_counts.sum() == 0 + assert h.reservoir.n_items.sum() == 0 + assert h.reservoir.n_seen.sum() == 0 + assert h.total_tokens_processed == 0 + + def test_reservoir_tokens_initialized_to_sentinel(self): + h = _make_harvester() + assert (h.reservoir.tokens == WINDOW_PAD_SENTINEL).all() + + +class TestReservoirAdd: + def test_fills_up_to_k(self): + h = _make_harvester() + k = h.reservoir.k + comp = 2 + + for i in range(k): + aw = _make_activation_windows([comp], torch.full((1, WINDOW), i, dtype=torch.long)) + h.reservoir.add(aw) + + assert h.reservoir.n_items[comp] == k + assert h.reservoir.n_seen[comp] == k + for i in range(k): + assert h.reservoir.tokens[comp, i, 0].item() == i + + def test_replacement_after_k(self): + h = _make_harvester() + k = h.reservoir.k + comp = 0 + + random.seed(42) + n_extra = 100 + for i in range(k + n_extra): + aw = _make_activation_windows([comp], torch.full((1, WINDOW), i, dtype=torch.long)) + h.reservoir.add(aw) + + assert h.reservoir.n_items[comp] == k + assert h.reservoir.n_seen[comp] == k + n_extra + + def test_n_items_never_exceeds_k(self): + h = _make_harvester() + k = h.reservoir.k + comp = 1 + + random.seed(0) + for i in range(k * 10): + aw = _make_activation_windows( + [comp], torch.full((1, WINDOW), i % VOCAB_SIZE, dtype=torch.long) + ) + h.reservoir.add(aw) + + assert h.reservoir.n_items[comp] == k + assert h.reservoir.n_seen[comp] == k * 10 + + def test_multiple_components_in_one_call(self): + h = _make_harvester() + aw = _make_activation_windows([0, 0, 3, 3, 3], torch.arange(5 * WINDOW).reshape(5, WINDOW)) + h.reservoir.add(aw) + + assert h.reservoir.n_items[0] == 2 + assert h.reservoir.n_seen[0] == 2 + assert h.reservoir.n_items[3] == 3 + assert h.reservoir.n_seen[3] == 3 + assert h.reservoir.n_items[1] == 0 + assert h.reservoir.n_items[2] == 0 + + def test_independent_component_tracking(self): + h = _make_harvester() + k = h.reservoir.k + + for i in range(k): + aw = _make_activation_windows([0], torch.full((1, WINDOW), i, dtype=torch.long)) + h.reservoir.add(aw) + + aw = _make_activation_windows([1], torch.full((1, WINDOW), 99, dtype=torch.long)) + h.reservoir.add(aw) + + assert h.reservoir.n_items[0] == k + assert h.reservoir.n_seen[0] == k + assert h.reservoir.n_items[1] == 1 + assert h.reservoir.n_seen[1] == 1 + + +class TestSaveLoadRoundtrip: + def test_roundtrip_preserves_all_fields(self, tmp_path: Path): + h = _make_harvester() + + h.firing_counts[0] = 10.0 + h.firing_counts[3] = 5.0 + h.activation_sums["ci"][0] = 2.5 + h.cooccurrence_counts[0, 3] = 7.0 + h.input_cooccurrence[0, 2] = 15 + h.input_marginals[2] = 100 + h.output_cooccurrence[0, 5] = 0.3 + h.output_marginals[5] = 1.0 + h.total_tokens_processed = 500 + + aw = _make_activation_windows([0], torch.tensor([[1, 2, 3]])) + h.reservoir.add(aw) + + path = tmp_path / "harvester.pt" + h.save(path) + loaded = Harvester.load(path, device=DEVICE) + + assert loaded.layer_names == h.layer_names + assert loaded.c_per_layer == h.c_per_layer + assert loaded.vocab_size == h.vocab_size + assert loaded.max_examples_per_component == h.max_examples_per_component + assert loaded.context_tokens_per_side == h.context_tokens_per_side + assert loaded.total_tokens_processed == h.total_tokens_processed + assert loaded.layer_offsets == h.layer_offsets + + for field in [ + "firing_counts", + "cooccurrence_counts", + "input_cooccurrence", + "input_marginals", + "output_cooccurrence", + "output_marginals", + ]: + assert torch.equal(getattr(loaded, field), getattr(h, field).cpu()), field + + for act_type in h.activation_sums: + assert torch.equal(loaded.activation_sums[act_type], h.activation_sums[act_type].cpu()) + + assert torch.equal(loaded.reservoir.tokens, h.reservoir.tokens.cpu()) + assert torch.equal(loaded.reservoir.n_items, h.reservoir.n_items.cpu()) + assert torch.equal(loaded.reservoir.n_seen, h.reservoir.n_seen.cpu()) + + def test_load_to_specific_device(self, tmp_path: Path): + h = _make_harvester() + path = tmp_path / "harvester.pt" + h.save(path) + loaded = Harvester.load(path, device=torch.device("cpu")) + assert loaded.device == torch.device("cpu") + assert loaded.firing_counts.device == torch.device("cpu") + + +class TestMerge: + def test_accumulators_sum(self): + h1 = _make_harvester() + h2 = _make_harvester() + + h1.firing_counts[0] = 10.0 + h2.firing_counts[0] = 20.0 + h1.activation_sums["ci"][1] = 3.0 + h2.activation_sums["ci"][1] = 7.0 + h1.cooccurrence_counts[0, 1] = 5.0 + h2.cooccurrence_counts[0, 1] = 3.0 + h1.input_cooccurrence[0, 2] = 10 + h2.input_cooccurrence[0, 2] = 5 + h1.input_marginals[2] = 100 + h2.input_marginals[2] = 200 + h1.output_cooccurrence[0, 0] = 0.5 + h2.output_cooccurrence[0, 0] = 0.3 + h1.output_marginals[0] = 1.0 + h2.output_marginals[0] = 2.0 + h1.total_tokens_processed = 100 + h2.total_tokens_processed = 200 + + h1.merge(h2) + + assert h1.firing_counts[0] == 30.0 + assert h1.activation_sums["ci"][1] == 10.0 + assert h1.cooccurrence_counts[0, 1] == 8.0 + assert h1.input_cooccurrence[0, 2] == 15 + assert h1.input_marginals[2] == 300 + assert h1.output_cooccurrence[0, 0] == pytest.approx(0.8) + assert h1.output_marginals[0] == 3.0 + assert h1.total_tokens_processed == 300 + + def test_merge_asserts_matching_structure(self): + h1 = _make_harvester() + h_different = Harvester( + layers=[("other", 4)], + vocab_size=VOCAB_SIZE, + max_examples_per_component=MAX_EXAMPLES, + context_tokens_per_side=CONTEXT_TOKENS_PER_SIDE, + max_examples_per_batch_per_component=5, + device=DEVICE, + ) + with pytest.raises(AssertionError): + h1.merge(h_different) + + def test_merge_reservoir_both_underfilled(self): + h1 = _make_harvester() + h2 = _make_harvester() + + for i in range(2): + aw = _make_activation_windows([0], torch.full((1, WINDOW), i, dtype=torch.long)) + h1.reservoir.add(aw) + for i in range(2): + aw = _make_activation_windows([0], torch.full((1, WINDOW), 10 + i, dtype=torch.long)) + h2.reservoir.add(aw) + + h1.merge(h2) + assert h1.reservoir.n_items[0] == 4 + assert h1.reservoir.n_seen[0] == 4 + + def test_merge_reservoir_n_seen_sums(self): + h1 = _make_harvester() + h2 = _make_harvester() + k = MAX_EXAMPLES + + random.seed(42) + for i in range(k + 10): + aw = _make_activation_windows( + [0], torch.full((1, WINDOW), i % VOCAB_SIZE, dtype=torch.long) + ) + h1.reservoir.add(aw) + for i in range(k + 5): + aw = _make_activation_windows( + [0], torch.full((1, WINDOW), i % VOCAB_SIZE, dtype=torch.long) + ) + h2.reservoir.add(aw) + + seen_before = h1.reservoir.n_seen[0].item() + h2.reservoir.n_seen[0].item() + h1.merge(h2) + + assert h1.reservoir.n_items[0] == k + assert h1.reservoir.n_seen[0] == seen_before + + def test_merge_preserves_other_components(self): + h1 = _make_harvester() + h2 = _make_harvester() + + aw1 = _make_activation_windows([0], torch.full((1, WINDOW), 1, dtype=torch.long)) + h1.reservoir.add(aw1) + aw2 = _make_activation_windows([3], torch.full((1, WINDOW), 2, dtype=torch.long)) + h2.reservoir.add(aw2) + + h1.merge(h2) + assert h1.reservoir.n_items[0] == 1 + assert h1.reservoir.n_items[3] == 1 + + +class TestBuildResults: + def _make_harvester_with_firings(self) -> Harvester: + h = _make_harvester() + + h.total_tokens_processed = 100 + h.firing_counts[0] = 10.0 + h.firing_counts[1] = 5.0 + h.activation_sums["ci"][0] = 2.0 + h.activation_sums["ci"][1] = 1.0 + + h.input_cooccurrence[0, 0] = 8 + h.input_cooccurrence[1, 1] = 3 + h.input_marginals[0] = 50 + h.input_marginals[1] = 30 + h.output_cooccurrence[0, 0] = 5.0 + h.output_cooccurrence[1, 1] = 2.0 + h.output_marginals[0] = 20.0 + h.output_marginals[1] = 15.0 + + for i in range(3): + aw = _make_activation_windows([0], torch.tensor([[i, i + 1, i + 2]])) + h.reservoir.add(aw) + + aw = _make_activation_windows([1], torch.tensor([[5, 6, 7]])) + h.reservoir.add(aw) + + return h + + def test_yields_only_firing_components(self): + h = self._make_harvester_with_firings() + results = list(h.build_results(pmi_top_k_tokens=3)) + + keys = {r.component_key for r in results} + assert keys == {"layer_0:0", "layer_0:1"} + + def test_skips_zero_firing_components(self): + h = self._make_harvester_with_firings() + results = list(h.build_results(pmi_top_k_tokens=3)) + + keys = {r.component_key for r in results} + for cidx in range(2, 4): + assert f"layer_0:{cidx}" not in keys + for cidx in range(4): + assert f"layer_1:{cidx}" not in keys + + def test_component_data_structure(self): + h = self._make_harvester_with_firings() + results = list(h.build_results(pmi_top_k_tokens=3)) + + comp0 = next(r for r in results if r.component_key == "layer_0:0") + assert comp0.layer == "layer_0" + assert comp0.component_idx == 0 + assert comp0.firing_density == pytest.approx(10.0 / 100) + assert comp0.mean_activations["ci"] == pytest.approx(2.0 / 100) + assert len(comp0.activation_examples) == 3 + assert comp0.input_token_pmi is not None + assert comp0.output_token_pmi is not None + + def test_activation_examples_have_correct_data(self): + h = self._make_harvester_with_firings() + results = list(h.build_results(pmi_top_k_tokens=3)) + + comp0 = next(r for r in results if r.component_key == "layer_0:0") + ex = comp0.activation_examples[0] + assert len(ex.token_ids) > 0 + assert len(ex.firings) == len(ex.token_ids) + for act_type in ex.activations: + assert len(ex.activations[act_type]) == len(ex.token_ids) + + def test_second_layer_component_keys(self): + h = _make_harvester() + h.total_tokens_processed = 100 + h.firing_counts[5] = 8.0 + h.activation_sums["ci"][5] = 1.6 + h.input_marginals[0] = 50 + h.input_cooccurrence[5, 0] = 4 + h.output_marginals[0] = 10.0 + h.output_cooccurrence[5, 0] = 2.0 + + aw = _make_activation_windows([5], torch.tensor([[1, 2, 3]])) + h.reservoir.add(aw) + + results = list(h.build_results(pmi_top_k_tokens=3)) + assert len(results) == 1 + assert results[0].component_key == "layer_1:1" + assert results[0].layer == "layer_1" + assert results[0].component_idx == 1 + + def test_no_results_when_nothing_fires(self): + h = _make_harvester() + h.total_tokens_processed = 100 + results = list(h.build_results(pmi_top_k_tokens=3)) + assert results == [] + + def test_sentinel_tokens_stripped_from_examples(self): + h = _make_harvester() + h.total_tokens_processed = 100 + h.firing_counts[0] = 5.0 + h.activation_sums["ci"][0] = 1.0 + h.input_marginals[0] = 50 + h.input_cooccurrence[0, 0] = 3 + h.output_marginals[0] = 10.0 + h.output_cooccurrence[0, 0] = 1.0 + + h.reservoir.tokens[0, 0] = torch.tensor([WINDOW_PAD_SENTINEL, 5, 6]) + h.reservoir.firings[0, 0] = torch.tensor([False, True, True]) + for at in h.reservoir.acts: + h.reservoir.acts[at][0, 0] = torch.tensor([0.0, 0.8, 0.9]) + h.reservoir.n_items[0] = 1 + h.reservoir.n_seen[0] = 1 + + results = list(h.build_results(pmi_top_k_tokens=3)) + assert len(results) == 1 + ex = results[0].activation_examples[0] + assert WINDOW_PAD_SENTINEL not in ex.token_ids + assert len(ex.token_ids) == 2 + + +class TestProcessBatch: + def _make_batch_inputs( + self, B: int = 2, S: int = 4 + ) -> tuple[ + torch.Tensor, dict[str, torch.Tensor], dict[str, dict[str, torch.Tensor]], torch.Tensor + ]: + batch = torch.randint(0, VOCAB_SIZE, (B, S)) + firings = {layer: torch.zeros(B, S, c, dtype=torch.bool) for layer, c in LAYERS} + activations = {layer: {at: torch.zeros(B, S, c) for at in ACT_TYPES} for layer, c in LAYERS} + output_probs = torch.zeros(B, S, VOCAB_SIZE) + return batch, firings, activations, output_probs + + def test_updates_total_tokens(self): + h = _make_harvester() + B, S = 2, 4 + batch, firings, activations, output_probs = self._make_batch_inputs(B, S) + + h.process_batch(batch, firings, activations, output_probs) + assert h.total_tokens_processed == B * S + + def test_firing_counts_accumulate(self): + h = _make_harvester() + B, S = 1, 2 + batch, firings, activations, output_probs = self._make_batch_inputs(B, S) + firings["layer_0"][0, 0, 0] = True + firings["layer_0"][0, 1, 0] = True + + h.process_batch(batch, firings, activations, output_probs) + assert h.firing_counts[0] == 2.0 + assert h.firing_counts[1] == 0.0 + + def test_activation_sums_accumulate(self): + h = _make_harvester() + B, S = 1, 1 + batch, firings, activations, output_probs = self._make_batch_inputs(B, S) + activations["layer_0"]["ci"][0, 0, 2] = 0.75 + + h.process_batch(batch, firings, activations, output_probs) + assert h.activation_sums["ci"][2].item() == pytest.approx(0.75) + + def test_cooccurrence_counts(self): + h = _make_harvester() + B, S = 1, 1 + batch, firings, activations, output_probs = self._make_batch_inputs(B, S) + firings["layer_0"][0, 0, 0] = True + firings["layer_0"][0, 0, 2] = True + + h.process_batch(batch, firings, activations, output_probs) + assert h.cooccurrence_counts[0, 2] == 1.0 + assert h.cooccurrence_counts[2, 0] == 1.0 + assert h.cooccurrence_counts[0, 0] == 1.0 + assert h.cooccurrence_counts[2, 2] == 1.0 + + +class TestExtractPaddingFiringWindows: + def test_center_window(self): + batch = torch.tensor([[10, 11, 12, 13, 14]]) + firings = torch.zeros(1, 5, 2, dtype=torch.bool) + firings[0, 2, 0] = True + activations = {"ci": torch.zeros(1, 5, 2)} + activations["ci"][0, 2, 0] = 0.9 + + result = extract_padding_firing_windows(batch, firings, activations, 10, 1) + assert result is not None + assert result.token_windows.shape == (1, 3) + assert result.token_windows[0].tolist() == [11, 12, 13] + assert result.activation_windows["ci"][0, 1].item() == pytest.approx(0.9) + + def test_left_boundary_padding(self): + batch = torch.tensor([[10, 11, 12]]) + firings = torch.zeros(1, 3, 1, dtype=torch.bool) + firings[0, 0, 0] = True + activations = {"ci": torch.zeros(1, 3, 1)} + + result = extract_padding_firing_windows(batch, firings, activations, 10, 2) + assert result is not None + tok_w = result.token_windows + assert tok_w.shape == (1, 5) + assert tok_w[0, 0] == WINDOW_PAD_SENTINEL + assert tok_w[0, 1] == WINDOW_PAD_SENTINEL + assert tok_w[0, 2] == 10 + assert tok_w[0, 3] == 11 + assert tok_w[0, 4] == 12 + + def test_right_boundary_padding(self): + batch = torch.tensor([[10, 11, 12]]) + firings = torch.zeros(1, 3, 1, dtype=torch.bool) + firings[0, 2, 0] = True + activations = {"ci": torch.zeros(1, 3, 1)} + + result = extract_padding_firing_windows(batch, firings, activations, 10, 2) + assert result is not None + tok_w = result.token_windows + assert tok_w[0, 0] == 10 + assert tok_w[0, 1] == 11 + assert tok_w[0, 2] == 12 + assert tok_w[0, 3] == WINDOW_PAD_SENTINEL + assert tok_w[0, 4] == WINDOW_PAD_SENTINEL + + def test_no_firings_returns_none(self): + batch = torch.tensor([[0, 1, 2]]) + firings = torch.zeros(1, 3, 2, dtype=torch.bool) + activations = {"ci": torch.zeros(1, 3, 2)} + + result = extract_padding_firing_windows(batch, firings, activations, 10, 1) + assert result is None + + def test_multiple_firings(self): + batch = torch.tensor([[0, 1, 2, 3, 4]]) + firings = torch.zeros(1, 5, 3, dtype=torch.bool) + firings[0, 1, 0] = True + firings[0, 3, 2] = True + activations = {"ci": torch.zeros(1, 5, 3)} + + result = extract_padding_firing_windows(batch, firings, activations, 10, 1) + assert result is not None + assert result.token_windows.shape == (2, 3) + assert result.token_windows[0].tolist() == [0, 1, 2] + assert result.token_windows[1].tolist() == [2, 3, 4] diff --git a/tests/harvest/test_reservoir.py b/tests/harvest/test_reservoir.py new file mode 100644 index 000000000..0acfff041 --- /dev/null +++ b/tests/harvest/test_reservoir.py @@ -0,0 +1,212 @@ +"""Tests for ActivationExamplesReservoir.""" + +import random + +import pytest +import torch + +from spd.harvest.reservoir import ( + WINDOW_PAD_SENTINEL, + ActivationExamplesReservoir, + ActivationWindows, +) + +DEVICE = torch.device("cpu") +N_COMPONENTS = 4 +K = 3 +WINDOW = 3 + +ACT_TYPES = ["ci", "inner"] + + +def _make_reservoir() -> ActivationExamplesReservoir: + return ActivationExamplesReservoir.create(N_COMPONENTS, K, WINDOW, DEVICE) + + +def _make_activation_window( + comp: list[int], + tokens: torch.Tensor, + firings: torch.Tensor | None = None, +) -> ActivationWindows: + n = len(comp) + w = tokens.shape[1] + if firings is None: + firings = torch.ones(n, w, dtype=torch.bool) + return ActivationWindows( + component_idx=torch.tensor(comp), + token_windows=tokens, + firing_windows=firings, + activation_windows={at: torch.ones(n, w) * 0.5 for at in ACT_TYPES}, + ) + + +class TestAdd: + def test_fills_up_to_k(self): + r = _make_reservoir() + comp = 1 + + for i in range(K): + r.add(_make_activation_window([comp], torch.full((1, WINDOW), i, dtype=torch.long))) + + assert r.n_items[comp] == K + assert r.n_seen[comp] == K + for i in range(K): + assert r.tokens[comp, i, 0].item() == i + + def test_replacement_after_k(self): + r = _make_reservoir() + comp = 0 + random.seed(42) + + n_total = K + 50 + for i in range(n_total): + r.add(_make_activation_window([comp], torch.full((1, WINDOW), i, dtype=torch.long))) + + assert r.n_items[comp] == K + assert r.n_seen[comp] == n_total + + def test_written_data_matches_input(self): + r = _make_reservoir() + tokens = torch.tensor([[7, 8, 9]]) + firings = torch.tensor([[True, False, True]]) + aw = ActivationWindows( + component_idx=torch.tensor([2]), + token_windows=tokens, + firing_windows=firings, + activation_windows={"ci": torch.tensor([[0.1, 0.2, 0.3]])}, + ) + r.add(aw) + + assert torch.equal(r.tokens[2, 0], tokens[0]) + assert torch.equal(r.firings[2, 0], firings[0]) + assert torch.allclose(r.acts["ci"][2, 0], torch.tensor([0.1, 0.2, 0.3])) + + +class TestMerge: + def test_merge_combines_underfilled(self): + r1 = _make_reservoir() + r2 = _make_reservoir() + + r1.add(_make_activation_window([0], torch.full((1, WINDOW), 1, dtype=torch.long))) + r2.add(_make_activation_window([0], torch.full((1, WINDOW), 2, dtype=torch.long))) + + r1.merge(r2) + assert r1.n_items[0] == 2 + assert r1.n_seen[0] == 2 + + def test_merge_weighted_by_n_seen(self): + torch.manual_seed(0) + + n_trials = 200 + heavy_wins = 0 + for _ in range(n_trials): + r_heavy = _make_reservoir() + r_light = _make_reservoir() + + for _ in range(K): + r_heavy.add( + _make_activation_window([0], torch.full((1, WINDOW), 1, dtype=torch.long)) + ) + r_heavy.n_seen[0] = 1000 + + for _ in range(K): + r_light.add( + _make_activation_window([0], torch.full((1, WINDOW), 2, dtype=torch.long)) + ) + r_light.n_seen[0] = 1 + + r_heavy.merge(r_light) + from_heavy = (r_heavy.tokens[0, :, 0] == 1).sum().item() + if from_heavy == K: + heavy_wins += 1 + + assert heavy_wins > n_trials * 0.8 + + def test_merge_n_seen_sums(self): + r1 = _make_reservoir() + r2 = _make_reservoir() + + for i in range(K + 5): + r1.add(_make_activation_window([0], torch.full((1, WINDOW), i % 10, dtype=torch.long))) + for i in range(K + 3): + r2.add(_make_activation_window([0], torch.full((1, WINDOW), i % 10, dtype=torch.long))) + + total = r1.n_seen[0].item() + r2.n_seen[0].item() + r1.merge(r2) + assert r1.n_seen[0] == total + assert r1.n_items[0] == K + + +class TestExamples: + def test_yields_correct_items(self): + r = _make_reservoir() + for i in range(2): + aw = ActivationWindows( + component_idx=torch.tensor([0]), + token_windows=torch.full((1, WINDOW), i + 10, dtype=torch.long), + firing_windows=torch.ones(1, WINDOW, dtype=torch.bool), + activation_windows={"ci": torch.ones(1, WINDOW) * (i + 1) * 0.1}, + ) + r.add(aw) + + examples = list(r.examples(0)) + assert len(examples) == 2 + ex0 = examples[0] + assert ex0.token_ids == [10, 10, 10] + assert all(ex0.firings) + assert ex0.activations["ci"] == [pytest.approx(0.1)] * 3 + + def test_filters_sentinels(self): + r = _make_reservoir() + r.tokens[0, 0] = torch.tensor([WINDOW_PAD_SENTINEL, 5, 6]) + r.firings[0, 0] = torch.tensor([False, True, True]) + r.acts["ci"] = torch.zeros(N_COMPONENTS, K, WINDOW) + r.acts["ci"][0, 0] = torch.tensor([0.0, 0.8, 0.9]) + r.n_items[0] = 1 + r.n_seen[0] = 1 + + examples = list(r.examples(0)) + assert len(examples) == 1 + ex = examples[0] + assert ex.token_ids == [5, 6] + assert ex.firings == [True, True] + assert ex.activations["ci"] == [pytest.approx(0.8), pytest.approx(0.9)] + + def test_empty_component_yields_nothing(self): + r = _make_reservoir() + assert list(r.examples(0)) == [] + + +class TestStateDictRoundtrip: + def test_roundtrip_preserves_data(self): + r = _make_reservoir() + for i in range(2): + aw = ActivationWindows( + component_idx=torch.tensor([1]), + token_windows=torch.full((1, WINDOW), i + 5, dtype=torch.long), + firing_windows=torch.ones(1, WINDOW, dtype=torch.bool), + activation_windows={"ci": torch.ones(1, WINDOW) * 0.5}, + ) + r.add(aw) + + sd = r.state_dict() + restored = ActivationExamplesReservoir.from_state_dict(sd, device=DEVICE) + + assert restored.k == r.k + assert restored.window == r.window + assert torch.equal(restored.tokens, r.tokens) + assert torch.equal(restored.firings, r.firings) + for at in r.acts: + assert torch.equal(restored.acts[at], r.acts[at]) + assert torch.equal(restored.n_items, r.n_items) + assert torch.equal(restored.n_seen, r.n_seen) + + def test_state_dict_on_cpu(self): + r = _make_reservoir() + r.add(_make_activation_window([0], torch.full((1, WINDOW), 1, dtype=torch.long))) + + sd = r.state_dict() + assert isinstance(sd["tokens"], torch.Tensor) and sd["tokens"].device == torch.device("cpu") + assert isinstance(sd["n_items"], torch.Tensor) and sd["n_items"].device == torch.device( + "cpu" + ) diff --git a/tests/harvest/test_sampling.py b/tests/harvest/test_sampling.py index 59cbb056a..621f6c48f 100644 --- a/tests/harvest/test_sampling.py +++ b/tests/harvest/test_sampling.py @@ -4,8 +4,7 @@ import torch -from spd.harvest.lib.reservoir_sampler import ReservoirSampler, ReservoirState -from spd.harvest.lib.sampling import compute_pmi, sample_at_most_n_per_group, top_k_pmi +from spd.harvest.sampling import compute_pmi, sample_at_most_n_per_group, top_k_pmi class TestSampleAtMostNPerGroup: @@ -191,79 +190,3 @@ def test_all_zeros_returns_empty(self) -> None: assert top == [] assert bottom == [] - - -class TestReservoirSampler: - def test_fills_up_to_k(self) -> None: - sampler: ReservoirSampler[int] = ReservoirSampler(k=5) - for i in range(3): - sampler.add(i) - - assert len(sampler.samples) == 3 - assert sampler.n_seen == 3 - - def test_caps_at_k(self) -> None: - sampler: ReservoirSampler[int] = ReservoirSampler(k=5) - for i in range(100): - sampler.add(i) - - assert len(sampler.samples) == 5 - assert sampler.n_seen == 100 - - def test_state_roundtrip(self) -> None: - sampler: ReservoirSampler[str] = ReservoirSampler(k=3) - sampler.add("a") - sampler.add("b") - - state = sampler.get_state() - restored = ReservoirSampler.from_state(state) - - assert restored.k == sampler.k - assert restored.samples == sampler.samples - assert restored.n_seen == sampler.n_seen - - -class TestReservoirStateMerge: - def test_merge_underfilled_reservoirs(self) -> None: - state1: ReservoirState[str] = ReservoirState(k=5, samples=["a", "b"], n_seen=100) - state2: ReservoirState[str] = ReservoirState(k=5, samples=["c"], n_seen=100) - - merged = ReservoirState.merge([state1, state2]) - - assert merged.k == 5 - assert merged.n_seen == 200 - assert set(merged.samples) == {"a", "b", "c"} - - def test_merge_preserves_k(self) -> None: - state1: ReservoirState[int] = ReservoirState(k=3, samples=[1, 2, 3], n_seen=100) - state2: ReservoirState[int] = ReservoirState(k=3, samples=[4, 5, 6], n_seen=100) - - merged = ReservoirState.merge([state1, state2]) - - assert merged.k == 3 - assert len(merged.samples) == 3 - assert merged.n_seen == 200 - - def test_merge_empty_states(self) -> None: - state1: ReservoirState[int] = ReservoirState(k=5, samples=[], n_seen=0) - state2: ReservoirState[int] = ReservoirState(k=5, samples=[], n_seen=0) - - merged = ReservoirState.merge([state1, state2]) - - assert merged.samples == [] - assert merged.n_seen == 0 - - def test_merge_weighted_by_n_seen(self) -> None: - # State1 saw way more samples, so its items should be more likely to appear - state1: ReservoirState[str] = ReservoirState(k=2, samples=["a", "b"], n_seen=1000) - state2: ReservoirState[str] = ReservoirState(k=2, samples=["c", "d"], n_seen=10) - - # Run multiple merges and check that state1 items dominate - from_state1 = 0 - n_trials = 100 - for _ in range(n_trials): - merged = ReservoirState.merge([state1, state2]) - from_state1 += sum(1 for s in merged.samples if s in ["a", "b"]) - - # With 1000:10 ratio, state1 items should appear ~99% of the time - assert from_state1 > n_trials * 2 * 0.9 # at least 90% from state1 diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index fa32cc1e3..fe0bf14d2 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -7,6 +7,7 @@ from jaxtyping import Float from torch import Tensor +from spd.configs import LayerwiseCiConfig from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo @@ -38,7 +39,9 @@ def forward(self, x: Tensor) -> Tensor: return x -def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: +def make_one_layer_component_model( + weight: Float[Tensor, "d_out d_in"], C: int = 1 +) -> ComponentModel: """Create a ComponentModel with a single linear layer for testing. Args: @@ -55,9 +58,8 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo comp_model = ComponentModel( target_model=target, - module_path_info=[ModulePathInfo(module_path="fc", C=1)], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + module_path_info=[ModulePathInfo(module_path="fc", C=C)], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -93,8 +95,7 @@ def make_two_layer_component_model( ModulePathInfo(module_path="fc1", C=1), ModulePathInfo(module_path="fc2", C=1), ], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) diff --git a/tests/metrics/test_attn_patterns_recon_loss.py b/tests/metrics/test_attn_patterns_recon_loss.py new file mode 100644 index 000000000..ebbc4e655 --- /dev/null +++ b/tests/metrics/test_attn_patterns_recon_loss.py @@ -0,0 +1,380 @@ +import torch + +from spd.configs import LayerwiseCiConfig +from spd.metrics.attn_patterns_recon_loss import ( + CIMaskedAttnPatternsReconLoss, + StochasticAttnPatternsReconLoss, + _compute_attn_patterns, +) +from spd.models.component_model import ComponentModel +from spd.pretrain.models.gpt2 import GPT2, GPT2Config +from spd.pretrain.models.gpt2_simple import GPT2Simple, GPT2SimpleConfig +from spd.pretrain.models.llama_simple import LlamaSimple, LlamaSimpleConfig +from spd.utils.module_utils import ModulePathInfo + + +def _make_gpt2_component_model(n_embd: int = 16, n_head: int = 2) -> ComponentModel: + """Create a 1-layer GPT2Simple wrapped in ComponentModel with q_proj/k_proj decomposed.""" + config = GPT2SimpleConfig( + model_type="GPT2Simple", + block_size=32, + vocab_size=64, + n_layer=1, + n_head=n_head, + n_embd=n_embd, + flash_attention=False, + ) + target = GPT2Simple(config) + target.requires_grad_(False) + + module_path_info = [ + ModulePathInfo(module_path="h.0.attn.q_proj", C=n_embd), + ModulePathInfo(module_path="h.0.attn.k_proj", C=n_embd), + ] + + comp_model = ComponentModel( + target_model=target, + module_path_info=module_path_info, + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), + pretrained_model_output_attr="idx_0", + sigmoid_type="leaky_hard", + ) + return comp_model + + +def _make_gpt2_c_attn_component_model(n_embd: int = 16, n_head: int = 2) -> ComponentModel: + """Create a 1-layer GPT2 wrapped in ComponentModel with combined c_attn decomposed.""" + config = GPT2Config( + model_type="GPT2", + block_size=32, + vocab_size=64, + n_layer=1, + n_head=n_head, + n_embd=n_embd, + flash_attention=False, + ) + target = GPT2(config) + target.requires_grad_(False) + + module_path_info = [ + ModulePathInfo(module_path="h_torch.0.attn.c_attn", C=n_embd), + ] + + comp_model = ComponentModel( + target_model=target, + module_path_info=module_path_info, + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), + pretrained_model_output_attr="idx_0", + sigmoid_type="leaky_hard", + ) + return comp_model + + +class TestAttnPatternsReconLoss: + def test_identity_decomposition_kl_near_zero(self) -> None: + """With V=weight.T and U=eye, the component exactly reproduces the original weight, + so attention patterns should match and KL divergence should be ~0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_gpt2_component_model(n_embd=n_embd, n_head=n_head) + + for path in ["h.0.attn.q_proj", "h.0.attn.k_proj"]: + target_weight = model.target_weight(path) + with torch.no_grad(): + model.components[path].V.copy_(target_weight.T) + model.components[path].U.copy_(torch.eye(n_embd)) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device="cpu", + n_heads=n_head, + q_proj_path="h.*.attn.q_proj", + k_proj_path="h.*.attn.k_proj", + c_attn_path=None, + ) + metric.update(batch=batch, pre_weight_acts=pre_weight_acts, ci=ci) + loss = metric.compute() + + assert loss.item() < 1e-4, f"Expected KL ≈ 0 with identity decomposition, got {loss.item()}" + + def test_random_init_kl_positive(self) -> None: + """With random V/U init, attention patterns should differ and KL should be > 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_gpt2_component_model(n_embd=n_embd, n_head=n_head) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device="cpu", + n_heads=n_head, + q_proj_path="h.*.attn.q_proj", + k_proj_path="h.*.attn.k_proj", + c_attn_path=None, + ) + metric.update(batch=batch, pre_weight_acts=pre_weight_acts, ci=ci) + loss = metric.compute() + + assert loss.item() > 0.01, f"Expected KL > 0 with random init, got {loss.item()}" + + def test_stochastic_identity_decomposition_kl_near_zero(self) -> None: + """Stochastic variant with identity decomposition should give KL ≈ 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_gpt2_component_model(n_embd=n_embd, n_head=n_head) + + for path in ["h.0.attn.q_proj", "h.0.attn.k_proj"]: + target_weight = model.target_weight(path) + with torch.no_grad(): + model.components[path].V.copy_(target_weight.T) + model.components[path].U.copy_(torch.eye(n_embd)) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = StochasticAttnPatternsReconLoss( + model=model, + device="cpu", + sampling="continuous", + use_delta_component=False, + n_mask_samples=2, + n_heads=n_head, + q_proj_path="h.*.attn.q_proj", + k_proj_path="h.*.attn.k_proj", + c_attn_path=None, + ) + weight_deltas = model.calc_weight_deltas() + metric.update( + batch=batch, pre_weight_acts=pre_weight_acts, ci=ci, weight_deltas=weight_deltas + ) + loss = metric.compute() + + assert loss.item() < 1e-4, f"Expected KL ≈ 0 with identity decomposition, got {loss.item()}" + + def test_stochastic_random_init_kl_positive(self) -> None: + """Stochastic variant with random init should give KL > 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_gpt2_component_model(n_embd=n_embd, n_head=n_head) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = StochasticAttnPatternsReconLoss( + model=model, + device="cpu", + sampling="continuous", + use_delta_component=False, + n_mask_samples=2, + n_heads=n_head, + q_proj_path="h.*.attn.q_proj", + k_proj_path="h.*.attn.k_proj", + c_attn_path=None, + ) + weight_deltas = model.calc_weight_deltas() + metric.update( + batch=batch, pre_weight_acts=pre_weight_acts, ci=ci, weight_deltas=weight_deltas + ) + loss = metric.compute() + + assert loss.item() > 0.01, f"Expected KL > 0 with random init, got {loss.item()}" + + +class TestCAttnPatternsReconLoss: + def test_c_attn_identity_decomposition_kl_near_zero(self) -> None: + """Combined c_attn path with identity decomposition should give KL ≈ 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_gpt2_c_attn_component_model(n_embd=n_embd, n_head=n_head) + + path = "h_torch.0.attn.c_attn" + target_weight = model.target_weight(path) # (3*n_embd, n_embd) + with torch.no_grad(): + model.components[path].V.copy_(torch.eye(n_embd)) # (n_embd, C=n_embd) + model.components[path].U.copy_(target_weight.T) # (C=n_embd, 3*n_embd) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device="cpu", + n_heads=n_head, + q_proj_path=None, + k_proj_path=None, + c_attn_path="h_torch.*.attn.c_attn", + ) + metric.update(batch=batch, pre_weight_acts=pre_weight_acts, ci=ci) + loss = metric.compute() + + assert loss.item() < 1e-4, f"Expected KL ≈ 0 with identity decomposition, got {loss.item()}" + + def test_c_attn_random_init_kl_positive(self) -> None: + """Combined c_attn path with random init should give KL > 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_gpt2_c_attn_component_model(n_embd=n_embd, n_head=n_head) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device="cpu", + n_heads=n_head, + q_proj_path=None, + k_proj_path=None, + c_attn_path="h_torch.*.attn.c_attn", + ) + metric.update(batch=batch, pre_weight_acts=pre_weight_acts, ci=ci) + loss = metric.compute() + + assert loss.item() > 0.01, f"Expected KL > 0 with random init, got {loss.item()}" + + +def _make_llama_component_model(n_embd: int = 16, n_head: int = 2) -> ComponentModel: + """Create a 1-layer LlamaSimple with RoPE, wrapped in ComponentModel with q_proj/k_proj.""" + config = LlamaSimpleConfig( + model_type="LlamaSimple", + block_size=32, + vocab_size=64, + n_layer=1, + n_head=n_head, + n_embd=n_embd, + n_intermediate=n_embd * 4 * 2 // 3, + use_grouped_query_attention=True, + n_key_value_heads=n_head, + flash_attention=False, + n_ctx=32, + rotary_dim=n_embd // n_head, + ) + target = LlamaSimple(config) + target.requires_grad_(False) + + module_path_info = [ + ModulePathInfo(module_path="h.0.attn.q_proj", C=n_embd), + ModulePathInfo(module_path="h.0.attn.k_proj", C=n_embd), + ] + + return ComponentModel( + target_model=target, + module_path_info=module_path_info, + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), + pretrained_model_output_attr="idx_0", + sigmoid_type="leaky_hard", + ) + + +class TestRoPEAttnPatternsReconLoss: + def test_rope_identity_decomposition_kl_near_zero(self) -> None: + """Identity decomposition with auto-detected RoPE should give KL ~ 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_llama_component_model(n_embd=n_embd, n_head=n_head) + + for path in ["h.0.attn.q_proj", "h.0.attn.k_proj"]: + target_weight = model.target_weight(path) + with torch.no_grad(): + model.components[path].V.copy_(target_weight.T) + model.components[path].U.copy_(torch.eye(n_embd)) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device="cpu", + n_heads=n_head, + q_proj_path="h.*.attn.q_proj", + k_proj_path="h.*.attn.k_proj", + c_attn_path=None, + ) + metric.update(batch=batch, pre_weight_acts=pre_weight_acts, ci=ci) + loss = metric.compute() + + assert loss.item() < 1e-4, f"Expected KL ≈ 0 with identity decomposition, got {loss.item()}" + + def test_rope_random_init_kl_positive(self) -> None: + """Random init with RoPE should give KL > 0.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_llama_component_model(n_embd=n_embd, n_head=n_head) + + batch = torch.randint(0, 64, (2, 8)) + target_output = model(batch, cache_type="input") + pre_weight_acts = target_output.cache + ci = model.calc_causal_importances( + pre_weight_acts=pre_weight_acts, detach_inputs=False, sampling="continuous" + ) + + metric = CIMaskedAttnPatternsReconLoss( + model=model, + device="cpu", + n_heads=n_head, + q_proj_path="h.*.attn.q_proj", + k_proj_path="h.*.attn.k_proj", + c_attn_path=None, + ) + metric.update(batch=batch, pre_weight_acts=pre_weight_acts, ci=ci) + loss = metric.compute() + + assert loss.item() > 0.01, f"Expected KL > 0 with random init, got {loss.item()}" + + def test_rope_changes_patterns(self) -> None: + """Applying RoPE should produce different attention patterns than without.""" + torch.manual_seed(42) + n_embd = 16 + n_head = 2 + model = _make_llama_component_model(n_embd=n_embd, n_head=n_head) + attn_module = model.target_model.get_submodule("h.0.attn") + + q = torch.randn(2, 8, n_embd) + k = torch.randn(2, 8, n_embd) + + patterns_without_rope = _compute_attn_patterns(q, k, n_head, attn_module=None) + patterns_with_rope = _compute_attn_patterns(q, k, n_head, attn_module=attn_module) + + assert not torch.allclose(patterns_without_rope, patterns_with_rope, atol=1e-6), ( + "RoPE should change attention patterns" + ) diff --git a/tests/metrics/test_recon_losses.py b/tests/metrics/test_recon_losses.py new file mode 100644 index 000000000..931ca0f2a --- /dev/null +++ b/tests/metrics/test_recon_losses.py @@ -0,0 +1,321 @@ +"""Sanity checks for stochastic, CI, PGD, and persistent PGD reconstruction losses.""" + +from collections.abc import Callable + +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +from spd.configs import PGDConfig +from spd.metrics import ci_masked_recon_loss, pgd_recon_loss, stochastic_recon_loss +from spd.metrics.hidden_acts_recon_loss import ( + CIHiddenActsReconLoss, + _sum_per_module_mse, + calc_hidden_acts_mse, +) +from spd.metrics.ppgd_eval_losses import PPGDReconEval +from spd.models.component_model import CIOutputs, ComponentModel +from spd.models.components import make_mask_infos +from spd.persistent_pgd import PPGDSources, get_ppgd_mask_infos +from tests.metrics.fixtures import ( + OneLayerLinearModel, + TwoLayerLinearModel, + make_one_layer_component_model, + make_two_layer_component_model, +) + +ReconLossFn = Callable[[ComponentModel, Tensor, Tensor, dict[str, Tensor]], Tensor] + + +def _stochastic( + model: ComponentModel, + batch: Tensor, + target_out: Tensor, + ci: dict[str, Tensor], +) -> Tensor: + return stochastic_recon_loss( + model=model, + sampling="continuous", + n_mask_samples=4, + output_loss_type="mse", + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=None, + ) + + +def _ci( + model: ComponentModel, + batch: Tensor, + target_out: Tensor, + ci: dict[str, Tensor], +) -> Tensor: + return ci_masked_recon_loss( + model=model, + output_loss_type="mse", + batch=batch, + target_out=target_out, + ci=ci, + ) + + +def _pgd( + model: ComponentModel, + batch: Tensor, + target_out: Tensor, + ci: dict[str, Tensor], +) -> Tensor: + return pgd_recon_loss( + model=model, + batch=batch, + target_out=target_out, + output_loss_type="mse", + ci=ci, + weight_deltas=None, + pgd_config=PGDConfig( + init="random", + step_size=0.1, + n_steps=5, + mask_scope="unique_per_datapoint", + ), + ) + + +LOSS_FNS = [_stochastic, _ci, _pgd] + + +@pytest.mark.parametrize("loss_fn", LOSS_FNS, ids=["stochastic", "ci", "pgd"]) +class TestOutputReconLoss: + def test_perfect_init_zero_recon(self, loss_fn: ReconLossFn) -> None: + """V=W.T, U=I with CI=1 → all masks are 1 → output matches target → loss ≈ 0.""" + torch.manual_seed(42) + d = 3 + weight = torch.randn(d, d) + model = make_one_layer_component_model(weight=weight, C=d) + + assert isinstance(model.target_model, OneLayerLinearModel) + target_weight = model.target_model.fc.weight.data + with torch.no_grad(): + model.components["fc"].V.copy_(target_weight.T) + model.components["fc"].U.copy_(torch.eye(d)) + + batch = torch.randn(4, d) + target_out = model.target_model(batch) + ci = {"fc": torch.ones(4, d)} + + loss = loss_fn(model, batch, target_out, ci) + assert loss < 1e-5, f"Expected ~0 loss with perfect init, got {loss}" + + def test_random_init_high_recon(self, loss_fn: ReconLossFn) -> None: + """Random V and U should give substantially nonzero recon loss.""" + torch.manual_seed(42) + d = 3 + weight = torch.randn(d, d) + model = make_one_layer_component_model(weight=weight, C=d) + + batch = torch.randn(4, d) + target_out = model.target_model(batch) + ci = {"fc": torch.ones(4, d)} + + loss = loss_fn(model, batch, target_out, ci) + assert loss > 0.01, f"Expected high loss with random init, got {loss}" + + +def test_output_recon_manual_calculation() -> None: + """Verify CI-masked output recon loss matches a manual forward pass computation.""" + torch.manual_seed(42) + + fc_weight = torch.randn(2, 2) + model = make_one_layer_component_model(weight=fc_weight) + + V = model.components["fc"].V + U = model.components["fc"].U + + batch = torch.randn(1, 2) + target_out = torch.randn(1, 2) + ci = {"fc": torch.tensor([[0.8]])} + + # Manual: component_acts = batch @ V, masked = acts * ci, out = masked @ U + out = (batch @ V * ci["fc"]) @ U + expected_loss = torch.nn.functional.mse_loss(out, target_out, reduction="sum") / out.numel() + + actual_loss = ci_masked_recon_loss( + model=model, + output_loss_type="mse", + batch=batch, + target_out=target_out, + ci=ci, + ) + + assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( + f"Expected {expected_loss}, got {actual_loss}" + ) + + +def test_per_module_recon_manual_calculation() -> None: + """Verify per-module recon loss matches manual computation for a two-layer model.""" + torch.manual_seed(42) + + fc1_weight = torch.randn(3, 2) + fc2_weight = torch.randn(2, 3) + model = make_two_layer_component_model(weight1=fc1_weight, weight2=fc2_weight) + + V1, U1 = model.components["fc1"].V, model.components["fc1"].U + V2, U2 = model.components["fc2"].V, model.components["fc2"].U + + batch = torch.randn(1, 2) + ci = {"fc1": torch.tensor([[0.8]]), "fc2": torch.tensor([[0.7]])} + + # Target activations (output of each layer through the target model) + assert isinstance(model.target_model, TwoLayerLinearModel) + target_fc1 = batch @ model.target_model.fc1.weight.data.T + target_fc2 = target_fc1 @ model.target_model.fc2.weight.data.T + + # Component activations with CI as masks (fc2 input is fc1's component output, not target) + comp_fc1 = batch @ (V1 * ci["fc1"]) @ U1 + comp_fc2 = comp_fc1 @ (V2 * ci["fc2"]) @ U2 + + expected_fc1_mse = F.mse_loss(comp_fc1, target_fc1, reduction="sum") + expected_fc2_mse = F.mse_loss(comp_fc2, target_fc2, reduction="sum") + expected_total = (expected_fc1_mse + expected_fc2_mse) / ( + target_fc1.numel() + target_fc2.numel() + ) + + # Actual computation + target_acts = model(batch, cache_type="output").cache + mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) + per_module, _ = calc_hidden_acts_mse(model, batch, mask_infos, target_acts) + sum_mse, n_examples = _sum_per_module_mse(per_module) + actual_total = sum_mse / n_examples + + assert torch.allclose(actual_total, expected_total, rtol=1e-5) + fc1_mse, _ = per_module["fc1"] + assert torch.allclose(fc1_mse, expected_fc1_mse, rtol=1e-5) + fc2_mse, _ = per_module["fc2"] + assert torch.allclose(fc2_mse, expected_fc2_mse, rtol=1e-5) + + +def test_per_module_recon_metric_keys() -> None: + """CIHiddenActsReconLoss.compute() returns per-module + total keys.""" + torch.manual_seed(42) + + model = make_two_layer_component_model(weight1=torch.randn(3, 2), weight2=torch.randn(2, 3)) + batch = torch.randn(2, 2) + + target_output = model(batch, cache_type="input") + ci = model.calc_causal_importances(pre_weight_acts=target_output.cache, sampling="continuous") + + metric = CIHiddenActsReconLoss(model=model, device="cpu") + metric.update(batch=batch, ci=ci) + result = metric.compute() + + assert set(result.keys()) == { + "CIHiddenActsReconLoss", + "CIHiddenActsReconLoss/fc1", + "CIHiddenActsReconLoss/fc2", + } + for v in result.values(): + assert v.item() >= 0 + + +def _make_ci_outputs(ci: dict[str, Tensor]) -> CIOutputs: + return CIOutputs( + lower_leaky=ci, + upper_leaky=ci, + pre_sigmoid={k: torch.ones_like(v) * 10 for k, v in ci.items()}, + ) + + +def test_ppgd_recon_eval_metric_keys() -> None: + """PPGDReconEval.compute() returns hidden_acts (total + per-module) and output_recon keys.""" + torch.manual_seed(42) + + model = make_two_layer_component_model(weight1=torch.randn(3, 2), weight2=torch.randn(2, 3)) + batch = torch.randn(2, 2) + target_out = model.target_model(batch) + ci = {"fc1": torch.ones(2, 1), "fc2": torch.ones(2, 1)} + sources: PPGDSources = {k: torch.zeros(1, v.shape[-1]) for k, v in ci.items()} + + metric = PPGDReconEval( + model=model, + device="cpu", + effective_sources=sources, + use_delta_component=False, + output_loss_type="mse", + metric_name="my_ppgd", + ) + metric.update( + batch=batch, + ci=_make_ci_outputs(ci), + weight_deltas={}, + target_out=target_out, + ) + result = metric.compute() + + assert set(result.keys()) == { + "my_ppgd/hidden_acts", + "my_ppgd/hidden_acts/fc1", + "my_ppgd/hidden_acts/fc2", + "my_ppgd/output_recon", + } + for v in result.values(): + assert v.item() >= 0 + + +def test_ppgd_recon_eval_manual_calculation() -> None: + """Verify PPGD hidden-acts MSE and output recon match hand-computed values.""" + torch.manual_seed(42) + + fc1_weight = torch.randn(3, 2) + fc2_weight = torch.randn(2, 3) + model = make_two_layer_component_model(weight1=fc1_weight, weight2=fc2_weight) + + V1, U1 = model.components["fc1"].V, model.components["fc1"].U + V2, U2 = model.components["fc2"].V, model.components["fc2"].U + + batch = torch.randn(1, 2) + ci = {"fc1": torch.tensor([[0.8]]), "fc2": torch.tensor([[0.7]])} + adv_sources: PPGDSources = {"fc1": torch.tensor([[0.3]]), "fc2": torch.tensor([[0.5]])} + + # mask = ci + (1 - ci) * source + mask_fc1 = ci["fc1"] + (1 - ci["fc1"]) * adv_sources["fc1"] # 0.8 + 0.2*0.3 = 0.86 + mask_fc2 = ci["fc2"] + (1 - ci["fc2"]) * adv_sources["fc2"] # 0.7 + 0.3*0.5 = 0.85 + + # Target activations + assert isinstance(model.target_model, TwoLayerLinearModel) + target_fc1 = batch @ model.target_model.fc1.weight.data.T + target_fc2 = target_fc1 @ model.target_model.fc2.weight.data.T + + # Component activations with PPGD masks + comp_fc1 = batch @ (V1 * mask_fc1) @ U1 + comp_fc2 = comp_fc1 @ (V2 * mask_fc2) @ U2 + + expected_fc1_mse = F.mse_loss(comp_fc1, target_fc1, reduction="sum") + expected_fc2_mse = F.mse_loss(comp_fc2, target_fc2, reduction="sum") + expected_hidden_total = (expected_fc1_mse + expected_fc2_mse) / ( + target_fc1.numel() + target_fc2.numel() + ) + expected_output_recon = ((comp_fc2 - target_fc2) ** 2).sum() / comp_fc2.numel() + + # Actual computation via get_ppgd_mask_infos + calc_hidden_acts_mse + target_acts = model(batch, cache_type="output").cache + mask_infos = get_ppgd_mask_infos( + ci=ci, + weight_deltas=None, + ppgd_sources=adv_sources, + routing_masks="all", + batch_dims=(1,), + ) + per_module, comp_output = calc_hidden_acts_mse(model, batch, mask_infos, target_acts) + sum_mse, n_examples = _sum_per_module_mse(per_module) + actual_hidden_total = sum_mse / n_examples + actual_output_recon = ((comp_output - target_fc2) ** 2).sum() / comp_output.numel() + + assert torch.allclose(actual_hidden_total, expected_hidden_total, rtol=1e-5) + assert torch.allclose(actual_output_recon, expected_output_recon, rtol=1e-5) + fc1_mse, _ = per_module["fc1"] + assert torch.allclose(fc1_mse, expected_fc1_mse, rtol=1e-5) + fc2_mse, _ = per_module["fc2"] + assert torch.allclose(fc2_mse, expected_fc2_mse, rtol=1e-5) diff --git a/tests/metrics/test_stochastic_hidden_acts_recon.py b/tests/metrics/test_stochastic_hidden_acts_recon.py deleted file mode 100644 index 021146ecf..000000000 --- a/tests/metrics/test_stochastic_hidden_acts_recon.py +++ /dev/null @@ -1,116 +0,0 @@ -from unittest.mock import patch - -import torch -from torch import Tensor - -from spd.configs import SamplingType -from spd.metrics import stochastic_hidden_acts_recon_loss -from spd.models.components import ComponentsMaskInfo, make_mask_infos -from spd.routing import Router -from tests.metrics.fixtures import make_two_layer_component_model - - -class TestStochasticHiddenActsReconLoss: - def test_manual_calculation(self: object) -> None: - """Test stochastic hidden acts recon loss with manual calculation. - - For a two-layer model (batch -> fc1 -> hidden -> fc2 -> output): - - pre_weight_acts["fc1"] is the batch input (always same, MSE = 0) - - pre_weight_acts["fc2"] is the hidden activation (differs with stochastic masks) - - This structure provides both a sanity check and meaningful test of the metric. - """ - torch.manual_seed(42) - - # Create 2-layer model: 2 -> 3 -> 2 - fc1_weight = torch.randn(3, 2, dtype=torch.float32) - fc2_weight = torch.randn(2, 3, dtype=torch.float32) - model = make_two_layer_component_model(weight1=fc1_weight, weight2=fc2_weight) - - V1 = model.components["fc1"].V - U1 = model.components["fc1"].U - - batch = torch.randn(1, 2, dtype=torch.float32) - - # Get target pre_weight_acts (activations before each weight matrix) - # fc1: input is batch - # fc2: input is output of fc1 - target_pre_weight_acts = model(batch, cache_type="input").cache - - ci = { - "fc1": torch.tensor([[0.8]], dtype=torch.float32), - "fc2": torch.tensor([[0.7]], dtype=torch.float32), - } - - # Define deterministic masks for n_mask_samples=2 - sample_masks_fc1 = [ - torch.tensor([[0.9]], dtype=torch.float32), - torch.tensor([[0.7]], dtype=torch.float32), - ] - sample_masks_fc2 = [ - torch.tensor([[0.85]], dtype=torch.float32), - torch.tensor([[0.65]], dtype=torch.float32), - ] - - # Mock calc_stochastic_component_mask_info to return our deterministic masks - call_count = [0] - - def mock_calc_stochastic_component_mask_info( - causal_importances: dict[str, Tensor], # pyright: ignore[reportUnusedParameter] - component_mask_sampling: SamplingType, # pyright: ignore[reportUnusedParameter] - weight_deltas: dict[str, Tensor] | None, # pyright: ignore[reportUnusedParameter] - router: Router, # pyright: ignore[reportUnusedParameter] - ) -> dict[str, ComponentsMaskInfo]: - idx = call_count[0] % len(sample_masks_fc1) - call_count[0] += 1 - masks = {"fc1": sample_masks_fc1[idx], "fc2": sample_masks_fc2[idx]} - - return make_mask_infos( - component_masks=masks, routing_masks="all", weight_deltas_and_masks=None - ) - - with patch( - "spd.metrics.stochastic_hidden_acts_recon_loss.calc_stochastic_component_mask_info", - side_effect=mock_calc_stochastic_component_mask_info, - ): - # Calculate expected loss manually - sum_mse = 0.0 - n_examples = 0 - - for mask1 in sample_masks_fc1: - # Stochastic forward pass for fc1 - # pre_weight_acts["fc1"] is always batch (same as target) - stoch_fc1_input = batch - - # pre_weight_acts["fc2"] is output of masked fc1 - stoch_fc2_input = batch @ (V1 * mask1 @ U1) - - # MSE for fc1 input (should be 0 - good sanity check!) - mse_fc1 = torch.nn.functional.mse_loss( - stoch_fc1_input, target_pre_weight_acts["fc1"], reduction="sum" - ) - assert mse_fc1.item() == 0.0, f"MSE for fc1 input should be 0, got {mse_fc1.item()}" - - # MSE for fc2 input (the actual meaningful comparison) - mse_fc2 = torch.nn.functional.mse_loss( - stoch_fc2_input, target_pre_weight_acts["fc2"], reduction="sum" - ) - - sum_mse += mse_fc1.item() + mse_fc2.item() - n_examples += stoch_fc1_input.numel() + stoch_fc2_input.numel() - - expected_loss = sum_mse / n_examples - - actual_loss = stochastic_hidden_acts_recon_loss( - model=model, - sampling="continuous", - n_mask_samples=2, - batch=batch, - pre_weight_acts=target_pre_weight_acts, - ci=ci, - weight_deltas=None, - ) - - assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( - f"Expected {expected_loss}, got {actual_loss}" - ) diff --git a/tests/metrics/test_stochastic_recon_loss.py b/tests/metrics/test_stochastic_recon_loss.py deleted file mode 100644 index 594b55a7f..000000000 --- a/tests/metrics/test_stochastic_recon_loss.py +++ /dev/null @@ -1,90 +0,0 @@ -from unittest.mock import patch - -import torch -from torch import Tensor - -from spd.configs import SamplingType -from spd.metrics import stochastic_recon_loss -from spd.models.components import ComponentsMaskInfo, make_mask_infos -from spd.routing import Router -from tests.metrics.fixtures import make_one_layer_component_model - - -class TestStochasticReconLoss: - def test_manual_calculation(self: object) -> None: - """Test stochastic reconstruction with manual calculation. - - Mocks calc_stochastic_component_mask_info to use deterministic masks. - """ - torch.manual_seed(42) - - fc_weight = torch.randn(2, 2, dtype=torch.float32) - model = make_one_layer_component_model(weight=fc_weight) - - V = model.components["fc"].V - U = model.components["fc"].U - - batch = torch.randn(1, 2, dtype=torch.float32) - target_out = torch.randn(1, 2, dtype=torch.float32) - - ci = {"fc": torch.tensor([[0.8]], dtype=torch.float32)} - - # Define deterministic masks for our samples - # n_mask_samples=2, so we'll have 2 samples - sample_masks = [ - torch.tensor([[0.9]], dtype=torch.float32), - torch.tensor([[0.7]], dtype=torch.float32), - ] - - # Mock calc_stochastic_component_mask_info to return our deterministic masks - call_count = [0] - - def mock_calc_stochastic_component_mask_info( - causal_importances: dict[str, Tensor], # pyright: ignore[reportUnusedParameter] - component_mask_sampling: SamplingType, # pyright: ignore[reportUnusedParameter] - router: Router, # pyright: ignore[reportUnusedParameter] - weight_deltas: dict[str, Tensor] | None, # pyright: ignore[reportUnusedParameter] - ) -> dict[str, ComponentsMaskInfo]: - idx = call_count[0] % len(sample_masks) - call_count[0] += 1 - masks = {"fc": sample_masks[idx]} - - return make_mask_infos( - component_masks=masks, - routing_masks="all", - weight_deltas_and_masks=None, - ) - - with patch( - "spd.metrics.stochastic_recon_loss.calc_stochastic_component_mask_info", - side_effect=mock_calc_stochastic_component_mask_info, - ): - # Calculate expected loss manually - sum_loss = 0.0 - n_examples = 0 - - for mask in sample_masks: - # Manually calculate forward pass: out = batch @ (V * mask @ U) - masked_component = V * mask @ U - out = batch @ masked_component - loss = torch.nn.functional.mse_loss(out, target_out, reduction="sum") - sum_loss += loss.item() - n_examples += out.numel() - - expected_loss = sum_loss / n_examples - - # Calculate actual loss - actual_loss = stochastic_recon_loss( - model=model, - sampling="continuous", - n_mask_samples=2, - output_loss_type="mse", - batch=batch, - target_out=target_out, - ci=ci, - weight_deltas=None, - ) - - assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( - f"Expected {expected_loss}, got {actual_loss}" - ) diff --git a/tests/scripts_run/test_grid_search.py b/tests/scripts_run/test_grid_search.py index a5f59c65d..dadf92d3a 100644 --- a/tests/scripts_run/test_grid_search.py +++ b/tests/scripts_run/test_grid_search.py @@ -323,6 +323,11 @@ def test_tms_config_with_loss_sweep(self): "C": 10, "n_mask_samples": 1, "target_module_patterns": ["linear1"], + "ci_config": { + "mode": "layerwise", + "fn_type": "mlp", + "hidden_dims": [16], + }, "loss_metric_configs": [ { "classname": "ImportanceMinimalityLoss", @@ -377,6 +382,11 @@ def test_lm_config_with_loss_sweep(self): "C": 10, "n_mask_samples": 1, "target_module_patterns": ["transformer"], + "ci_config": { + "mode": "layerwise", + "fn_type": "vector_mlp", + "hidden_dims": [12], + }, "loss_metric_configs": [ { "classname": "ImportanceMinimalityLoss", @@ -442,6 +452,11 @@ def test_full_sweep_workflow(self): "C": 10, "n_mask_samples": 1, "target_module_patterns": ["linear1"], + "ci_config": { + "mode": "layerwise", + "fn_type": "mlp", + "hidden_dims": [16], + }, "loss_metric_configs": [ { "classname": "ImportanceMinimalityLoss", diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 22d60a8f0..6835f8a1d 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -10,7 +10,9 @@ from spd.configs import ( Config, + GlobalCiConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ScheduleConfig, TMSTaskConfig, @@ -24,9 +26,13 @@ from spd.models.components import ( ComponentsMaskInfo, EmbeddingComponents, + GlobalCiFnWrapper, + GlobalSharedMLPCiFn, + GlobalSharedTransformerCiFn, LinearComponents, MLPCiFn, ParallelLinear, + TargetLayerConfig, VectorMLPCiFn, VectorSharedMLPCiFn, make_mask_infos, @@ -89,8 +95,7 @@ def test_correct_parameters_require_grad(): ModulePathInfo(module_path="conv1d1", C=10), ModulePathInfo(module_path="conv1d2", C=5), ], - ci_fn_type="mlp", - ci_fn_hidden_dims=[4], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -142,8 +147,7 @@ def test_from_run_info(): ModulePatternInfoConfig(module_pattern="conv1d2", C=4), ], identity_module_info=[ModulePatternInfoConfig(module_pattern="linear1", C=4)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[4], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), batch_size=1, steps=1, lr_schedule=ScheduleConfig(start_val=1e-3), @@ -172,8 +176,7 @@ def test_from_run_info(): cm = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_config=config.ci_config, pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) @@ -279,8 +282,7 @@ def test_full_weight_delta_matches_target_behaviour(): cm = ComponentModel( target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], - ci_fn_type="mlp", - ci_fn_hidden_dims=[4], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -311,8 +313,7 @@ def test_input_cache_captures_pre_weight_input(): cm = ComponentModel( target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=2) for p in target_module_paths], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -346,8 +347,7 @@ def test_weight_deltas(): cm = ComponentModel( target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=3) for p in target_module_paths], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -381,8 +381,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, module_path_info=[ModulePathInfo(module_path="linear", C=C)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -437,8 +436,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, module_path_info=[ModulePathInfo(module_path="linear.pre_identity", C=C)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -487,8 +485,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, module_path_info=[ModulePathInfo(module_path="linear", C=C)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -524,3 +521,854 @@ def forward(self, x: Tensor) -> Tensor: # but it should be the same for the second example (where it's not routed to components) assert torch.allclose(cm_routed_out[1], target_out[1]) + + +def test_checkpoint_ci_config_mismatch_global_to_layerwise(): + """Test that loading a global CI checkpoint with layerwise config fails.""" + target_model = SimpleTestModel() + target_model.eval() + target_model.requires_grad_(False) + + with tempfile.TemporaryDirectory() as tmp_dir: + base_dir = Path(tmp_dir) + base_model_dir = base_dir / "test_model" + base_model_dir.mkdir(parents=True, exist_ok=True) + comp_model_dir = base_dir / "comp_model" + comp_model_dir.mkdir(parents=True, exist_ok=True) + + base_model_path = base_model_dir / "model.pth" + save_file(target_model.state_dict(), base_model_path) + + # Create and save a component model with GLOBAL CI + config_global = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + module_path_info = expand_module_patterns(target_model, config_global.all_module_info) + cm_global = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_config=config_global.ci_config, + pretrained_model_output_attr=config_global.pretrained_model_output_attr, + sigmoid_type=config_global.sigmoid_type, + ) + + # Save global CI checkpoint + global_checkpoint_path = comp_model_dir / "global_model.pth" + save_file(cm_global.state_dict(), global_checkpoint_path) + save_file(config_global.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + # Now try to load it with LAYERWISE config - should fail + config_layerwise = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + # Override the checkpoint path and config in the directory + save_file(config_layerwise.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + cm_run_info = SPDRunInfo.from_path(global_checkpoint_path) + # Update config to layerwise after loading run_info + cm_run_info.config = config_layerwise + + with pytest.raises( + AssertionError, + match="Config specifies layerwise CI but checkpoint has no ci_fn._ci_fns keys", + ): + ComponentModel.from_run_info(cm_run_info) + + +def test_checkpoint_ci_config_mismatch_layerwise_to_global(): + """Test that loading a layerwise CI checkpoint with global config fails.""" + target_model = SimpleTestModel() + target_model.eval() + target_model.requires_grad_(False) + + with tempfile.TemporaryDirectory() as tmp_dir: + base_dir = Path(tmp_dir) + base_model_dir = base_dir / "test_model" + base_model_dir.mkdir(parents=True, exist_ok=True) + comp_model_dir = base_dir / "comp_model" + comp_model_dir.mkdir(parents=True, exist_ok=True) + + base_model_path = base_model_dir / "model.pth" + save_file(target_model.state_dict(), base_model_path) + + # Create and save a component model with LAYERWISE CI + config_layerwise = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + module_path_info = expand_module_patterns(target_model, config_layerwise.all_module_info) + cm_layerwise = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_config=config_layerwise.ci_config, + pretrained_model_output_attr=config_layerwise.pretrained_model_output_attr, + sigmoid_type=config_layerwise.sigmoid_type, + ) + + # Save layerwise CI checkpoint + layerwise_checkpoint_path = comp_model_dir / "layerwise_model.pth" + save_file(cm_layerwise.state_dict(), layerwise_checkpoint_path) + save_file(config_layerwise.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + # Now try to load it with GLOBAL config - should fail + config_global = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + # Override the checkpoint path and config in the directory + save_file(config_global.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + cm_run_info = SPDRunInfo.from_path(layerwise_checkpoint_path) + # Update config to global after loading run_info + cm_run_info.config = config_global + + with pytest.raises( + AssertionError, + match="Config specifies global CI but checkpoint has no ci_fn._global_ci_fn keys", + ): + ComponentModel.from_run_info(cm_run_info) + + +# ============================================================================= +# Global CI Function Tests +# ============================================================================= + + +@pytest.mark.parametrize("hidden_dims", [[], [8], [16, 8]]) +def test_global_shared_mlp_ci_fn_shapes_and_values(hidden_dims: list[int]): + """Test GlobalSharedMLPCiFn produces correct output shapes and valid values.""" + layer_configs = { + "layer1": (10, 5), # (input_dim, C) + "layer2": (20, 3), + "layer3": (15, 7), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=hidden_dims) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 20), + "layer3": torch.randn(BATCH_SIZE, 15), + } + outputs = ci_fn(inputs) + + # Check shapes + assert outputs["layer1"].shape == (BATCH_SIZE, 5) + assert outputs["layer2"].shape == (BATCH_SIZE, 3) + assert outputs["layer3"].shape == (BATCH_SIZE, 7) + + # Check values are valid (not NaN, not Inf) + for name, out in outputs.items(): + assert torch.isfinite(out).all(), f"Output {name} contains NaN or Inf" + + +def test_global_shared_mlp_ci_fn_sorted_layer_order(): + """Test that GlobalSharedMLPCiFn uses sorted layer ordering for determinism.""" + layer_configs = { + "z_layer": (5, 2), + "a_layer": (10, 3), + "m_layer": (8, 4), + } + + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + # Layer order should be sorted alphabetically for deterministic concat/split + assert ci_fn.layer_order == ["a_layer", "m_layer", "z_layer"] + assert ci_fn.split_sizes == [3, 4, 2] # C values in sorted order + + +def test_global_shared_mlp_ci_fn_different_inputs_produce_different_outputs(): + """Test that different inputs produce different CI outputs (not constant function).""" + layer_configs = { + "layer1": (10, 5), + "layer2": (8, 3), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + inputs1 = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 8), + } + inputs2 = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 8), + } + + outputs1 = ci_fn(inputs1) + outputs2 = ci_fn(inputs2) + + # Outputs should differ for different inputs + assert not torch.allclose(outputs1["layer1"], outputs2["layer1"]) + assert not torch.allclose(outputs1["layer2"], outputs2["layer2"]) + + +def test_global_shared_mlp_ci_fn_gradient_flow(): + """Test that gradients flow through GlobalSharedMLPCiFn and are meaningful.""" + layer_configs = { + "layer1": (10, 5), + "layer2": (8, 3), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10, requires_grad=True), + "layer2": torch.randn(BATCH_SIZE, 8, requires_grad=True), + } + outputs = ci_fn(inputs) + + loss = torch.stack([out.sum() for out in outputs.values()]).sum() + loss.backward() + + # Check gradients exist for inputs and are meaningful + assert inputs["layer1"].grad is not None + assert inputs["layer2"].grad is not None + assert torch.isfinite(inputs["layer1"].grad).all() + assert torch.isfinite(inputs["layer2"].grad).all() + assert inputs["layer1"].grad.abs().sum() > 0, "Input gradients should be non-zero" + assert inputs["layer2"].grad.abs().sum() > 0, "Input gradients should be non-zero" + + # Check gradients exist for parameters and are meaningful + for name, param in ci_fn.named_parameters(): + assert param.grad is not None, f"Param {name} has no gradient" + assert torch.isfinite(param.grad).all(), f"Param {name} has NaN/Inf gradient" + assert param.grad.abs().sum() > 0, f"Param {name} has zero gradient" + + +def test_global_shared_mlp_ci_fn_with_seq_dim(): + """Test GlobalSharedMLPCiFn with sequence dimension produces valid outputs.""" + seq_len = 5 + layer_configs = { + "layer1": (10, 4), + "layer2": (8, 3), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, seq_len, 10), + "layer2": torch.randn(BATCH_SIZE, seq_len, 8), + } + outputs = ci_fn(inputs) + + # Check shapes + assert outputs["layer1"].shape == (BATCH_SIZE, seq_len, 4) + assert outputs["layer2"].shape == (BATCH_SIZE, seq_len, 3) + + # Check values are valid + for name, out in outputs.items(): + assert torch.isfinite(out).all(), f"Output {name} contains NaN or Inf" + + +def test_global_shared_mlp_ci_fn_single_component(): + """Test GlobalSharedMLPCiFn with C=1 edge case.""" + layer_configs = { + "layer1": (10, 1), + "layer2": (8, 1), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[4]) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 8), + } + outputs = ci_fn(inputs) + + assert outputs["layer1"].shape == (BATCH_SIZE, 1) + assert outputs["layer2"].shape == (BATCH_SIZE, 1) + assert torch.isfinite(outputs["layer1"]).all() + assert torch.isfinite(outputs["layer2"]).all() + + +def test_global_shared_mlp_ci_fn_single_layer(): + """Test GlobalSharedMLPCiFn with single layer edge case.""" + layer_configs = {"only_layer": (10, 5)} + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[8]) + + inputs = {"only_layer": torch.randn(BATCH_SIZE, 10)} + outputs = ci_fn(inputs) + + assert outputs["only_layer"].shape == (BATCH_SIZE, 5) + assert torch.isfinite(outputs["only_layer"]).all() + + +def test_global_shared_transformer_ci_fn_shapes_and_values(): + """Test GlobalSharedTransformerCiFn produces correct output shapes and valid values.""" + layer_configs = { + "layer1": TargetLayerConfig(input_dim=10, C=5), + "layer2": TargetLayerConfig(input_dim=20, C=3), + "layer3": TargetLayerConfig(input_dim=15, C=7), + } + ci_fn = GlobalSharedTransformerCiFn( + target_model_layer_configs=layer_configs, + d_model=8, + n_layers=2, + n_heads=2, + mlp_hidden_dims=[16], + ) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 20), + "layer3": torch.randn(BATCH_SIZE, 15), + } + outputs = ci_fn(inputs) + + # Check shapes + assert outputs["layer1"].shape == (BATCH_SIZE, 5) + assert outputs["layer2"].shape == (BATCH_SIZE, 3) + assert outputs["layer3"].shape == (BATCH_SIZE, 7) + + # Check values are valid (not NaN, not Inf) + for name, out in outputs.items(): + assert torch.isfinite(out).all(), f"Output {name} contains NaN or Inf" + + +def test_global_shared_transformer_ci_fn_with_seq_dim(): + """Test GlobalSharedTransformerCiFn with sequence dimension produces valid outputs.""" + seq_len = 5 + layer_configs = { + "layer1": TargetLayerConfig(input_dim=10, C=4), + "layer2": TargetLayerConfig(input_dim=8, C=3), + } + ci_fn = GlobalSharedTransformerCiFn( + target_model_layer_configs=layer_configs, + d_model=8, + n_layers=3, + n_heads=2, + mlp_hidden_dims=[16], + ) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, seq_len, 10), + "layer2": torch.randn(BATCH_SIZE, seq_len, 8), + } + outputs = ci_fn(inputs) + + # Check shapes + assert outputs["layer1"].shape == (BATCH_SIZE, seq_len, 4) + assert outputs["layer2"].shape == (BATCH_SIZE, seq_len, 3) + + # Check values are valid + for name, out in outputs.items(): + assert torch.isfinite(out).all(), f"Output {name} contains NaN or Inf" + + +def test_component_model_with_global_ci(): + """Test ComponentModel instantiation and forward with global CI config.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + assert isinstance(cm.ci_fn._global_ci_fn, GlobalSharedMLPCiFn) + + # Forward pass should work and match target + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + out = cm(token_ids) + torch.testing.assert_close(out, target_model(token_ids)) + + +def test_component_model_global_ci_calc_causal_importances(): + """Test causal importance calculation with global CI produces valid bounded outputs.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=False, + ) + + for path in target_module_paths: + # Check shapes + assert ci_outputs.lower_leaky[path].shape == (BATCH_SIZE, C) + assert ci_outputs.upper_leaky[path].shape == (BATCH_SIZE, C) + assert ci_outputs.pre_sigmoid[path].shape == (BATCH_SIZE, C) + + # Check bounds (leaky sigmoids allow values slightly outside [0, 1]) + # lower_leaky: bounded to [0, 1], can be negative with small leak + # upper_leaky: bounded to [0, inf), can exceed 1 with small leak + assert (ci_outputs.lower_leaky[path] >= 0).all(), f"{path} lower_leaky < 0" + assert (ci_outputs.lower_leaky[path] <= 1.0).all(), f"{path} lower_leaky > 1" + assert (ci_outputs.upper_leaky[path] >= 0).all(), f"{path} upper_leaky < 0" + # upper_leaky can exceed 1.0 due to leaky behavior (1 + alpha*(x-1) when x>1) + + # Check values are finite + assert torch.isfinite(ci_outputs.pre_sigmoid[path]).all(), f"{path} pre_sigmoid has NaN/Inf" + + +def test_component_model_global_ci_different_inputs_different_ci(): + """Test that different inputs produce different CI values (CI is input-dependent).""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + # Two different token inputs + token_ids_1 = torch.tensor([0, 1], dtype=torch.long) + token_ids_2 = torch.tensor([2, 3], dtype=torch.long) + + _, cache1 = cm(token_ids_1, cache_type="input") + _, cache2 = cm(token_ids_2, cache_type="input") + + ci1 = cm.calc_causal_importances(cache1, sampling="continuous") + ci2 = cm.calc_causal_importances(cache2, sampling="continuous") + + # CI values should differ for different inputs + for path in target_module_paths: + assert not torch.allclose(ci1.pre_sigmoid[path], ci2.pre_sigmoid[path]), ( + f"CI for {path} should differ for different inputs" + ) + + +def test_component_model_global_ci_binomial_sampling(): + """Test global CI with binomial sampling produces valid binary masks.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint(0, target_model.embed.num_embeddings, size=(BATCH_SIZE,)) + _, cache = cm(token_ids, cache_type="input") + + ci = cm.calc_causal_importances(cache, sampling="binomial") + + for path in target_module_paths: + assert ci.lower_leaky[path].shape == (BATCH_SIZE, C) + assert torch.isfinite(ci.lower_leaky[path]).all() + assert torch.isfinite(ci.upper_leaky[path]).all() + + +def test_component_model_global_ci_with_embeddings(): + """Test global CI with embedding layers produces valid outputs.""" + target_model = tiny_target() + + target_module_paths = ["embed", "mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=False, + ) + + # Check all layers including embedding + for path in target_module_paths: + assert ci_outputs.lower_leaky[path].shape == (BATCH_SIZE, C) + assert (ci_outputs.lower_leaky[path] >= 0).all() + assert (ci_outputs.lower_leaky[path] <= 1.0).all() + assert torch.isfinite(ci_outputs.pre_sigmoid[path]).all() + + +def test_component_model_global_ci_gradient_flow(): + """Test gradient flow through global CI - gradients are non-zero and finite.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=False, + ) + + ci_loss = torch.stack([ci.sum() for ci in ci_outputs.lower_leaky.values()]).sum() + ci_loss.backward() + + # Check that global CI function has meaningful gradients + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + for name, param in cm.ci_fn._global_ci_fn.named_parameters(): + assert param.grad is not None, f"Param {name} has no gradient" + assert torch.isfinite(param.grad).all(), f"Param {name} has NaN/Inf gradient" + assert param.grad.abs().sum() > 0, f"Param {name} has zero gradient" + + +def test_component_model_global_ci_detach_inputs_blocks_gradients(): + """Test that detach_inputs=True blocks gradient flow to CI function.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + # With detach_inputs=True, gradients should still flow to CI fn params + # but from the CI loss, not from upstream + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=True, # Detach inputs + ) + + ci_loss = torch.stack([ci.sum() for ci in ci_outputs.lower_leaky.values()]).sum() + ci_loss.backward() + + # CI function should still get gradients (from its own computation) + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + for param in cm.ci_fn._global_ci_fn.parameters(): + assert param.grad is not None + + +def test_component_model_global_ci_masking_zeros(): + """Test that zero masks actually zero out component contributions.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + weight_deltas = cm.calc_weight_deltas() + + # All ones mask - should match target + all_ones_masks = {name: torch.ones(BATCH_SIZE, C) for name in target_module_paths} + weight_deltas_and_masks_ones = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos_ones = make_mask_infos( + all_ones_masks, weight_deltas_and_masks=weight_deltas_and_masks_ones + ) + out_ones = cm(token_ids, mask_infos=mask_infos_ones) + + # All zeros mask - should be different from all ones + all_zeros_masks = {name: torch.zeros(BATCH_SIZE, C) for name in target_module_paths} + weight_deltas_and_masks_zeros = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos_zeros = make_mask_infos( + all_zeros_masks, weight_deltas_and_masks=weight_deltas_and_masks_zeros + ) + out_zeros = cm(token_ids, mask_infos=mask_infos_zeros) + + # Outputs should differ + assert not torch.allclose(out_ones, out_zeros), ( + "Zero masks should produce different output than one masks" + ) + + +def test_component_model_global_ci_partial_masking(): + """Test that partial masks produce outputs between fully masked extremes.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + weight_deltas = cm.calc_weight_deltas() + + # Partial mask (0.5 for all) + partial_masks = {name: torch.full((BATCH_SIZE, C), 0.5) for name in target_module_paths} + weight_deltas_and_masks = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos = make_mask_infos(partial_masks, weight_deltas_and_masks=weight_deltas_and_masks) + out_partial = cm(token_ids, mask_infos=mask_infos) + + # Should produce valid output + assert torch.isfinite(out_partial).all(), "Partial masking produced NaN/Inf" + + +def test_component_model_global_ci_weight_deltas_all_ones_matches_target(): + """Test that all-ones mask with weight deltas matches target model output.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + weight_deltas = cm.calc_weight_deltas() + component_masks = {name: torch.ones(BATCH_SIZE, C) for name in target_module_paths} + weight_deltas_and_masks = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos = make_mask_infos(component_masks, weight_deltas_and_masks=weight_deltas_and_masks) + out = cm(token_ids, mask_infos=mask_infos) + + torch.testing.assert_close(out, target_model(token_ids)) + + +def test_global_ci_save_and_load(): + """Test saving and loading a model with global CI preserves functionality.""" + target_model = SimpleTestModel() + target_model.eval() + target_model.requires_grad_(False) + + with tempfile.TemporaryDirectory() as tmp_dir: + base_dir = Path(tmp_dir) + base_model_dir = base_dir / "test_model" + base_model_dir.mkdir(parents=True, exist_ok=True) + comp_model_dir = base_dir / "comp_model" + comp_model_dir.mkdir(parents=True, exist_ok=True) + + base_model_path = base_model_dir / "model.pth" + save_file(target_model.state_dict(), base_model_path) + + config = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[8]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + cm = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_config=config.ci_config, + pretrained_model_output_attr=config.pretrained_model_output_attr, + sigmoid_type=config.sigmoid_type, + ) + + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + + save_file(cm.state_dict(), comp_model_dir / "model.pth") + save_file(config.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + # Load and verify + cm_run_info = SPDRunInfo.from_path(comp_model_dir / "model.pth") + cm_loaded = ComponentModel.from_run_info(cm_run_info) + + assert isinstance(cm_loaded.ci_fn, GlobalCiFnWrapper) + + # Verify state dict matches + for k, v in cm_loaded.state_dict().items(): + torch.testing.assert_close(v, cm.state_dict()[k]) + + # Verify global CI function weights specifically + global_ci_fn = cm.ci_fn._global_ci_fn + global_ci_fn_loaded = cm_loaded.ci_fn._global_ci_fn + assert isinstance(global_ci_fn, GlobalSharedMLPCiFn) + assert isinstance(global_ci_fn_loaded, GlobalSharedMLPCiFn) + assert global_ci_fn_loaded.layer_order == global_ci_fn.layer_order + for p1, p2 in zip(global_ci_fn.parameters(), global_ci_fn_loaded.parameters(), strict=True): + torch.testing.assert_close(p1, p2) + + # Verify global CI function produces same outputs + test_acts = { + name: torch.randn(BATCH_SIZE, global_ci_fn.layer_configs[name][0]) + for name in global_ci_fn.layer_order + } + outputs_orig = global_ci_fn(test_acts) + outputs_loaded = global_ci_fn_loaded(test_acts) + for name in global_ci_fn.layer_order: + torch.testing.assert_close(outputs_orig[name], outputs_loaded[name]) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 3a43935c3..3b00866d7 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -18,8 +18,11 @@ "seed": 0, "C": 3, "n_mask_samples": 1, - "ci_fn_type": "vector_mlp", - "ci_fn_hidden_dims": [2], + "ci_config": { + "mode": "layerwise", + "fn_type": "vector_mlp", + "hidden_dims": [2], + }, "sigmoid_type": "leaky_hard", "target_module_patterns": ["model.layers.0.mlp.gate_proj"], # --- Loss metrics --- diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index b66cfa31d..33ce346dc 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -8,6 +8,7 @@ Config, FaithfulnessLossConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig, @@ -35,8 +36,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="vector_mlp", - ci_fn_hidden_dims=[128], + ci_config=LayerwiseCiConfig(fn_type="vector_mlp", hidden_dims=[128]), module_info=[ ModulePatternInfoConfig(module_pattern="transformer.h.2.attn.c_attn", C=10), ModulePatternInfoConfig(module_pattern="transformer.h.3.mlp.c_fc", C=10), diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 582a6aea5..f1781d02b 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -8,6 +8,7 @@ FaithfulnessLossConfig, IHTaskConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ScheduleConfig, StochasticHiddenActsReconLossConfig, @@ -50,8 +51,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="vector_mlp", - ci_fn_hidden_dims=[128], + ci_config=LayerwiseCiConfig(fn_type="vector_mlp", hidden_dims=[128]), module_info=[ ModulePatternInfoConfig(module_pattern="blocks.*.attn.q_proj", C=10), ModulePatternInfoConfig(module_pattern="blocks.*.attn.k_proj", C=10), diff --git a/tests/test_pgd_source_sync_distributed.py b/tests/test_pgd_source_sync_distributed.py index 8af4251be..17039db72 100644 --- a/tests/test_pgd_source_sync_distributed.py +++ b/tests/test_pgd_source_sync_distributed.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch import Tensor -from spd.configs import PGDReconLossConfig +from spd.configs import LayerwiseCiConfig, PGDReconLossConfig from spd.metrics.pgd_utils import pgd_masked_recon_loss_update from spd.models.component_model import ComponentModel from spd.routing import AllLayersRouter @@ -56,8 +56,7 @@ def _make_component_model(fc_weight: Tensor) -> ComponentModel: return ComponentModel( target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index f42bbe1d8..94ac00a29 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -4,6 +4,7 @@ Config, FaithfulnessLossConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ResidMLPTaskConfig, ScheduleConfig, @@ -43,8 +44,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="mlp", - ci_fn_hidden_dims=[8], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), loss_metric_configs=[ ImportanceMinimalityLossConfig( coeff=3e-3, diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 69a546f6e..d7acba607 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -5,7 +5,15 @@ from jaxtyping import Float from torch import Tensor -from spd.configs import UniformKSubsetRoutingConfig +from spd.configs import ( + AdamPGDConfig, + LayerwiseCiConfig, + PersistentPGDReconLossConfig, + ScheduleConfig, + SignPGDConfig, + SingleSourceScope, + UniformKSubsetRoutingConfig, +) from spd.metrics import ( ci_masked_recon_layerwise_loss, ci_masked_recon_loss, @@ -17,6 +25,7 @@ stochastic_recon_subset_loss, ) from spd.models.component_model import ComponentModel +from spd.persistent_pgd import PersistentPGDState from spd.utils.module_utils import ModulePathInfo @@ -30,6 +39,23 @@ def forward(self, x: Tensor) -> Tensor: return self.fc(x) +class TinySeqModel(nn.Module): + """A simple sequence model that applies a linear layer to each position. + + Input shape: (batch, seq_len, d_in) + Output shape: (batch, seq_len, d_out) + """ + + def __init__(self, d_in: int, d_out: int) -> None: + super().__init__() + self.fc = nn.Linear(d_in, d_out, bias=False) + + @override + def forward(self, x: Tensor) -> Tensor: + # x: (batch, seq_len, d_in) -> (batch, seq_len, d_out) + return self.fc(x) + + def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: d_out, d_in = weight.shape target = TinyLinearModel(d_in=d_in, d_out=d_out) @@ -40,8 +66,26 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel comp_model = ComponentModel( target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + return comp_model + + +def _make_seq_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: + """Create a ComponentModel from TinySeqModel for 3D (batch, seq, hidden) shaped data.""" + d_out, d_in = weight.shape + target = TinySeqModel(d_in=d_in, d_out=d_out) + with torch.no_grad(): + target.fc.weight.copy_(weight) + target.requires_grad_(False) + + comp_model = ComponentModel( + target_model=target, + module_path_info=[ModulePathInfo(module_path="fc", C=1)], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -656,3 +700,252 @@ def test_subset_stochastic_variability(self: object) -> None: # All should be valid assert all(loss >= 0.0 for loss in losses) + + +class TestPersistentPGDReconLoss: + def test_basic_forward_and_state_update(self: object) -> None: + """Test that persistent PGD computes loss and updates state.""" + fc_weight = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + model = _make_component_model(weight=fc_weight) + + # Use (batch, seq) shaped data to match PersistentPGD's expectations + batch = torch.tensor([[1.0, 2.0]], dtype=torch.float32) + target_out = torch.tensor([[1.0, 2.0]], dtype=torch.float32) + # CI needs (batch, seq, C) shape for PersistentPGD + ci = {"fc": torch.tensor([[[0.5], [0.5]]], dtype=torch.float32)} + + cfg = PersistentPGDReconLossConfig( + optimizer=SignPGDConfig(lr_schedule=ScheduleConfig(start_val=0.1)), + scope=SingleSourceScope(), + ) + + # Initialize state + state = PersistentPGDState( + module_to_c=model.module_to_c, + batch_dims=(1, 2), + device="cpu", + use_delta_component=False, + cfg=cfg, + output_loss_type="mse", + ) + + # Store initial mask values + initial_sources = {k: v.clone() for k, v in state.sources.items()} + + # Compute loss and gradients + loss = state.compute_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=None, + ) + grad = state.get_grads(loss) + + # Apply PGD step + state.step(grad) + + # Loss should be non-negative + assert loss >= 0.0 + + # Masks should have been updated (not equal to initial) + for k in state.sources: + # Due to PGD step, masks should change (unless gradient is exactly 0) + assert state.sources[k].shape == initial_sources[k].shape + # Masks should still be in [0, 1] + assert torch.all(state.sources[k] >= 0.0) + assert torch.all(state.sources[k] <= 1.0) + + def test_masks_persist_across_calls(self: object) -> None: + """Test that masks persist and accumulate updates across calls.""" + fc_weight = torch.tensor([[2.0, 0.0], [0.0, 2.0]], dtype=torch.float32) + model = _make_component_model(weight=fc_weight) + + batch = torch.tensor([[1.0, 1.0]], dtype=torch.float32) + target_out = torch.tensor([[2.0, 2.0]], dtype=torch.float32) + # CI needs (batch, seq, C) shape for PersistentPGD + ci = {"fc": torch.tensor([[[0.3], [0.3]]], dtype=torch.float32)} + + cfg = PersistentPGDReconLossConfig( + optimizer=SignPGDConfig(lr_schedule=ScheduleConfig(start_val=0.1)), + scope=SingleSourceScope(), + ) + + state = PersistentPGDState( + module_to_c=model.module_to_c, + batch_dims=(1, 2), + device="cpu", + use_delta_component=False, + cfg=cfg, + output_loss_type="mse", + ) + + # Run multiple steps + sources_history = [] + for _ in range(5): + sources_history.append({k: v.clone() for k, v in state.sources.items()}) + loss = state.compute_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=None, + ) + grad = state.get_grads(loss) + state.step(grad) + assert loss >= 0.0 + + # Masks should have changed over time + # (they accumulate updates, so later masks differ from earlier ones) + for k in state.sources: + initial = sources_history[0][k] + final = state.sources[k] + # Should have changed from initial (very unlikely to be identical after 5 steps) + assert not torch.allclose(initial, final) + + def test_with_delta_component(self: object) -> None: + """Test persistent PGD with delta component enabled.""" + # Use sequence model for proper 3D shapes (batch, seq, hidden) + fc_weight = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + model = _make_seq_component_model(weight=fc_weight) + + # Input shape: (batch=1, seq=2, d_in=2) + batch = torch.tensor([[[1.0, 2.0], [0.5, 1.5]]], dtype=torch.float32) + target_out = torch.tensor([[[1.0, 2.0], [0.5, 1.5]]], dtype=torch.float32) + # CI shape: (batch=1, seq=2, C=1) + ci = {"fc": torch.tensor([[[0.5], [0.5]]], dtype=torch.float32)} + weight_deltas = model.calc_weight_deltas() + + # batch_dims for PersistentPGDState is (batch, seq) = (1, 2) + batch_dims = batch.shape[:2] + + cfg = PersistentPGDReconLossConfig( + optimizer=SignPGDConfig(lr_schedule=ScheduleConfig(start_val=0.1)), + scope=SingleSourceScope(), + ) + + # Initialize state with delta component + state = PersistentPGDState( + module_to_c=model.module_to_c, + batch_dims=batch_dims, + device="cpu", + use_delta_component=True, + cfg=cfg, + output_loss_type="mse", + ) + + # Masks should have C+1 elements when using delta component + assert state.sources["fc"].shape[-1] == model.module_to_c["fc"] + 1 + + loss = state.compute_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=weight_deltas, + ) + grad = state.get_grads(loss) + state.step(grad) + + assert loss >= 0.0 + + def test_batch_dimension(self: object) -> None: + """Test that masks broadcast correctly across batch dimension.""" + # Use sequence model for proper 3D shapes (batch, seq, hidden) + fc_weight = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + model = _make_seq_component_model(weight=fc_weight) + + # Batch of 3 examples, seq_len of 2, d_in of 2 + # Shape: (batch=3, seq=2, d_in=2) + batch = torch.tensor( + [ + [[1.0, 2.0], [0.5, 1.5]], + [[2.0, 3.0], [1.0, 2.0]], + [[0.5, 1.0], [0.25, 0.5]], + ], + dtype=torch.float32, + ) + target_out = torch.tensor( + [ + [[1.0, 2.0], [0.5, 1.5]], + [[2.0, 3.0], [1.0, 2.0]], + [[0.5, 1.0], [0.25, 0.5]], + ], + dtype=torch.float32, + ) + # CI needs (batch, seq, C) shape - (3, 2, 1) for 3 batch, 2 seq positions, 1 component + ci = { + "fc": torch.tensor( + [[[0.5], [0.5]], [[0.6], [0.6]], [[0.4], [0.4]]], dtype=torch.float32 + ) + } + + # batch_dims for PersistentPGDState is (batch, seq) = (3, 2) + batch_dims = batch.shape[:2] + + cfg = PersistentPGDReconLossConfig( + optimizer=SignPGDConfig(lr_schedule=ScheduleConfig(start_val=0.1)), + scope=SingleSourceScope(), + ) + + state = PersistentPGDState( + module_to_c=model.module_to_c, + batch_dims=batch_dims, + device="cpu", + use_delta_component=False, + cfg=cfg, + output_loss_type="mse", + ) + + # Masks should have shape (1, 1, C) for single_mask scope - single mask shared across batch + assert state.sources["fc"].shape == (1, 1, model.module_to_c["fc"]) + + loss = state.compute_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=None, + ) + grad = state.get_grads(loss) + state.step(grad) + + assert loss >= 0.0 + + def test_adam_optimizer_state(self: object) -> None: + """Test that Adam optimizer path updates internal state.""" + fc_weight = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + model = _make_component_model(weight=fc_weight) + + batch = torch.tensor([[1.0, 2.0]], dtype=torch.float32) + target_out = torch.tensor([[0.5, 1.5]], dtype=torch.float32) + # CI needs (batch, seq, C) shape for PersistentPGD + ci = {"fc": torch.tensor([[[0.4], [0.4]]], dtype=torch.float32)} + + cfg = PersistentPGDReconLossConfig( + optimizer=AdamPGDConfig( + lr_schedule=ScheduleConfig(start_val=0.05), beta1=0.9, beta2=0.999, eps=1e-8 + ), + scope=SingleSourceScope(), + ) + + state = PersistentPGDState( + module_to_c=model.module_to_c, + batch_dims=(1, 2), + device="cpu", + use_delta_component=False, + cfg=cfg, + output_loss_type="mse", + ) + + loss = state.compute_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=None, + ) + grad = state.get_grads(loss) + state.step(grad) + + assert loss >= 0.0 diff --git a/tests/test_tms.py b/tests/test_tms.py index a060e1bd0..aaba385b4 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -8,6 +8,7 @@ Config, FaithfulnessLossConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ScheduleConfig, StochasticReconLayerwiseLossConfig, @@ -47,8 +48,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="mlp", - ci_fn_hidden_dims=[8], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), module_info=[ ModulePatternInfoConfig(module_pattern="linear1", C=10), ModulePatternInfoConfig(module_pattern="linear2", C=10), diff --git a/typings/orjson/__init__.pyi b/typings/orjson/__init__.pyi new file mode 100644 index 000000000..6cce74fdc --- /dev/null +++ b/typings/orjson/__init__.pyi @@ -0,0 +1,4 @@ +from typing import Any + +def loads(data: bytes | bytearray | memoryview | str) -> Any: ... +def dumps(obj: Any) -> bytes: ... diff --git a/uv.lock b/uv.lock index b12d82509..becdf51ec 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.13.*" resolution-markers = [ "sys_platform == 'linux'", @@ -49,6 +49,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/28/a8a9fc6957b2cee8902414e41816b5ab5536ecf43c3b1843c10e82c559b2/aiohttp-3.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:a88d13e7ca367394908f8a276b89d04a3652044612b9a408a0bb22a5ed976a1a", size = 452192, upload-time = "2025-10-28T20:57:34.166Z" }, ] +[[package]] +name = "aiolimiter" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/23/b52debf471f7a1e42e362d959a3982bdcb4fe13a5d46e63d28868807a79c/aiolimiter-1.2.1.tar.gz", hash = "sha256:e02a37ea1a855d9e832252a105420ad4d15011505512a1a1d814647451b5cca9", size = 7185, upload-time = "2024-12-08T15:31:51.496Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/ba/df6e8e1045aebc4778d19b8a3a9bc1808adb1619ba94ca354d9ba17d86c3/aiolimiter-1.2.1-py3-none-any.whl", hash = "sha256:d3f249e9059a20badcb56b61601a83556133655c11d1eb3dd3e04ff069e5f3c7", size = 6711, upload-time = "2024-12-08T15:31:49.874Z" }, +] + [[package]] name = "aiosignal" version = "1.4.0" @@ -815,6 +824,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, ] +[[package]] +name = "kaleido" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl", hash = "sha256:ca6f73e7ff00aaebf2843f73f1d3bacde1930ef5041093fe76b83a15785049a7", size = 85153681, upload-time = "2021-03-08T10:27:34.202Z" }, + { url = "https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bb9a5d1f710357d5d432ee240ef6658a6d124c3e610935817b4b42da9c787c05", size = 85808197, upload-time = "2021-03-08T10:27:46.561Z" }, + { url = "https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aa21cf1bf1c78f8fa50a9f7d45e1003c387bd3d6fe0a767cfbbf344b95bdc3a8", size = 79902476, upload-time = "2021-03-08T10:27:57.364Z" }, + { url = "https://files.pythonhosted.org/packages/a1/2b/680662678a57afab1685f0c431c2aba7783ce4344f06ec162074d485d469/kaleido-0.2.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:845819844c8082c9469d9c17e42621fbf85c2b237ef8a86ec8a8527f98b6512a", size = 83711746, upload-time = "2021-03-08T10:28:08.847Z" }, + { url = "https://files.pythonhosted.org/packages/88/89/4b6f8bb3f9ab036fd4ad1cb2d628ab5c81db32ac9aa0641d7b180073ba43/kaleido-0.2.1-py2.py3-none-win32.whl", hash = "sha256:ecc72635860be616c6b7161807a65c0dbd9b90c6437ac96965831e2e24066552", size = 62312480, upload-time = "2021-03-08T10:28:18.204Z" }, + { url = "https://files.pythonhosted.org/packages/f7/9a/0408b02a4bcb3cf8b338a2b074ac7d1b2099e2b092b42473def22f7b625f/kaleido-0.2.1-py2.py3-none-win_amd64.whl", hash = "sha256:4670985f28913c2d063c5734d125ecc28e40810141bdb0a46f15b76c1d45f23c", size = 65945521, upload-time = "2021-03-08T10:28:26.823Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.9" @@ -1215,6 +1237,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/f8/e86cc449d1ca5ced4e6a483b07932444c04a0146b4ed287956496ff5b61e/openrouter-0.1.1-py3-none-any.whl", hash = "sha256:37480230413a246f15af7056f479b2c1a9ca79c88086660bc56c3f69b37847d4", size = 240934, upload-time = "2025-12-04T15:50:36.646Z" }, ] +[[package]] +name = "orjson" +version = "3.11.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/b8/333fdb27840f3bf04022d21b654a35f58e15407183aeb16f3b41aa053446/orjson-3.11.5.tar.gz", hash = "sha256:82393ab47b4fe44ffd0a7659fa9cfaacc717eb617c93cde83795f14af5c2e9d5", size = 5972347, upload-time = "2025-12-06T15:55:39.458Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/43/61a77040ce59f1569edf38f0b9faadc90c8cf7e9bec2e0df51d0132c6bb7/orjson-3.11.5-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:3b01799262081a4c47c035dd77c1301d40f568f77cc7ec1bb7db5d63b0a01629", size = 245271, upload-time = "2025-12-06T15:54:40.878Z" }, + { url = "https://files.pythonhosted.org/packages/55/f9/0f79be617388227866d50edd2fd320cb8fb94dc1501184bb1620981a0aba/orjson-3.11.5-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:61de247948108484779f57a9f406e4c84d636fa5a59e411e6352484985e8a7c3", size = 129422, upload-time = "2025-12-06T15:54:42.403Z" }, + { url = "https://files.pythonhosted.org/packages/77/42/f1bf1549b432d4a78bfa95735b79b5dac75b65b5bb815bba86ad406ead0a/orjson-3.11.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:894aea2e63d4f24a7f04a1908307c738d0dce992e9249e744b8f4e8dd9197f39", size = 132060, upload-time = "2025-12-06T15:54:43.531Z" }, + { url = "https://files.pythonhosted.org/packages/25/49/825aa6b929f1a6ed244c78acd7b22c1481fd7e5fda047dc8bf4c1a807eb6/orjson-3.11.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ddc21521598dbe369d83d4d40338e23d4101dad21dae0e79fa20465dbace019f", size = 130391, upload-time = "2025-12-06T15:54:45.059Z" }, + { url = "https://files.pythonhosted.org/packages/42/ec/de55391858b49e16e1aa8f0bbbb7e5997b7345d8e984a2dec3746d13065b/orjson-3.11.5-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cce16ae2f5fb2c53c3eafdd1706cb7b6530a67cc1c17abe8ec747f5cd7c0c51", size = 135964, upload-time = "2025-12-06T15:54:46.576Z" }, + { url = "https://files.pythonhosted.org/packages/1c/40/820bc63121d2d28818556a2d0a09384a9f0262407cf9fa305e091a8048df/orjson-3.11.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e46c762d9f0e1cfb4ccc8515de7f349abbc95b59cb5a2bd68df5973fdef913f8", size = 139817, upload-time = "2025-12-06T15:54:48.084Z" }, + { url = "https://files.pythonhosted.org/packages/09/c7/3a445ca9a84a0d59d26365fd8898ff52bdfcdcb825bcc6519830371d2364/orjson-3.11.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d7345c759276b798ccd6d77a87136029e71e66a8bbf2d2755cbdde1d82e78706", size = 137336, upload-time = "2025-12-06T15:54:49.426Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b3/dc0d3771f2e5d1f13368f56b339c6782f955c6a20b50465a91acb79fe961/orjson-3.11.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75bc2e59e6a2ac1dd28901d07115abdebc4563b5b07dd612bf64260a201b1c7f", size = 138993, upload-time = "2025-12-06T15:54:50.939Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a2/65267e959de6abe23444659b6e19c888f242bf7725ff927e2292776f6b89/orjson-3.11.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:54aae9b654554c3b4edd61896b978568c6daa16af96fa4681c9b5babd469f863", size = 141070, upload-time = "2025-12-06T15:54:52.414Z" }, + { url = "https://files.pythonhosted.org/packages/63/c9/da44a321b288727a322c6ab17e1754195708786a04f4f9d2220a5076a649/orjson-3.11.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:4bdd8d164a871c4ec773f9de0f6fe8769c2d6727879c37a9666ba4183b7f8228", size = 413505, upload-time = "2025-12-06T15:54:53.67Z" }, + { url = "https://files.pythonhosted.org/packages/7f/17/68dc14fa7000eefb3d4d6d7326a190c99bb65e319f02747ef3ebf2452f12/orjson-3.11.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a261fef929bcf98a60713bf5e95ad067cea16ae345d9a35034e73c3990e927d2", size = 151342, upload-time = "2025-12-06T15:54:55.113Z" }, + { url = "https://files.pythonhosted.org/packages/c4/c5/ccee774b67225bed630a57478529fc026eda33d94fe4c0eac8fe58d4aa52/orjson-3.11.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c028a394c766693c5c9909dec76b24f37e6a1b91999e8d0c0d5feecbe93c3e05", size = 141823, upload-time = "2025-12-06T15:54:56.331Z" }, + { url = "https://files.pythonhosted.org/packages/67/80/5d00e4155d0cd7390ae2087130637671da713959bb558db9bac5e6f6b042/orjson-3.11.5-cp313-cp313-win32.whl", hash = "sha256:2cc79aaad1dfabe1bd2d50ee09814a1253164b3da4c00a78c458d82d04b3bdef", size = 135236, upload-time = "2025-12-06T15:54:57.507Z" }, + { url = "https://files.pythonhosted.org/packages/95/fe/792cc06a84808dbdc20ac6eab6811c53091b42f8e51ecebf14b540e9cfe4/orjson-3.11.5-cp313-cp313-win_amd64.whl", hash = "sha256:ff7877d376add4e16b274e35a3f58b7f37b362abf4aa31863dadacdd20e3a583", size = 133167, upload-time = "2025-12-06T15:54:58.71Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/d158bd8b50e3b1cfdcf406a7e463f6ffe3f0d167b99634717acdaf5e299f/orjson-3.11.5-cp313-cp313-win_arm64.whl", hash = "sha256:59ac72ea775c88b163ba8d21b0177628bd015c5dd060647bbab6e22da3aad287", size = 126712, upload-time = "2025-12-06T15:54:59.892Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -1910,6 +1955,7 @@ name = "spd" version = "0.0.1" source = { editable = "." } dependencies = [ + { name = "aiolimiter" }, { name = "datasets" }, { name = "einops" }, { name = "fastapi" }, @@ -1917,9 +1963,11 @@ dependencies = [ { name = "httpx" }, { name = "ipykernel" }, { name = "jaxtyping" }, + { name = "kaleido" }, { name = "matplotlib" }, { name = "numpy" }, { name = "openrouter" }, + { name = "orjson" }, { name = "pydantic" }, { name = "python-dotenv" }, { name = "scipy" }, @@ -1949,6 +1997,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiolimiter", specifier = ">=1.2" }, { name = "datasets", specifier = ">=2.21.0" }, { name = "einops" }, { name = "fastapi" }, @@ -1956,9 +2005,11 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.0" }, { name = "ipykernel" }, { name = "jaxtyping" }, + { name = "kaleido", specifier = "==0.2.1" }, { name = "matplotlib" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" }, + { name = "orjson" }, { name = "pydantic", specifier = "<2.12" }, { name = "python-dotenv" }, { name = "scipy", specifier = ">=1.14.1" },