diff --git a/.gitignore b/.gitignore index 8d3099ac1..3581e751a 100644 --- a/.gitignore +++ b/.gitignore @@ -9,8 +9,6 @@ scripts/outputs/ **/out/ neuronpedia_outputs/ .env -.mcp.json -.cursor/ .vscode/settings.json notebooks/ @@ -179,4 +177,7 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* +*.schema.json + +.claude/worktrees \ No newline at end of file diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 000000000..700113020 --- /dev/null +++ b/.mcp.json @@ -0,0 +1,3 @@ +{ + "mcpServers": {} +} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 2407dde87..13bdcf4d0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,14 +3,18 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Environment Setup + **IMPORTANT**: Always activate the virtual environment before running Python or git operations: + ```bash source .venv/bin/activate ``` -Repo requires `.env` file with WandB credentials (see `.env.example`) +If working in a worktree, make sure there's a local `.venv` first by running `uv sync` in the worktree directory. Do NOT `cd` to the main repo — all commands (including git) should run in the worktree. +Repo requires `.env` file with WandB credentials (see `.env.example`) ## Project Overview + SPD (Stochastic Parameter Decomposition) is a research framework for analyzing neural network components and their interactions through sparse parameter decomposition techniques. - Target model parameters are decomposed as a sum of `parameter components` @@ -36,6 +40,8 @@ The codebase supports three experimental domains: TMS (Toy Model of Superpositio - `ss_llama_simple_mlp`, `ss_llama_simple_mlp-1L`, `ss_llama_simple_mlp-2L` - Llama MLP-only variants - `ss_gpt2`, `ss_gpt2_simple`, `ss_gpt2_simple_noln` - Simple Stories GPT-2 variants - `ss_gpt2_simple-1L`, `ss_gpt2_simple-2L` - GPT-2 simple layer variants + - `pile_llama_simple_mlp-2L`, `pile_llama_simple_mlp-4L`, `pile_llama_simple_mlp-12L` - Pile Llama MLP-only variants + - `pile_gpt2_simple-2L_global_reverse` - Pile GPT-2 with global reverse - `gpt2` - Standard GPT-2 - `ts` - TinyStories @@ -46,7 +52,7 @@ This repository implements methods from two key research papers on parameter dec **Stochastic Parameter Decomposition (SPD)** - [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) -- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. +- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. - Introduces the core SPD framework - Details the stochastic masking approach and optimization techniques used throughout the codebase - Useful reading for understanding the implementation details, though may be outdated. @@ -95,6 +101,7 @@ This repository implements methods from two key research papers on parameter dec ## Architecture Overview **Core SPD Framework:** + - `spd/run_spd.py` - Main SPD optimization logic called by all experiments - `spd/configs.py` - Pydantic config classes for all experiment types - `spd/registry.py` - Centralized experiment registry with all experiment configurations @@ -105,15 +112,17 @@ This repository implements methods from two key research papers on parameter dec - `spd/figures.py` - Figures for logging to WandB (e.g. CI histograms, Identity plots, etc.) **Terminology: Sources vs Masks:** + - **Sources** (`adv_sources`, `PPGDSources`, `self.sources`): The raw values that PGD optimizes adversarially. These are interpolated with CI to produce component masks: `mask = ci + (1 - ci) * source`. Used in both regular PGD (`spd/metrics/pgd_utils.py`) and persistent PGD (`spd/persistent_pgd.py`). - **Masks** (`component_masks`, `RoutingMasks`, `make_mask_infos`, `n_mask_samples`): The materialized per-component masks used during forward passes. These are produced from sources (in PGD) or from stochastic sampling, and are a general SPD concept across the whole codebase. **Experiment Structure:** Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: + - `models.py` - Experiment-specific model classes and pretrained loading - `*_decomposition.py` - Main SPD execution script -- `train_*.py` - Training script for target models +- `train_*.py` - Training script for target models - `*_config.yaml` - Configuration files - `plotting.py` - Visualization utilities @@ -127,7 +136,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: **Configuration System:** - YAML configs define all experiment parameters -- Pydantic models provide type safety and validation +- Pydantic models provide type safety and validation - WandB integration for experiment tracking and model storage - Supports both local paths and `wandb:project/runs/run_id` format for model loading - Centralized experiment registry (`spd/registry.py`) manages all experiment configurations @@ -137,8 +146,9 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: - `spd/harvest/` - Offline GPU pipeline for collecting component statistics (correlations, token stats, activation examples) - `spd/autointerp/` - LLM-based automated interpretation of components - `spd/dataset_attributions/` - Multi-GPU pipeline for computing component-to-component attribution strengths aggregated over training data -- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions}//` -- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, and `spd/dataset_attributions/CLAUDE.md` for details +- `spd/graph_interp/` - Context-aware component labeling using graph structure (attributions + correlations) +- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions,graph_interp}//` +- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, `spd/dataset_attributions/CLAUDE.md`, and `spd/graph_interp/CLAUDE.md` for details **Output Directory (`SPD_OUT_DIR`):** @@ -160,12 +170,14 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── investigate/ # Agent investigation (see investigate/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) │ ├── dataset_attributions/ # Dataset attributions (see dataset_attributions/CLAUDE.md) │ ├── harvest/ # Statistics collection (see harvest/CLAUDE.md) │ ├── postprocess/ # Unified postprocessing pipeline (harvest + attributions + autointerp) +│ ├── graph_interp/ # Context-aware interpretation (see graph_interp/CLAUDE.md) │ ├── pretrain/ # Target model pretraining (see pretrain/CLAUDE.md) │ ├── experiments/ # Experiment implementations │ │ ├── tms/ # Toy Model of Superposition @@ -201,14 +213,17 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-autointerp` | `spd/autointerp/scripts/run_slurm_cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | | `spd-postprocess` | `spd/postprocess/cli.py` | Unified postprocessing pipeline (harvest + attributions + interpret + evals) | +| `spd-graph-interp` | `spd/graph_interp/scripts/run_slurm_cli.py` | Submit graph interpretation SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | | `spd-pretrain` | `spd/pretrain/scripts/run_slurm_cli.py` | Pretrain target models | +| `spd-investigate` | `spd/investigate/scripts/run_slurm_cli.py` | Launch investigation agent | ### Files to Skip When Searching Use `spd/` as the search root (not repo root) to avoid noise. **Always skip:** + - `.venv/` - Virtual environment - `__pycache__/`, `.pytest_cache/`, `.ruff_cache/` - Build artifacts - `node_modules/` - Frontend dependencies @@ -218,27 +233,37 @@ Use `spd/` as the search root (not repo root) to avoid noise. - `wandb/` - WandB local files **Usually skip unless relevant:** + - `tests/` - Test files (unless debugging test failures) - `papers/` - Research paper drafts ### Common Call Chains **Running Experiments:** + - `spd-run` → `spd/scripts/run.py` → `spd/utils/slurm.py` → SLURM → `spd/run_spd.py` - `spd-local` → `spd/scripts/run_local.py` → `spd/run_spd.py` directly **Harvest Pipeline:** + - `spd-harvest` → `spd/harvest/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/harvest/scripts/run.py` → `spd/harvest/harvest.py` **Autointerp Pipeline:** + - `spd-autointerp` → `spd/autointerp/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → `spd/autointerp/interpret.py` **Dataset Attributions Pipeline:** + - `spd-attributions` → `spd/dataset_attributions/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/dataset_attributions/harvest.py` **Clustering Pipeline:** + - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Investigation Pipeline:** + +- `spd-investigate` → `spd/investigate/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM → `spd/investigate/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -285,6 +310,28 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Investigation (`spd-investigate`) + +Launch a Claude Code agent to investigate a specific question about an SPD model: + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What components are involved in verb agreement?" --time 4:00:00 +``` + +Each investigation: + +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates the specific research question using SPD tools via MCP +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/investigations//` + +For parallel investigations, run the command multiple times with different prompts. + +See `spd/investigate/CLAUDE.md` for details. + ### Unified Postprocessing (`spd-postprocess`) Run all postprocessing steps for a completed SPD run with a single command: @@ -295,6 +342,7 @@ spd-postprocess --config custom_config.yaml # Use custom config ``` Defaults are defined in `PostprocessConfig` (`spd/postprocess/config.py`). Pass a custom YAML/JSON config to override. Set any section to `null` to skip it: + - `attributions: null` — skip dataset attributions - `autointerp: null` — skip autointerp entirely (interpret + evals) - `autointerp.evals: null` — skip evals but still run interpret @@ -323,6 +371,7 @@ spd-run # Run all experiments ``` All `spd-run` executions: + - Submit jobs to SLURM - Create a git snapshot for reproducibility - Create W&B workspace views @@ -343,6 +392,7 @@ spd-run --experiments --sweep --n_agents [--cpu] ``` Examples: + ```bash spd-run --experiments tms_5-2 --sweep --n_agents 4 # Run TMS 5-2 sweep with 4 GPU agents spd-run --experiments resid_mlp2 --sweep --n_agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents @@ -364,6 +414,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee - Default sweep parameters are loaded from `spd/scripts/sweep_params.yaml` - You can specify a custom sweep parameters file by passing its path to `--sweep` - Sweep parameters support both experiment-specific and global configurations: + ```yaml # Global parameters applied to all experiments global: @@ -376,7 +427,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee # Experiment-specific parameters (override global) tms_5-2: seed: - values: [100, 200] # Overrides global seed + values: [100, 200] # Overrides global seed task_config: feature_probability: values: [0.05, 0.1] @@ -402,6 +453,7 @@ model = ComponentModel.from_run_info(run_info) # Local paths work too model = ComponentModel.from_pretrained("/path/to/checkpoint.pt") ``` + **Path Formats:** - WandB: `wandb:entity/project/run_id` or `wandb:entity/project/runs/run_id` @@ -415,12 +467,12 @@ Downloaded runs are cached in `SPD_OUT_DIR/runs/-/`. - This includes not setting off multiple sweeps/evals that total >8 GPUs - Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` - ## Coding Guidelines & Software Engineering Principles **This is research code, not production. Prioritize simplicity and fail-fast over defensive programming.** Core principles: + - **Fail fast** - assert assumptions, crash on violations, don't silently recover - **No legacy support** - delete unused code, don't add fallbacks for old formats or migration shims - **Narrow types** - avoid `| None` unless null is semantically meaningful; use discriminated unions over bags of optional fields @@ -451,11 +503,12 @@ config = get_config(path) value = config.key ``` - ### Tests + - The point of tests in this codebase is to ensure that the code is working as expected, not to prevent production outages - there's no deployment here. Therefore, don't worry about lots of larger integration/end-to-end tests. These often require too much overhead for what it's worth in our case, and this codebase is interactively run so often that issues will likely be caught by the user at very little cost. ### Assertions and error handling + - If you have an invariant in your head, assert it. Are you afraid to assert? Sounds like your program might already be broken. Assert, assert, assert. Never soft fail. - Do not write: `if everythingIsOk: continueHappyPath()`. Instead do `assert everythingIsOk` - You should have a VERY good reason to handle an error gracefully. If your program isn't working like it should then it shouldn't be running, you should be fixing it. @@ -463,11 +516,13 @@ value = config.key - **Write for the golden path.** Never let edge cases bloat the code. Before handling them, just raise an exception. If an edge case becomes annoying enough, we'll handle it then — but write first and foremost for the common case. ### Control Flow + - Keep I/O as high up as possible. Make as many functions as possible pure. - Prefer `match` over `if/elif/else` chains when dispatching on conditions - more declarative and makes cases explicit - If you either have (a and b) or neither, don't make them both independently optional. Instead, put them in an optional tuple ### Types, Arguments, and Defaults + - Write your invariants into types as much as possible. - Use jaxtyping for tensor shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. @@ -483,12 +538,11 @@ value = config.key - Don't use `from __future__ import annotations` — use string quotes for forward references instead. ### Tensor Operations + - Try to use einops by default for clarity. - Assert shapes liberally - Document complex tensor manipulations - - ### Comments - Comments hide sloppy code. If you feel the need to write a comment, consider that you should instead @@ -498,10 +552,11 @@ value = config.key - separate an inlined computation into a meaningfully named variable - Don’t write dialogic / narrativised comments or code. Instead, write comments that describe the code as is, not the diff you're making. Examples of narrativising comments: - - `# the function now uses y instead of x` - - `# changed to be faster` - - `# we now traverse in reverse` + - `# the function now uses y instead of x` + - `# changed to be faster` + - `# we now traverse in reverse` - Here's an example of a bad diff, where the new comment makes reference to a change in code, not just the state of the code: + ``` 95 - # Reservoir states 96 - reservoir_states: list[ReservoirState] @@ -509,21 +564,15 @@ value = config.key 96 + reservoir: TensorReservoirState ``` - -### Fire CLI Gotchas - -This codebase uses `python-fire` for CLI entry points in SLURM worker scripts. Two known gotchas: - -- **JSON args become dicts.** Fire auto-parses JSON strings into Python dicts. So `--config_json '{"n_batches": 500}'` arrives as `dict`, not `str`. Use `model_validate()` (not `model_validate_json()`), and type the param as `dict[str, Any]`. -- **Numeric-looking strings become ints/floats.** Fire parses `1234_1` (SLURM array job ID format) as an integer. This is partly why we use string-prefixed IDs everywhere (`s-`, `h-`, `da-`, `a-`) — the prefix prevents Fire from coercing them. - ### Other Important Software Development Practices + - Don't add legacy fallbacks or migration code - just change it and let old data be manually migrated if needed. -- Delete unused code. +- Delete unused code. - If an argument is always x, strongly consider removing as an argument and just inlining - **Update CLAUDE.md files** when changing code structure, adding/removing files, or modifying key interfaces. Update the CLAUDE.md in the same directory (or nearest parent) as the changed files. ### GitHub + - To view github issues and PRs, use the github cli (e.g. `gh issue view 28` or `gh pr view 30`). - When making PRs, use the github template defined in `.github/pull_request_template.md`. - Before committing, ALWAYS ensure you are on the correct branch and do not use `git add .` to add all unstaged files. Instead, add only the individual files you changed, don't commit all files. diff --git a/find_clean_facts.py b/find_clean_facts.py deleted file mode 100644 index e22ca906d..000000000 --- a/find_clean_facts.py +++ /dev/null @@ -1,572 +0,0 @@ -#!/usr/bin/env python3 -""" -Find the cleanest (most monosemantic) facts from the SPD analysis. - -A fact is "clean" if the components that fire on it are monosemantic. - -For down_proj: A component is monosemantic if it responds to a single label. -For up_proj: A component is monosemantic if it responds to: - - A single label, OR - - A single input element at position 0, 1, or 2 - -We score each fact based on how monosemantic its firing components are. -""" - -import re -from collections import Counter, defaultdict - - -def parse_analysis_file(filepath: str): - """Parse the analysis.txt file to extract component and fact information.""" - - with open(filepath) as f: - lines = f.readlines() - - # Parse component-to-facts mapping (from the COMPONENT ACTIVATION ANALYSIS section) - up_proj_components = defaultdict(list) # component_id -> list of (fact_idx, input, label) - down_proj_components = defaultdict(list) - - # Parse the per-fact analysis (from PER-FACT COMPONENT ANALYSIS section) - up_proj_per_fact = {} # fact_idx -> {inputs, label, components} - down_proj_per_fact = {} - - current_module = None - current_section = None # 'component_analysis' or 'per_fact' - current_component = None - - i = 0 - while i < len(lines): - line = lines[i].strip() - - # Detect section changes - if "COMPONENT ACTIVATION ANALYSIS" in line: - current_section = "component_analysis" - elif "PER-FACT COMPONENT ANALYSIS" in line: - current_section = "per_fact" - elif "SUMMARY STATISTICS" in line: - current_section = "summary" - - # Detect module changes - if "MODULE: block.mlp.up_proj" in line: - current_module = "up_proj" - elif "MODULE: block.mlp.down_proj" in line: - current_module = "down_proj" - - # Parse component activation analysis section - if current_section == "component_analysis" and current_module: - # Parse component header: [Rank X] Component Y (mean CI=Z): N facts above threshold - comp_match = re.match(r"\[Rank \d+\] Component (\d+)", line) - if comp_match: - current_component = int(comp_match.group(1)) - - # Parse fact line: Fact X: input=[a, b, c] → label=Y (CI=Z) - fact_match = re.match( - r"Fact\s+(\d+): input=\[(\d+), (\d+), (\d+)\] → label=(\d+)", line - ) - if fact_match and current_component is not None: - fact_idx = int(fact_match.group(1)) - inputs = [ - int(fact_match.group(2)), - int(fact_match.group(3)), - int(fact_match.group(4)), - ] - label = int(fact_match.group(5)) - - if current_module == "up_proj": - up_proj_components[current_component].append((fact_idx, inputs, label)) - else: - down_proj_components[current_component].append((fact_idx, inputs, label)) - - # Parse per-fact analysis section - if current_section == "per_fact" and current_module: - # Parse fact line - fact_match = re.match( - r"Fact\s+(\d+): input=\[(\d+), (\d+), (\d+)\] → label=(\d+)", line - ) - if fact_match: - fact_idx = int(fact_match.group(1)) - inputs = [ - int(fact_match.group(2)), - int(fact_match.group(3)), - int(fact_match.group(4)), - ] - label = int(fact_match.group(5)) - - # Look for components in the next lines - components = [] - j = i + 1 - while j < len(lines): - next_line = lines[j].strip() - - # Check if we've hit the next fact or section - if ( - next_line.startswith("Fact ") - or next_line.startswith("===") - or next_line.startswith("MODULE:") - ): - break - - # Parse component activations like C206(1.000) - comp_matches = re.findall(r"C(\d+)\(([\d.]+)\)", next_line) - for comp_id, ci_score in comp_matches: - components.append((int(comp_id), float(ci_score))) - - j += 1 - - if current_module == "up_proj": - up_proj_per_fact[fact_idx] = { - "inputs": inputs, - "label": label, - "components": components, - } - else: - down_proj_per_fact[fact_idx] = { - "inputs": inputs, - "label": label, - "components": components, - } - - i += 1 - - return up_proj_components, down_proj_components, up_proj_per_fact, down_proj_per_fact - - -def compute_component_monosemanticity(component_facts: list) -> dict: - """ - Compute monosemanticity scores for a component. - """ - if not component_facts: - return None - - labels = [f[2] for f in component_facts] - pos0_vals = [f[1][0] for f in component_facts] - pos1_vals = [f[1][1] for f in component_facts] - pos2_vals = [f[1][2] for f in component_facts] - - label_counts = Counter(labels) - pos0_counts = Counter(pos0_vals) - pos1_counts = Counter(pos1_vals) - pos2_counts = Counter(pos2_vals) - - n = len(component_facts) - - dominant_label, dominant_label_count = label_counts.most_common(1)[0] - dominant_pos0, dominant_pos0_count = pos0_counts.most_common(1)[0] - dominant_pos1, dominant_pos1_count = pos1_counts.most_common(1)[0] - dominant_pos2, dominant_pos2_count = pos2_counts.most_common(1)[0] - - return { - "n_facts": n, - "n_unique_labels": len(label_counts), - "dominant_label": dominant_label, - "dominant_label_ratio": dominant_label_count / n, - "n_unique_pos0": len(pos0_counts), - "dominant_pos0": dominant_pos0, - "dominant_pos0_ratio": dominant_pos0_count / n, - "n_unique_pos1": len(pos1_counts), - "dominant_pos1": dominant_pos1, - "dominant_pos1_ratio": dominant_pos1_count / n, - "n_unique_pos2": len(pos2_counts), - "dominant_pos2": dominant_pos2, - "dominant_pos2_ratio": dominant_pos2_count / n, - } - - -def is_component_monosemantic(stats: dict, threshold: float = 0.9) -> tuple[bool, str]: - """ - Determine if a component is monosemantic based on its statistics. - Returns (is_monosemantic, reason) - """ - if stats is None: - return False, "no_data" - - # Check if it responds to a single label - if stats["dominant_label_ratio"] >= threshold: - return True, f"label_{stats['dominant_label']}" - - # Check if it responds to a single input element - if stats["dominant_pos0_ratio"] >= threshold: - return True, f"pos0_{stats['dominant_pos0']}" - if stats["dominant_pos1_ratio"] >= threshold: - return True, f"pos1_{stats['dominant_pos1']}" - if stats["dominant_pos2_ratio"] >= threshold: - return True, f"pos2_{stats['dominant_pos2']}" - - return False, "polysemantic" - - -def compute_monosemanticity_score(stats: dict) -> float: - """ - Compute a monosemanticity score from 0 to 1. - Higher score = more monosemantic. - """ - if stats is None: - return 0.0 - - # The score is the maximum of all the dominant ratios - return max( - stats["dominant_label_ratio"], - stats["dominant_pos0_ratio"], - stats["dominant_pos1_ratio"], - stats["dominant_pos2_ratio"], - ) - - -def score_fact( - fact_info: dict, - up_proj_mono_scores: dict, - down_proj_mono_scores: dict, - up_proj_stats: dict, - down_proj_stats: dict, -) -> tuple[float, dict]: - """ - Score a fact based on how monosemantic its firing components are. - Returns (score, details) - """ - up_components = fact_info.get("up_proj_components", []) - down_components = fact_info.get("down_proj_components", []) - - if not up_components and not down_components: - return 0.0, { - "reason": "no_components", - "up_proj_components": [], - "down_proj_components": [], - "n_components": 0, - } - - # For each component, get its monosemanticity score - up_scores = [] - for comp_id, ci_score in up_components: - mono_score = up_proj_mono_scores.get(comp_id, 0.0) - stats = up_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(stats, threshold=0.9) if stats else (False, "unknown") - ) - up_scores.append((comp_id, mono_score, ci_score, is_mono, reason)) - - down_scores = [] - for comp_id, ci_score in down_components: - mono_score = down_proj_mono_scores.get(comp_id, 0.0) - stats = down_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(stats, threshold=0.9) if stats else (False, "unknown") - ) - down_scores.append((comp_id, mono_score, ci_score, is_mono, reason)) - - # Compute fact score as minimum monosemanticity of all components - all_mono_scores = [s[1] for s in up_scores] + [s[1] for s in down_scores] - - if not all_mono_scores: - return 0.0, { - "reason": "no_scores", - "up_proj_components": up_scores, - "down_proj_components": down_scores, - "n_components": 0, - } - - min_score = min(all_mono_scores) - mean_score = sum(all_mono_scores) / len(all_mono_scores) - - # Count how many components are monosemantic - n_mono = sum(1 for s in up_scores + down_scores if s[3]) - total = len(up_scores) + len(down_scores) - - return min_score, { - "up_proj_components": up_scores, - "down_proj_components": down_scores, - "min_mono_score": min_score, - "mean_mono_score": mean_score, - "n_components": total, - "n_mono_components": n_mono, - "mono_ratio": n_mono / total if total > 0 else 0, - } - - -def main(): - print("Parsing analysis.txt...") - up_proj_comps, down_proj_comps, up_proj_facts, down_proj_facts = parse_analysis_file( - "analysis.txt" - ) - - print(f"\nFound {len(up_proj_comps)} up_proj components with facts") - print(f"Found {len(down_proj_comps)} down_proj components with facts") - print(f"Found {len(up_proj_facts)} facts with up_proj info") - print(f"Found {len(down_proj_facts)} facts with down_proj info") - - # Sample check - if up_proj_facts: - sample_fact = list(up_proj_facts.items())[0] - print(f"\nSample up_proj fact: {sample_fact}") - if down_proj_facts: - sample_fact = list(down_proj_facts.items())[0] - print(f"Sample down_proj fact: {sample_fact}") - - # Compute monosemanticity for each component - print("\nComputing component monosemanticity...") - - up_proj_stats = {} - up_proj_mono_scores = {} - for comp_id, facts in up_proj_comps.items(): - stats = compute_component_monosemanticity(facts) - up_proj_stats[comp_id] = stats - up_proj_mono_scores[comp_id] = compute_monosemanticity_score(stats) - - down_proj_stats = {} - down_proj_mono_scores = {} - for comp_id, facts in down_proj_comps.items(): - stats = compute_component_monosemanticity(facts) - down_proj_stats[comp_id] = stats - down_proj_mono_scores[comp_id] = compute_monosemanticity_score(stats) - - # Print some example monosemantic components - print("\n" + "=" * 80) - print("MONOSEMANTIC UP_PROJ COMPONENTS (threshold >= 0.9)") - print("=" * 80) - mono_up = [] - for comp_id, stats in up_proj_stats.items(): - is_mono, reason = is_component_monosemantic(stats, threshold=0.9) - if is_mono: - mono_up.append((comp_id, stats, reason)) - - mono_up.sort(key=lambda x: compute_monosemanticity_score(x[1]), reverse=True) - for comp_id, stats, reason in mono_up[:20]: - print( - f" Component {comp_id}: {reason}, score={compute_monosemanticity_score(stats):.3f}, n_facts={stats['n_facts']}" - ) - print(f" ... and {len(mono_up) - 20} more" if len(mono_up) > 20 else "") - - print(f"\nTotal monosemantic up_proj components: {len(mono_up)} / {len(up_proj_stats)}") - - print("\n" + "=" * 80) - print("MONOSEMANTIC DOWN_PROJ COMPONENTS (threshold >= 0.9)") - print("=" * 80) - mono_down = [] - for comp_id, stats in down_proj_stats.items(): - is_mono, reason = is_component_monosemantic(stats, threshold=0.9) - if is_mono: - mono_down.append((comp_id, stats, reason)) - - mono_down.sort(key=lambda x: compute_monosemanticity_score(x[1]), reverse=True) - for comp_id, stats, reason in mono_down[:20]: - print( - f" Component {comp_id}: {reason}, score={compute_monosemanticity_score(stats):.3f}, n_facts={stats['n_facts']}" - ) - print(f" ... and {len(mono_down) - 20} more" if len(mono_down) > 20 else "") - - print(f"\nTotal monosemantic down_proj components: {len(mono_down)} / {len(down_proj_stats)}") - - # Combine up_proj and down_proj info for each fact - print("\n" + "=" * 80) - print("SCORING FACTS BY MONOSEMANTICITY") - print("=" * 80) - - all_facts = set(up_proj_facts.keys()) | set(down_proj_facts.keys()) - fact_scores = [] - - for fact_idx in all_facts: - up_info = up_proj_facts.get(fact_idx, {}) - down_info = down_proj_facts.get(fact_idx, {}) - - # Get the inputs and label from either source - inputs = up_info.get("inputs") or down_info.get("inputs", []) - label = up_info.get("label", down_info.get("label", -1)) - - combined_info = { - "inputs": inputs, - "label": label, - "up_proj_components": up_info.get("components", []), - "down_proj_components": down_info.get("components", []), - } - - score, details = score_fact( - combined_info, - up_proj_mono_scores, - down_proj_mono_scores, - up_proj_stats, - down_proj_stats, - ) - - fact_scores.append( - { - "fact_idx": fact_idx, - "inputs": inputs, - "label": label, - "score": score, - "details": details, - } - ) - - # Sort by score (highest = cleanest), then by mono ratio, then by fewer components - fact_scores.sort( - key=lambda x: ( - x["score"], - x["details"].get("mono_ratio", 0), - -x["details"].get("n_components", 999), - ), - reverse=True, - ) - - # Print top cleanest facts - print("\nTOP 50 CLEANEST FACTS (highest monosemanticity score):") - print("-" * 80) - - for i, fs in enumerate(fact_scores[:50]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - up_str = ", ".join([f"C{c[0]}({c[4]})" for c in up_comps]) - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"\n{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print(f" Score: {fs['score']:.3f}, mono_ratio: {fs['details'].get('mono_ratio', 0):.2f}") - print(f" Up_proj ({len(up_comps)}): {up_str if up_str else 'none'}") - print(f" Down_proj ({len(down_comps)}): {down_str if down_str else 'none'}") - - # Find facts where ALL components are monosemantic - print("\n" + "=" * 80) - print("FACTS WHERE ALL COMPONENTS ARE MONOSEMANTIC") - print("=" * 80) - - all_mono_facts = [ - fs - for fs in fact_scores - if fs["details"].get("n_components", 0) > 0 and fs["details"].get("mono_ratio", 0) == 1.0 - ] - - print(f"\nFound {len(all_mono_facts)} facts where ALL components are monosemantic:\n") - - for i, fs in enumerate(all_mono_facts[:30]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - up_str = ", ".join([f"C{c[0]}({c[4]})" for c in up_comps]) - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print(f" Up_proj: {up_str if up_str else 'none'}") - print(f" Down_proj: {down_str if down_str else 'none'}") - print() - - if len(all_mono_facts) > 30: - print(f" ... and {len(all_mono_facts) - 30} more") - - # Also show facts with only 1 component firing in up_proj - print("\n" + "=" * 80) - print("FACTS WITH ONLY 1 UP_PROJ COMPONENT FIRING") - print("=" * 80) - - single_comp_facts = [ - fs for fs in fact_scores if len(fs["details"].get("up_proj_components", [])) == 1 - ] - single_comp_facts.sort(key=lambda x: x["score"], reverse=True) - - print(f"\nFound {len(single_comp_facts)} facts with only 1 up_proj component:\n") - - for i, fs in enumerate(single_comp_facts[:30]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - comp_id = up_comps[0][0] - comp_stats = up_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(comp_stats, threshold=0.9) - if comp_stats - else (False, "unknown") - ) - - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print( - f" Up_proj C{comp_id}: mono_score={fs['score']:.3f}, is_mono={is_mono}, reason={reason}" - ) - print(f" Down_proj: {down_str if down_str else 'none'}") - if comp_stats: - print( - f" Component stats: dominant_label={comp_stats['dominant_label']} ({comp_stats['dominant_label_ratio']:.1%})" - ) - print() - - # Print summary - print("\n" + "=" * 80) - print("SUMMARY") - print("=" * 80) - - # Count facts with at least one component - facts_with_components = [fs for fs in fact_scores if fs["details"].get("n_components", 0) > 0] - print(f"\nTotal facts with at least one component: {len(facts_with_components)}") - - score_thresholds = [1.0, 0.95, 0.9, 0.8, 0.5, 0.0] - for thresh in score_thresholds: - count = sum(1 for fs in facts_with_components if fs["score"] >= thresh) - print(f" Facts with monosemanticity score >= {thresh}: {count}") - - # Save results to a file - print("\n\nSaving detailed results to clean_facts_ranking.txt...") - with open("clean_facts_ranking.txt", "w") as f: - f.write("FACTS RANKED BY MONOSEMANTICITY SCORE\n") - f.write("=" * 80 + "\n\n") - f.write("A fact is 'clean' if all components that fire on it are monosemantic.\n") - f.write( - "Monosemantic = responds to a single label or single input position value (>= 90%).\n\n" - ) - - f.write(f"Total facts with at least one component: {len(facts_with_components)}\n") - f.write(f"Facts where ALL components are monosemantic: {len(all_mono_facts)}\n\n") - - f.write("=" * 80 + "\n") - f.write("CLEANEST FACTS (all components monosemantic)\n") - f.write("=" * 80 + "\n\n") - - for i, fs in enumerate(all_mono_facts): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - f.write(f"Rank {i + 1}: Fact {fs['fact_idx']}\n") - f.write(f" Input: {fs['inputs']} → Label: {fs['label']}\n") - f.write(f" Monosemanticity Score: {fs['score']:.4f}\n") - f.write(f" Up_proj components ({len(up_comps)}):\n") - for comp_id, mono_score, ci_score, _is_mono, reason in up_comps: - f.write( - f" C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write(f" Down_proj components ({len(down_comps)}):\n") - for comp_id, mono_score, ci_score, _is_mono, reason in down_comps: - f.write( - f" C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write("\n") - - f.write("\n" + "=" * 80 + "\n") - f.write("ALL FACTS RANKED\n") - f.write("=" * 80 + "\n\n") - - for i, fs in enumerate(facts_with_components): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - f.write(f"Rank {i + 1}: Fact {fs['fact_idx']}\n") - f.write(f" Input: {fs['inputs']} → Label: {fs['label']}\n") - f.write(f" Min Monosemanticity Score: {fs['score']:.4f}\n") - f.write( - f" Mono ratio: {fs['details'].get('mono_ratio', 0):.2f} ({fs['details'].get('n_mono_components', 0)}/{fs['details'].get('n_components', 0)})\n" - ) - f.write(f" Up_proj components ({len(up_comps)}):\n") - for comp_id, mono_score, ci_score, is_mono, reason in up_comps: - mono_marker = "✓" if is_mono else "✗" - f.write( - f" {mono_marker} C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write(f" Down_proj components ({len(down_comps)}):\n") - for comp_id, mono_score, ci_score, is_mono, reason in down_comps: - mono_marker = "✓" if is_mono else "✗" - f.write( - f" {mono_marker} C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write("\n") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/package-lock.json b/package-lock.json deleted file mode 100644 index ca5f8195a..000000000 --- a/package-lock.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "spd", - "lockfileVersion": 3, - "requires": true, - "packages": {} -} diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 0f95d54a8..7e86aa0fd 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -4,17 +4,19 @@ Web-based visualization and analysis tool for exploring neural network component - **Backend**: Python FastAPI (`backend/`) - **Frontend**: Svelte 5 + TypeScript (`frontend/`) -- **Database**: SQLite at `.data/app/prompt_attr.db` (relative to repo root) +- **Database**: SQLite at `SPD_OUT_DIR/app/prompt_attr.db` (shared across team via NFS) - **TODOs**: See `TODO.md` for open work items ## Project Context This is a **rapidly iterated research tool**. Key implications: -- **Please do not code for backwards compatibility**: Schema changes don't need migrations, expect state can be deleted, etc. -- **Database is disposable**: Delete `.data/app/prompt_attr.db` if schema changes break things +- **Database is persistent shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, shared across the team. Do not delete. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access. + - **Schema changes require manual migration**: Update the `CREATE TABLE IF NOT EXISTS` statements to match the desired schema, then manually `ALTER TABLE` the real DB (back it up first). No automatic migration framework — just SQL. + - Keep the CREATE TABLE statements as the source of truth for the schema. - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert +- **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. ## Running the App @@ -34,14 +36,14 @@ This launches both backend (FastAPI/uvicorn) and frontend (Vite) dev servers. backend/ ├── server.py # FastAPI app, CORS, routers ├── state.py # Singleton StateManager + HarvestRepo (lazy-loaded harvest data) -├── compute.py # Core attribution computation +├── compute.py # Core attribution computation + intervention evaluation ├── app_tokenizer.py # AppTokenizer: wraps HF tokenizers for display/encoding ├── (topology lives at spd/topology.py — TransformerTopology) ├── schemas.py # Pydantic API models ├── dependencies.py # FastAPI dependency injection ├── utils.py # Logging/timing utilities ├── database.py # SQLite interface -├── optim_cis.py # Sparse CI optimization +├── optim_cis.py # Sparse CI optimization, loss configs, PGD └── routers/ ├── runs.py # Load W&B runs + GET /api/model_info ├── graphs.py # Compute attribution graphs @@ -51,7 +53,8 @@ backend/ ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering ├── dataset_search.py # Dataset search (reads dataset from run config) - └── agents.py # Various useful endpoints that AI agents should look at when helping + ├── agents.py # Various useful endpoints that AI agents should look at when helping + └── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ``` Note: Activation contexts, correlations, and token stats are now loaded from pre-harvested data (see `spd/harvest/`). The app no longer computes these on-the-fly. @@ -90,7 +93,7 @@ frontend/src/ ├── ActivationContextsPagedTable.svelte ├── DatasetSearchTab.svelte # Dataset search UI ├── DatasetSearchResults.svelte - ├── ClusterPathInput.svelte # Cluster path selector + ├── ClusterPathInput.svelte # Cluster path selector (dropdown populated from registry.ts) ├── ComponentProbeInput.svelte # Component probe UI ├── TokenHighlights.svelte # Token highlighting ├── prompt-attr/ @@ -154,8 +157,13 @@ Edge(source: Node, target: Node, strength: float, is_cross_seq: bool) # strength = gradient * activation # is_cross_seq = True for k/v → o_proj (attention pattern) -PromptAttributionResult(edges: list[Edge], output_probs: Tensor[seq, vocab], node_ci_vals: dict[str, float]) -# node_ci_vals maps "layer:seq:c_idx" → CI value +PromptAttributionResult(edges, ci_masked_out_logits, target_out_logits, node_ci_vals, node_subcomp_acts) + +TokenPrediction(token, token_id, prob, logit, target_prob, target_logit) + +InterventionResult(input_tokens, ci, stochastic, adversarial, ci_loss, stochastic_loss, adversarial_loss) +# ci/stochastic/adversarial are list[list[TokenPrediction]] (per-position top-k) +# losses are evaluated using the graph's implied loss context ``` ### Frontend Types (`promptAttributionsTypes.ts`) @@ -211,13 +219,31 @@ Finds sparse CI mask that: - Minimizes L0 (active component count) - Uses importance minimality + CE loss (or KL loss) -### Intervention Forward +### Interventions (`compute.py → compute_intervention`) + +A single unified function evaluates a node selection under three masking regimes: + +- **CI**: mask = selection (binary on/off) +- **Stochastic**: mask = selection + (1-selection) × Uniform(0,1) +- **Adversarial**: PGD optimizes alive-but-unselected components to maximize loss; non-alive get Uniform(0,1) + +Returns `InterventionResult` with top-k `TokenPrediction`s per position for each regime, plus per-regime loss values. + +**Loss context**: Every graph has an implied loss that interventions evaluate against: -`compute_intervention_forward()`: +- **Standard/manual graphs** → `MeanKLLossConfig` (mean KL divergence from target across all positions) +- **Optimized graphs** → the graph's optimization loss (CE for a specific token at a position, or KL at a position) -1. Build component masks (all zeros) -2. Set mask=1.0 for selected nodes -3. Forward pass → top-k predictions per position +This loss is used for two things: (1) what PGD maximizes during adversarial evaluation, and (2) the `ci_loss`/`stochastic_loss`/`adversarial_loss` metrics reported in `InterventionResult`. + +**Alive masks**: `compute_intervention` recomputes the model's natural CI (one forward pass + `calc_causal_importances`) and binarizes at 0 to get alive masks. This ensures the alive set is always the full model's CI — not the graph's potentially sparse optimized CI. PGD can only manipulate alive-but-unselected components. + +**Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, +`adv_pgd_step_size`) are a _training_ regularizer — they make CI optimization robust. The PGD in +`compute_intervention` is an _eval_ metric — it measures worst-case performance for a given node +selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). + +**Base intervention run**: Created automatically during graph computation. Uses all interventable nodes with CI > 0. Persisted as an `intervention_run` so predictions are available synchronously. --- @@ -245,9 +271,14 @@ POST /api/graphs ### Intervention ``` -POST /api/intervention {text, nodes: ["h.0.attn.q_proj:3:5", ...]} - → compute_intervention_forward() - ← InterventionResponse with top-k predictions +POST /api/intervention/run {graph_id, selected_nodes, top_k, adv_pgd} + → compute_intervention(active_nodes, graph_alive_masks, loss_config) + ← InterventionRunSummary {id, selected_nodes, result: InterventionResult} + +InterventionResult = { + input_tokens, ci, stochastic, adversarial, // TokenPrediction[][] per regime + ci_loss, stochastic_loss, adversarial_loss // loss under each regime +} ``` ### Component Correlations & Interpretations @@ -281,14 +312,14 @@ GET /api/dataset/results?page=1&page_size=20 ## Database Schema -Located at `.data/app/prompt_attr.db`. Delete this file if schema changes cause issues. +Located at `SPD_OUT_DIR/app/prompt_attr.db` (shared via NFS). Uses DELETE journal mode with `fcntl.flock` write locking for safe concurrent access from multiple backends. -| Table | Key | Purpose | -| ------------------ | ---------------------------------- | ------------------------------------------------- | -| `runs` | `wandb_path` | W&B run references | -| `prompts` | `(run_id, context_length)` | Token sequences | -| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + output probs + node CI values | -| `intervention_runs`| `graph_id` | Saved intervention results | +| Table | Key | Purpose | +| ------------------- | ---------------------------------- | -------------------------------------------------------- | +| `runs` | `wandb_path` | W&B run references | +| `prompts` | `(run_id, context_length)` | Token sequences | +| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + CI/target logits + node CI values | +| `intervention_runs` | `graph_id` | Saved `InterventionResult` JSON (single `result` column) | Note: Activation contexts, correlations, token stats, and interpretations are loaded from pre-harvested data at `SPD_OUT_DIR/{harvest,autointerp}/` (see `spd/harvest/` and `spd/autointerp/`). diff --git a/spd/app/TODO.md b/spd/app/TODO.md deleted file mode 100644 index a3c7eb1aa..000000000 --- a/spd/app/TODO.md +++ /dev/null @@ -1,3 +0,0 @@ -# App TODOs - -- Audit SQLite access pragma stuff — `immutable=1` in `HarvestDB` causes "database disk image is malformed" errors when the app reads a harvest DB mid-write (WAL not yet checkpointed). Investigate whether to check for WAL file existence, use normal locking mode, or add another safeguard. See `spd/harvest/db.py:79`. diff --git a/spd/app/backend/app_tokenizer.py b/spd/app/backend/app_tokenizer.py index 0d79cd9ba..acfa4d7eb 100644 --- a/spd/app/backend/app_tokenizer.py +++ b/spd/app/backend/app_tokenizer.py @@ -53,6 +53,12 @@ def vocab_size(self) -> int: assert isinstance(size, int) return size + @property + def eos_token_id(self) -> int: + eos = self._tok.eos_token_id + assert isinstance(eos, int) + return eos + def encode(self, text: str) -> list[int]: return self._tok.encode(text, add_special_tokens=False) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 8992e0e06..c99de5a02 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -12,12 +12,26 @@ import torch from jaxtyping import Bool, Float +from pydantic import BaseModel from torch import Tensor, nn from spd.app.backend.app_tokenizer import AppTokenizer -from spd.app.backend.optim_cis import OptimCIConfig, OptimizationMetrics, optimize_ci_values +from spd.app.backend.optim_cis import ( + AdvPGDConfig, + CELossConfig, + CISnapshotCallback, + LogitLossConfig, + LossConfig, + OptimCIConfig, + OptimizationMetrics, + compute_recon_loss, + optimize_ci_values, + optimize_ci_values_batched, + run_adv_pgd, +) from spd.configs import SamplingType from spd.log import logger +from spd.metrics.pgd_utils import interpolate_pgd_mask from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.topology import TransformerTopology @@ -115,6 +129,7 @@ class PromptAttributionResult: """Result of computing prompt attributions for a prompt.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -128,6 +143,7 @@ class OptimizedPromptAttributionResult: """Result of computing prompt attributions with optimized CI values.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -135,7 +151,6 @@ class OptimizedPromptAttributionResult: node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val node_subcomp_acts: dict[str, float] # layer:seq:c_idx -> subcomponent activation (v_i^T @ a) metrics: OptimizationMetrics # Final loss metrics from optimization - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None # Adversarial PGD output logits ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) @@ -168,17 +183,22 @@ def _compute_edges_for_target( cache: dict[str, Tensor], loss_seq_pos: int, topology: TransformerTopology, -) -> list[Edge]: +) -> tuple[list[Edge], list[Edge]]: """Compute all edges flowing into a single target layer. For each alive (s_out, c_out) in the target layer, computes gradient-based - attribution strengths from all alive source components. + attribution strengths from all alive source components. Computes both signed + (∂y/∂x · x) and absolute-target (∂|y|/∂x · x) variants. Args: loss_seq_pos: Maximum sequence position to include (inclusive). Only compute edges for target positions <= loss_seq_pos. + + Returns: + (edges, edges_abs): Signed and absolute-target edge lists. """ edges: list[Edge] = [] + edges_abs: list[Edge] = [] out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{target}_pre_detach"] in_post_detaches: list[Float[Tensor, "1 s C"]] = [ cache[f"{source}_post_detach"] for source in sources @@ -190,11 +210,19 @@ def _compute_edges_for_target( continue for c_out in s_out_alive_c: + target_val = out_pre_detach[0, s_out, c_out] grads = torch.autograd.grad( - outputs=out_pre_detach[0, s_out, c_out], + outputs=target_val, inputs=in_post_detaches, retain_graph=True, ) + # ∂|y|/∂x = sign(y) · ∂y/∂x — avoids a second backward pass. + # This works because target_val is a single scalar. In dataset_attributions/ + # harvester.py, the target is sum(|y_i|) over batch+seq — there each y_i has a + # different sign, so you can't factor out one scalar. The issue isn't the chain + # rule (sign·grad is always valid per-element), it's that abs breaks the + # grad(sum)=sum(grad) trick that makes the batch reduction a single backward pass. + target_sign = target_val.sign() with torch.no_grad(): canonical_target = topology.target_to_canon(target) for source, source_info, grad, in_post_detach in zip( @@ -203,27 +231,35 @@ def _compute_edges_for_target( canonical_source = topology.target_to_canon(source) is_cross_seq = topology.is_cross_seq_pair(canonical_source, canonical_target) weighted: Float[Tensor, "s C"] = (grad * in_post_detach)[0] + weighted_abs: Float[Tensor, "s C"] = weighted * target_sign if canonical_source == "embed": weighted = weighted.sum(dim=1, keepdim=True) + weighted_abs = weighted_abs.sum(dim=1, keepdim=True) s_in_range = range(s_out + 1) if is_cross_seq else [s_out] for s_in in s_in_range: for c_in in source_info.alive_c_idxs: if not source_info.alive_mask[s_in, c_in]: continue + src = Node(layer=canonical_source, seq_pos=s_in, component_idx=c_in) + tgt = Node(layer=canonical_target, seq_pos=s_out, component_idx=c_out) edges.append( Edge( - source=Node( - layer=canonical_source, seq_pos=s_in, component_idx=c_in - ), - target=Node( - layer=canonical_target, seq_pos=s_out, component_idx=c_out - ), + source=src, + target=tgt, strength=weighted[s_in, c_in].item(), is_cross_seq=is_cross_seq, ) ) - return edges + edges_abs.append( + Edge( + source=src, + target=tgt, + strength=weighted_abs[s_in, c_in].item(), + is_cross_seq=is_cross_seq, + ) + ) + return edges, edges_abs def compute_edges_from_ci( @@ -330,12 +366,13 @@ def compute_edges_from_ci( # Compute edges for each target layer t0 = time.perf_counter() edges: list[Edge] = [] + edges_abs: list[Edge] = [] total_source_layers = sum(len(sources) for sources in sources_by_target.values()) progress_count = 0 for target, sources in sources_by_target.items(): t_target = time.perf_counter() - target_edges = _compute_edges_for_target( + target_edges, target_edges_abs = _compute_edges_for_target( target=target, sources=sources, target_info=alive_info[target], @@ -345,6 +382,7 @@ def compute_edges_from_ci( topology=topology, ) edges.extend(target_edges) + edges_abs.extend(target_edges_abs) canonical_target = topology.target_to_canon(target) logger.info( f"[perf] {canonical_target}: {time.perf_counter() - t_target:.2f}s, " @@ -375,6 +413,7 @@ def compute_edges_from_ci( return PromptAttributionResult( edges=edges, + edges_abs=edges_abs, ci_masked_out_probs=ci_masked_out_probs[0, : loss_seq_pos + 1], ci_masked_out_logits=ci_masked_logits[0, : loss_seq_pos + 1], target_out_probs=target_out_probs[0, : loss_seq_pos + 1], @@ -508,6 +547,7 @@ def compute_prompt_attributions_optimized( output_prob_threshold: float, device: str, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizedPromptAttributionResult: """Compute prompt attributions using optimized sparse CI values. @@ -528,6 +568,7 @@ def compute_prompt_attributions_optimized( config=optim_config, device=device, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) ci_outputs = optim_result.params.create_ci_outputs(model, device) @@ -557,13 +598,9 @@ def compute_prompt_attributions_optimized( loss_seq_pos=loss_seq_pos, ) - # Slice adversarial logits to match the loss_seq_pos range - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None - if optim_result.adv_pgd_out_logits is not None: - adv_pgd_out_logits = optim_result.adv_pgd_out_logits[: loss_seq_pos + 1] - return OptimizedPromptAttributionResult( edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_probs=result.ci_masked_out_probs, ci_masked_out_logits=result.ci_masked_out_logits, target_out_probs=result.target_out_probs, @@ -571,10 +608,81 @@ def compute_prompt_attributions_optimized( node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, metrics=optim_result.metrics, - adv_pgd_out_logits=adv_pgd_out_logits, ) +def compute_prompt_attributions_optimized_batched( + model: ComponentModel, + topology: TransformerTopology, + tokens: Float[Tensor, "1 seq"], + sources_by_target: dict[str, list[str]], + configs: list[OptimCIConfig], + output_prob_threshold: float, + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizedPromptAttributionResult]: + """Compute prompt attributions for multiple sparsity coefficients in one batched optimization.""" + with torch.no_grad(), bf16_autocast(): + target_logits = model(tokens) + target_out_probs = torch.softmax(target_logits, dim=-1) + + optim_results = optimize_ci_values_batched( + model=model, + tokens=tokens, + configs=configs, + device=device, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + if on_progress is not None: + on_progress(0, len(optim_results), "graph") + + with torch.no_grad(), bf16_autocast(): + pre_weight_acts = model(tokens, cache_type="input").cache + + loss_seq_pos = configs[0].loss_config.position + + results: list[OptimizedPromptAttributionResult] = [] + for i, optim_result in enumerate(optim_results): + ci_outputs = optim_result.params.create_ci_outputs(model, device) + + result = compute_edges_from_ci( + model=model, + topology=topology, + tokens=tokens, + ci_lower_leaky=ci_outputs.lower_leaky, + pre_weight_acts=pre_weight_acts, + sources_by_target=sources_by_target, + target_out_probs=target_out_probs, + target_out_logits=target_logits, + output_prob_threshold=output_prob_threshold, + device=device, + on_progress=on_progress, + loss_seq_pos=loss_seq_pos, + ) + + results.append( + OptimizedPromptAttributionResult( + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_probs=result.ci_masked_out_probs, + ci_masked_out_logits=result.ci_masked_out_logits, + target_out_probs=result.target_out_probs, + target_out_logits=result.target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + metrics=optim_result.metrics, + ) + ) + + if on_progress is not None: + on_progress(i + 1, len(optim_results), "graph") + + return results + + @dataclass class CIOnlyResult: """Result of computing CI values only (no attribution graph).""" @@ -666,94 +774,248 @@ def extract_node_subcomp_acts( return node_subcomp_acts -@dataclass -class InterventionResult: - """Result of intervention forward pass.""" +class TokenPrediction(BaseModel): + """A single token prediction with probability.""" + + token: str + token_id: int + prob: float + logit: float + target_prob: float + target_logit: float + + +class LabelPredictions(BaseModel): + """Prediction stats for the CE label token at the optimized position, per masking regime.""" + + position: int + ci: TokenPrediction + stochastic: TokenPrediction + adversarial: TokenPrediction + ablated: TokenPrediction | None + + +class InterventionResult(BaseModel): + """Unified result of an intervention evaluation under multiple masking regimes.""" input_tokens: list[str] - predictions_per_position: list[ - list[tuple[str, int, float, float, float, float]] - ] # [(token, id, spd_prob, logit, target_prob, target_logit)] + ci: list[list[TokenPrediction]] + stochastic: list[list[TokenPrediction]] + adversarial: list[list[TokenPrediction]] + ablated: list[list[TokenPrediction]] | None + ci_loss: float + stochastic_loss: float + adversarial_loss: float + ablated_loss: float | None + label: LabelPredictions | None + + +# Default eval PGD settings (distinct from optimization PGD which is a training regularizer) +DEFAULT_EVAL_PGD_CONFIG = AdvPGDConfig(n_steps=4, step_size=1.0, init="random") + + +def _extract_topk_predictions( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + top_k: int, +) -> list[list[TokenPrediction]]: + """Extract top-k token predictions per position, paired with target probs.""" + probs = torch.softmax(logits, dim=-1) + target_probs = torch.softmax(target_logits, dim=-1) + result: list[list[TokenPrediction]] = [] + for pos in range(probs.shape[1]): + top_vals, top_ids = torch.topk(probs[0, pos], top_k) + pos_preds: list[TokenPrediction] = [] + for p, tid_t in zip(top_vals, top_ids, strict=True): + tid = int(tid_t.item()) + pos_preds.append( + TokenPrediction( + token=tokenizer.get_tok_display(tid), + token_id=tid, + prob=float(p.item()), + logit=float(logits[0, pos, tid].item()), + target_prob=float(target_probs[0, pos, tid].item()), + target_logit=float(target_logits[0, pos, tid].item()), + ) + ) + result.append(pos_preds) + return result + + +def _extract_label_prediction( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + position: int, + label_token: int, +) -> TokenPrediction: + """Extract the prediction for a specific token at a specific position.""" + probs = torch.softmax(logits[0, position], dim=-1) + target_probs = torch.softmax(target_logits[0, position], dim=-1) + return TokenPrediction( + token=tokenizer.get_tok_display(label_token), + token_id=label_token, + prob=float(probs[label_token].item()), + logit=float(logits[0, position, label_token].item()), + target_prob=float(target_probs[label_token].item()), + target_logit=float(target_logits[0, position, label_token].item()), + ) -def compute_intervention_forward( +def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], - active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] - top_k: int, + active_nodes: list[tuple[str, int, int]], + nodes_to_ablate: list[tuple[str, int, int]] | None, tokenizer: AppTokenizer, + adv_pgd_config: AdvPGDConfig, + loss_config: LossConfig, + sampling: SamplingType, + top_k: int, ) -> InterventionResult: - """Forward pass with only specified nodes active. + """Unified intervention evaluation: CI, stochastic, adversarial, and optionally ablated. Args: - model: ComponentModel to run intervention on. - tokens: Input tokens of shape [1, seq]. - active_nodes: List of (layer, seq_pos, component_idx) tuples specifying which nodes to activate. + active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. + Used for CI, stochastic, and adversarial masking. + nodes_to_ablate: If provided, nodes to ablate in ablated (full target model minus these). + The frontend computes this as all_graph_nodes - selected_nodes. + If None, ablated is skipped. + loss_config: Loss for PGD adversary to maximize and for reporting metrics. + sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. - tokenizer: Tokenizer for decoding tokens. - - Returns: - InterventionResult with input tokens and top-k predictions per position. """ - seq_len = tokens.shape[1] device = tokens.device - # Build component masks: all zeros, then set 1s for active nodes - component_masks: dict[str, Float[Tensor, "1 seq C"]] = {} - for layer_name, C in model.module_to_c.items(): - component_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) + # Compute natural CI alive masks (the model's own binarized CI, independent of graph) + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = { + k: v > 0 for k, v in ci_outputs.lower_leaky.items() + } + # Build binary CI masks from active nodes (selected = 1, rest = 0) + ci_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name, C in model.module_to_c.items(): + ci_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) for layer, seq_pos, c_idx in active_nodes: - assert layer in component_masks, f"Layer {layer} not in model" - assert 0 <= seq_pos < seq_len, f"seq_pos {seq_pos} out of bounds [0, {seq_len})" - assert 0 <= c_idx < model.module_to_c[layer], ( - f"component_idx {c_idx} out of bounds [0, {model.module_to_c[layer]})" + ci_masks[layer][0, seq_pos, c_idx] = 1.0 + assert alive_masks[layer][0, seq_pos, c_idx], ( + f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) - component_masks[layer][0, seq_pos, c_idx] = 1.0 - mask_infos = make_mask_infos(component_masks, routing_masks="all") + with torch.no_grad(), bf16_autocast(): + # Target forward (unmasked) + target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) + + # CI forward (binary mask) + ci_mask_infos = make_mask_infos(ci_masks, routing_masks="all") + ci_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ci_mask_infos) + + # Stochastic forward: ci + (1-ci) * uniform + stoch_masks = { + layer: ci_masks[layer] + (1 - ci_masks[layer]) * torch.rand_like(ci_masks[layer]) + for layer in ci_masks + } + stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") + stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) + + # Target-sans forward (only if nodes_to_ablate provided) + ts_logits: Float[Tensor, "1 seq vocab"] | None = None + if nodes_to_ablate is not None: + ts_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name in ci_masks: + ts_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) + for layer, seq_pos, c_idx in nodes_to_ablate: + ts_masks[layer][0, seq_pos, c_idx] = 0.0 + weight_deltas = model.calc_weight_deltas() + ts_wd = { + k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() + } + ts_mask_infos = make_mask_infos( + ts_masks, routing_masks="all", weight_deltas_and_masks=ts_wd + ) + ts_logits = model(tokens, mask_infos=ts_mask_infos) + # Adversarial: PGD optimizes alive-but-unselected components + adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=ci_masks, + alive_masks=alive_masks, + adv_config=adv_pgd_config, + target_out=target_logits, + loss_config=loss_config, + ) + # Non-alive positions get uniform random fill + adv_masks = interpolate_pgd_mask(ci_masks, adv_sources) + with torch.no_grad(): + for layer in adv_masks: + non_alive = ~alive_masks[layer] + adv_masks[layer][non_alive] = torch.rand(int(non_alive.sum().item()), device=device) with torch.no_grad(), bf16_autocast(): - # SPD model forward pass (with component masks) - spd_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=mask_infos) - spd_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(spd_logits, dim=-1) + adv_mask_infos = make_mask_infos(adv_masks, routing_masks="all") + adv_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=adv_mask_infos) + + # Extract predictions and loss metrics + device_str = str(device) + with torch.no_grad(): + ci_preds = _extract_topk_predictions(ci_logits, target_logits, tokenizer, top_k) + stoch_preds = _extract_topk_predictions(stoch_logits, target_logits, tokenizer, top_k) + adv_preds = _extract_topk_predictions(adv_logits, target_logits, tokenizer, top_k) + + ci_loss = float( + compute_recon_loss(ci_logits, loss_config, target_logits, device_str).item() + ) + stoch_loss = float( + compute_recon_loss(stoch_logits, loss_config, target_logits, device_str).item() + ) + adv_loss = float( + compute_recon_loss(adv_logits, loss_config, target_logits, device_str).item() + ) - # Target model forward pass (no masks) - target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) - target_out_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(target_logits, dim=-1) - - # Get top-k predictions per position (based on SPD model's top-k) - predictions_per_position: list[list[tuple[str, int, float, float, float, float]]] = [] - for pos in range(seq_len): - pos_spd_probs = spd_probs[0, pos] - pos_spd_logits = spd_logits[0, pos] - pos_target_out_probs = target_out_probs[0, pos] - pos_target_logits = target_logits[0, pos] - top_probs, top_ids = torch.topk(pos_spd_probs, top_k) - - pos_predictions: list[tuple[str, int, float, float, float, float]] = [] - for spd_prob, token_id in zip(top_probs, top_ids, strict=True): - tid = int(token_id.item()) - token_str = tokenizer.get_tok_display(tid) - target_prob = float(pos_target_out_probs[tid].item()) - target_logit = float(pos_target_logits[tid].item()) - pos_predictions.append( - ( - token_str, - tid, - float(spd_prob.item()), - float(pos_spd_logits[tid].item()), - target_prob, - target_logit, - ) + ts_preds: list[list[TokenPrediction]] | None = None + ts_loss: float | None = None + if ts_logits is not None: + ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) + ts_loss = float( + compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() ) - predictions_per_position.append(pos_predictions) - # Decode input tokens + label: LabelPredictions | None = None + if isinstance(loss_config, CELossConfig | LogitLossConfig): + pos, tid = loss_config.position, loss_config.label_token + ts_label = ( + _extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid) + if ts_logits is not None + else None + ) + label = LabelPredictions( + position=pos, + ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), + stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), + adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), + ablated=ts_label, + ) + input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) return InterventionResult( input_tokens=input_tokens, - predictions_per_position=predictions_per_position, + ci=ci_preds, + stochastic=stoch_preds, + adversarial=adv_preds, + ablated=ts_preds, + ci_loss=ci_loss, + stochastic_loss=stoch_loss, + adversarial_loss=adv_loss, + ablated_loss=ts_loss, + label=label, ) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 6b2b09552..f64593237 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -6,10 +6,13 @@ Interpretations are stored separately at SPD_OUT_DIR/autointerp//. """ +import fcntl import hashlib import io import json +import os import sqlite3 +from contextlib import contextmanager from pathlib import Path from typing import Literal @@ -17,14 +20,35 @@ from pydantic import BaseModel from spd.app.backend.compute import Edge, Node -from spd.app.backend.optim_cis import CELossConfig, KLLossConfig, LossConfig, MaskType -from spd.settings import REPO_ROOT +from spd.app.backend.optim_cis import ( + CELossConfig, + KLLossConfig, + LogitLossConfig, + MaskType, + PositionalLossConfig, +) +from spd.settings import SPD_OUT_DIR GraphType = Literal["standard", "optimized", "manual"] -# Persistent data directories -_APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = SPD_OUT_DIR / "app" / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path. + + Checks env vars in order: + 1. SPD_INVESTIGATION_DIR - investigation mode, db at dir/app.db + 2. SPD_APP_DB_PATH - explicit override + 3. Default: SPD_OUT_DIR/app/prompt_attr.db + """ + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + return Path(investigation_dir) / "app.db" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH class Run(BaseModel): @@ -43,6 +67,11 @@ class PromptRecord(BaseModel): is_custom: bool = False +class PgdConfig(BaseModel): + n_steps: int + step_size: float + + class OptimizationParams(BaseModel): """Optimization parameters that affect graph computation.""" @@ -51,9 +80,12 @@ class OptimizationParams(BaseModel): pnorm: float beta: float mask_type: MaskType - loss: LossConfig - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + loss: PositionalLossConfig + pgd: PgdConfig | None = None + # Computed metrics (persisted for display on reload) + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + adv_pgd_label_prob: float | None = None class StoredGraph(BaseModel): @@ -66,9 +98,11 @@ class StoredGraph(BaseModel): # Core graph data (all types) edges: list[Edge] + edges_abs: list[Edge] | None = ( + None # absolute-target variant (∂|y|/∂x · x), None for old graphs + ) ci_masked_out_logits: torch.Tensor # [seq, vocab] target_out_logits: torch.Tensor # [seq, vocab] - adv_pgd_out_logits: torch.Tensor | None = None # [seq, vocab] adversarial PGD logits node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val (required for all graphs) node_subcomp_acts: dict[str, float] = {} # layer:seq:c_idx -> subcomp act (v_i^T @ a) @@ -85,17 +119,7 @@ class InterventionRunRecord(BaseModel): id: int graph_id: int selected_nodes: list[str] # node keys that were selected - result_json: str # JSON-encoded InterventionResponse - created_at: str - - -class ForkedInterventionRunRecord(BaseModel): - """A forked intervention run with modified tokens.""" - - id: int - intervention_run_id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result_json: str # JSON-encoded InterventionResponse + result_json: str # JSON-encoded InterventionResult created_at: str @@ -111,7 +135,8 @@ class PromptAttrDB: """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() + self._lock_path = self.db_path.with_suffix(".db.lock") self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None @@ -135,6 +160,16 @@ def __enter__(self) -> "PromptAttrDB": def __exit__(self, *args: object) -> None: self.close() + @contextmanager + def _write_lock(self): + """Acquire an exclusive file lock for write operations (NFS-safe).""" + with open(self._lock_path, "w") as lock_fd: + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + yield + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + # ------------------------------------------------------------------------- # Schema initialization # ------------------------------------------------------------------------- @@ -142,7 +177,7 @@ def __exit__(self, *args: object) -> None: def init_schema(self) -> None: """Initialize the database schema. Safe to call multiple times.""" conn = self._get_conn() - conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA journal_mode=DELETE") conn.execute("PRAGMA foreign_keys=ON") conn.executescript(""" CREATE TABLE IF NOT EXISTS runs ( @@ -178,12 +213,19 @@ def init_schema(self) -> None: adv_pgd_n_steps INTEGER, adv_pgd_step_size REAL, + -- Optimization metrics (NULL for non-optimized graphs) + ci_masked_label_prob REAL, + stoch_masked_label_prob REAL, + adv_pgd_label_prob REAL, + -- Manual graph params (NULL for non-manual graphs) included_nodes TEXT, -- JSON array of node keys in this graph included_nodes_hash TEXT, -- SHA256 hash of sorted JSON for uniqueness -- The actual graph data (JSON) edges_data TEXT NOT NULL, + -- Absolute-target edges (∂|y|/∂x · x), NULL for old graphs + edges_data_abs TEXT, -- Node CI values: "layer:seq:c_idx" -> ci_val (required for all graphs) node_ci_vals TEXT NOT NULL, -- Node subcomponent activations: "layer:seq:c_idx" -> v_i^T @ a @@ -216,24 +258,14 @@ def init_schema(self) -> None: id INTEGER PRIMARY KEY AUTOINCREMENT, graph_id INTEGER NOT NULL REFERENCES graphs(id), selected_nodes TEXT NOT NULL, -- JSON array of node keys - result TEXT NOT NULL, -- JSON InterventionResponse + result TEXT NOT NULL, -- JSON InterventionResult created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_intervention_runs_graph ON intervention_runs(graph_id); - - CREATE TABLE IF NOT EXISTS forked_intervention_runs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - intervention_run_id INTEGER NOT NULL REFERENCES intervention_runs(id) ON DELETE CASCADE, - token_replacements TEXT NOT NULL, -- JSON array of [seq_pos, new_token_id] tuples - result TEXT NOT NULL, -- JSON InterventionResponse - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_forked_intervention_runs_parent - ON forked_intervention_runs(intervention_run_id); """) + conn.commit() # ------------------------------------------------------------------------- @@ -242,15 +274,16 @@ def init_schema(self) -> None: def create_run(self, wandb_path: str) -> int: """Create a new run. Returns the run ID.""" - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO runs (wandb_path) VALUES (?)", - (wandb_path,), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO runs (wandb_path) VALUES (?)", + (wandb_path,), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_run_by_wandb_path(self, wandb_path: str) -> Run | None: """Get a run by its wandb path.""" @@ -308,19 +341,20 @@ def add_custom_prompt( Returns: The prompt ID (existing or newly created). """ - existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) - if existing_id is not None: - return existing_id + with self._write_lock(): + existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) + if existing_id is not None: + return existing_id - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", - (run_id, json.dumps(token_ids), context_length), - ) - prompt_id = cursor.lastrowid - assert prompt_id is not None - conn.commit() - return prompt_id + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", + (run_id, json.dumps(token_ids), context_length), + ) + prompt_id = cursor.lastrowid + assert prompt_id is not None + conn.commit() + return prompt_id def get_prompt(self, prompt_id: int) -> PromptRecord | None: """Get a prompt by ID.""" @@ -384,24 +418,26 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: "component_idx": n.component_idx, } - edges_json = json.dumps( - [ - { - "source": _node_to_dict(e.source), - "target": _node_to_dict(e.target), - "strength": e.strength, - "is_cross_seq": e.is_cross_seq, - } - for e in graph.edges - ] - ) + def _edges_to_json(edges: list[Edge]) -> str: + return json.dumps( + [ + { + "source": _node_to_dict(e.source), + "target": _node_to_dict(e.target), + "strength": e.strength, + "is_cross_seq": e.is_cross_seq, + } + for e in edges + ] + ) + + edges_json = _edges_to_json(graph.edges) + edges_abs_json = _edges_to_json(graph.edges_abs) if graph.edges_abs is not None else None buf = io.BytesIO() logits_dict: dict[str, torch.Tensor] = { "ci_masked": graph.ci_masked_out_logits, "target": graph.target_out_logits, } - if graph.adv_pgd_out_logits is not None: - logits_dict["adv_pgd"] = graph.adv_pgd_out_logits torch.save(logits_dict, buf) output_logits_blob = buf.getvalue() node_ci_vals_json = json.dumps(graph.node_ci_vals) @@ -417,6 +453,9 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: loss_config_hash: str | None = None adv_pgd_n_steps = None adv_pgd_step_size = None + ci_masked_label_prob = None + stoch_masked_label_prob = None + adv_pgd_label_prob = None if graph.optimization_params: imp_min_coeff = graph.optimization_params.imp_min_coeff @@ -426,8 +465,15 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: mask_type = graph.optimization_params.mask_type loss_config_json = graph.optimization_params.loss.model_dump_json() loss_config_hash = hashlib.sha256(loss_config_json.encode()).hexdigest() - adv_pgd_n_steps = graph.optimization_params.adv_pgd_n_steps - adv_pgd_step_size = graph.optimization_params.adv_pgd_step_size + adv_pgd_n_steps = ( + graph.optimization_params.pgd.n_steps if graph.optimization_params.pgd else None + ) + adv_pgd_step_size = ( + graph.optimization_params.pgd.step_size if graph.optimization_params.pgd else None + ) + ci_masked_label_prob = graph.optimization_params.ci_masked_label_prob + stoch_masked_label_prob = graph.optimization_params.stoch_masked_label_prob + adv_pgd_label_prob = graph.optimization_params.adv_pgd_label_prob # Extract manual-specific values (NULL for non-manual graphs) # Sort included_nodes and compute hash for reliable uniqueness @@ -437,64 +483,70 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: included_nodes_json = json.dumps(sorted(graph.included_nodes)) included_nodes_hash = hashlib.sha256(included_nodes_json.encode()).hexdigest() - try: - cursor = conn.execute( - """INSERT INTO graphs - (prompt_id, graph_type, - imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, loss_config_hash, - adv_pgd_n_steps, adv_pgd_step_size, - included_nodes, included_nodes_hash, - edges_data, output_logits, node_ci_vals, node_subcomp_acts) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - prompt_id, - graph.graph_type, - imp_min_coeff, - steps, - pnorm, - beta, - mask_type, - loss_config_json, - loss_config_hash, - adv_pgd_n_steps, - adv_pgd_step_size, - included_nodes_json, - included_nodes_hash, - edges_json, - output_logits_blob, - node_ci_vals_json, - node_subcomp_acts_json, - ), - ) - conn.commit() - graph_id = cursor.lastrowid - assert graph_id is not None - return graph_id - except sqlite3.IntegrityError as e: - match graph.graph_type: - case "standard": - raise ValueError( - f"Standard graph already exists for prompt_id={prompt_id}. " - "Use get_graphs() to retrieve existing graph or delete it first." - ) from e - case "optimized": - raise ValueError( - f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." - ) from e - case "manual": - # Get-or-create semantics: return existing graph ID - conn.rollback() - row = conn.execute( - """SELECT id FROM graphs - WHERE prompt_id = ? AND graph_type = 'manual' - AND included_nodes_hash = ?""", - (prompt_id, included_nodes_hash), - ).fetchone() - if row: - return row["id"] - # Should not happen if constraint triggered - raise ValueError("A manual graph with the same nodes already exists.") from e + with self._write_lock(): + try: + cursor = conn.execute( + """INSERT INTO graphs + (prompt_id, graph_type, + imp_min_coeff, steps, pnorm, beta, mask_type, + loss_config, loss_config_hash, + adv_pgd_n_steps, adv_pgd_step_size, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob, + included_nodes, included_nodes_hash, + edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + prompt_id, + graph.graph_type, + imp_min_coeff, + steps, + pnorm, + beta, + mask_type, + loss_config_json, + loss_config_hash, + adv_pgd_n_steps, + adv_pgd_step_size, + ci_masked_label_prob, + stoch_masked_label_prob, + adv_pgd_label_prob, + included_nodes_json, + included_nodes_hash, + edges_json, + edges_abs_json, + output_logits_blob, + node_ci_vals_json, + node_subcomp_acts_json, + ), + ) + conn.commit() + graph_id = cursor.lastrowid + assert graph_id is not None + return graph_id + except sqlite3.IntegrityError as e: + match graph.graph_type: + case "standard": + raise ValueError( + f"Standard graph already exists for prompt_id={prompt_id}. " + "Use get_graphs() to retrieve existing graph or delete it first." + ) from e + case "optimized": + raise ValueError( + f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." + ) from e + case "manual": + conn.rollback() + row = conn.execute( + """SELECT id FROM graphs + WHERE prompt_id = ? AND graph_type = 'manual' + AND included_nodes_hash = ?""", + (prompt_id, included_nodes_hash), + ).fetchone() + if row: + return row["id"] + raise ValueError( + "A manual graph with the same nodes already exists." + ) from e def _row_to_stored_graph(self, row: sqlite3.Row) -> StoredGraph: """Convert a database row to a StoredGraph.""" @@ -506,19 +558,22 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: component_idx=int(d["component_idx"]), ) - edges = [ - Edge( - source=_node_from_dict(e["source"]), - target=_node_from_dict(e["target"]), - strength=float(e["strength"]), - is_cross_seq=bool(e["is_cross_seq"]), - ) - for e in json.loads(row["edges_data"]) - ] + def _parse_edges(data: str) -> list[Edge]: + return [ + Edge( + source=_node_from_dict(e["source"]), + target=_node_from_dict(e["target"]), + strength=float(e["strength"]), + is_cross_seq=bool(e["is_cross_seq"]), + ) + for e in json.loads(data) + ] + + edges = _parse_edges(row["edges_data"]) + edges_abs = _parse_edges(row["edges_data_abs"]) if row["edges_data_abs"] else None logits_data = torch.load(io.BytesIO(row["output_logits"]), weights_only=True) ci_masked_out_logits: torch.Tensor = logits_data["ci_masked"] target_out_logits: torch.Tensor = logits_data["target"] - adv_pgd_out_logits: torch.Tensor | None = logits_data.get("adv_pgd") node_ci_vals: dict[str, float] = json.loads(row["node_ci_vals"]) node_subcomp_acts: dict[str, float] = json.loads(row["node_subcomp_acts"] or "{}") @@ -526,12 +581,18 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: if row["graph_type"] == "optimized": loss_config_data = json.loads(row["loss_config"]) loss_type = loss_config_data["type"] - assert loss_type in ("ce", "kl"), f"Unknown loss type: {loss_type}" - loss_config: LossConfig - if loss_type == "ce": - loss_config = CELossConfig(**loss_config_data) - else: - loss_config = KLLossConfig(**loss_config_data) + assert loss_type in ("ce", "kl", "logit"), f"Unknown loss type: {loss_type}" + loss_config: PositionalLossConfig + match loss_type: + case "ce": + loss_config = CELossConfig(**loss_config_data) + case "kl": + loss_config = KLLossConfig(**loss_config_data) + case "logit": + loss_config = LogitLossConfig(**loss_config_data) + pgd = None + if row["adv_pgd_n_steps"] is not None: + pgd = PgdConfig(n_steps=row["adv_pgd_n_steps"], step_size=row["adv_pgd_step_size"]) opt_params = OptimizationParams( imp_min_coeff=row["imp_min_coeff"], steps=row["steps"], @@ -539,8 +600,10 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: beta=row["beta"], mask_type=row["mask_type"], loss=loss_config, - adv_pgd_n_steps=row["adv_pgd_n_steps"], - adv_pgd_step_size=row["adv_pgd_step_size"], + pgd=pgd, + ci_masked_label_prob=row["ci_masked_label_prob"], + stoch_masked_label_prob=row["stoch_masked_label_prob"], + adv_pgd_label_prob=row["adv_pgd_label_prob"], ) # Parse manual-specific fields @@ -552,9 +615,9 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: id=row["id"], graph_type=row["graph_type"], edges=edges, + edges_abs=edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, - adv_pgd_out_logits=adv_pgd_out_logits, node_ci_vals=node_ci_vals, node_subcomp_acts=node_subcomp_acts, optimization_params=opt_params, @@ -572,9 +635,10 @@ def get_graphs(self, prompt_id: int) -> list[StoredGraph]: """ conn = self._get_conn() rows = conn.execute( - """SELECT id, graph_type, edges_data, output_logits, node_ci_vals, + """SELECT id, graph_type, edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes + loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob FROM graphs WHERE prompt_id = ? ORDER BY @@ -588,9 +652,11 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: """Retrieve a single graph by its ID. Returns (graph, prompt_id) or None.""" conn = self._get_conn() row = conn.execute( - """SELECT id, prompt_id, graph_type, edges_data, output_logits, node_ci_vals, - node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes + """SELECT id, prompt_id, graph_type, edges_data, edges_data_abs, output_logits, + node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, + mask_type, loss_config, adv_pgd_n_steps, adv_pgd_step_size, + included_nodes, ci_masked_label_prob, stoch_masked_label_prob, + adv_pgd_label_prob FROM graphs WHERE id = ?""", (graph_id,), @@ -599,23 +665,18 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: return None return (self._row_to_stored_graph(row), row["prompt_id"]) - def delete_graphs_for_prompt(self, prompt_id: int) -> int: - """Delete all graphs for a prompt. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) - conn.commit() - return cursor.rowcount - - def delete_graphs_for_run(self, run_id: int) -> int: - """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute( - """DELETE FROM graphs - WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", - (run_id,), - ) - conn.commit() - return cursor.rowcount + def delete_prompt(self, prompt_id: int) -> None: + """Delete a prompt and all its graphs, intervention runs, and forked runs.""" + with self._write_lock(): + conn = self._get_conn() + graph_ids_query = "SELECT id FROM graphs WHERE prompt_id = ?" + conn.execute( + f"DELETE FROM intervention_runs WHERE graph_id IN ({graph_ids_query})", + (prompt_id,), + ) + conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.execute("DELETE FROM prompts WHERE id = ?", (prompt_id,)) + conn.commit() # ------------------------------------------------------------------------- # Intervention run operations @@ -632,21 +693,22 @@ def save_intervention_run( Args: graph_id: The graph ID this run belongs to. selected_nodes: List of node keys that were selected. - result_json: JSON-encoded InterventionResponse. + result_json: JSON-encoded InterventionResult. Returns: The intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO intervention_runs (graph_id, selected_nodes, result) - VALUES (?, ?, ?)""", - (graph_id, json.dumps(selected_nodes), result_json), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO intervention_runs (graph_id, selected_nodes, result) + VALUES (?, ?, ?)""", + (graph_id, json.dumps(selected_nodes), result_json), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: """Get all intervention runs for a graph. @@ -679,102 +741,7 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: def delete_intervention_run(self, run_id: int) -> None: """Delete an intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) - conn.commit() - - def delete_intervention_runs_for_graph(self, graph_id: int) -> int: - """Delete all intervention runs for a graph. Returns count deleted.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) - conn.commit() - return cursor.rowcount - - # ------------------------------------------------------------------------- - # Forked intervention run operations - # ------------------------------------------------------------------------- - - def save_forked_intervention_run( - self, - intervention_run_id: int, - token_replacements: list[tuple[int, int]], - result_json: str, - ) -> int: - """Save a forked intervention run. - - Args: - intervention_run_id: The parent intervention run ID. - token_replacements: List of (seq_pos, new_token_id) tuples. - result_json: JSON-encoded InterventionResponse. - - Returns: - The forked intervention run ID. - """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) - VALUES (?, ?, ?)""", - (intervention_run_id, json.dumps(token_replacements), result_json), - ) - conn.commit() - fork_id = cursor.lastrowid - assert fork_id is not None - return fork_id - - def get_forked_intervention_runs( - self, intervention_run_id: int - ) -> list[ForkedInterventionRunRecord]: - """Get all forked runs for an intervention run. - - Args: - intervention_run_id: The parent intervention run ID. - - Returns: - List of forked intervention run records, ordered by creation time. - """ - conn = self._get_conn() - rows = conn.execute( - """SELECT id, intervention_run_id, token_replacements, result, created_at - FROM forked_intervention_runs - WHERE intervention_run_id = ? - ORDER BY created_at""", - (intervention_run_id,), - ).fetchall() - - return [ - ForkedInterventionRunRecord( - id=row["id"], - intervention_run_id=row["intervention_run_id"], - token_replacements=json.loads(row["token_replacements"]), - result_json=row["result"], - created_at=row["created_at"], - ) - for row in rows - ] - - def get_intervention_run(self, run_id: int) -> InterventionRunRecord | None: - """Get a single intervention run by ID.""" - conn = self._get_conn() - row = conn.execute( - """SELECT id, graph_id, selected_nodes, result, created_at - FROM intervention_runs - WHERE id = ?""", - (run_id,), - ).fetchone() - - if row is None: - return None - - return InterventionRunRecord( - id=row["id"], - graph_id=row["graph_id"], - selected_nodes=json.loads(row["selected_nodes"]), - result_json=row["result"], - created_at=row["created_at"], - ) - - def delete_forked_intervention_run(self, fork_id: int) -> None: - """Delete a forked intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) + conn.commit() diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 48403feb6..5ec2f15b5 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -13,6 +13,7 @@ from spd.configs import ImportanceMinimalityLossConfig, PGDInitStrategy, SamplingType from spd.metrics import importance_minimality_loss +from spd.metrics.pgd_utils import get_pgd_init_tensor, interpolate_pgd_mask from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.routing import AllLayersRouter @@ -48,16 +49,33 @@ class KLLossConfig(BaseModel): position: int -LossConfig = CELossConfig | KLLossConfig +class LogitLossConfig(BaseModel): + """Logit loss: maximize the pre-softmax logit for a specific token at a position.""" + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + + +class MeanKLLossConfig(BaseModel): + """Mean KL divergence loss: match target model distribution across all positions.""" + + type: Literal["mean_kl"] = "mean_kl" + coeff: float = 1.0 -def _compute_recon_loss( + +PositionalLossConfig = CELossConfig | KLLossConfig | LogitLossConfig +LossConfig = CELossConfig | KLLossConfig | LogitLossConfig | MeanKLLossConfig + + +def compute_recon_loss( logits: Tensor, loss_config: LossConfig, target_out: Tensor, device: str, ) -> Tensor: - """Compute recon loss (CE or KL) from model output logits at the configured position.""" + """Compute recon loss (CE, KL, or mean KL) from model output logits.""" match loss_config: case CELossConfig(position=pos, label_token=label_token): return F.cross_entropy( @@ -68,14 +86,13 @@ def _compute_recon_loss( target_probs = F.softmax(target_out[0, pos, :], dim=-1) pred_log_probs = F.log_softmax(logits[0, pos, :], dim=-1) return F.kl_div(pred_log_probs, target_probs, reduction="sum") - - -def _interpolate_masks( - ci: dict[str, Tensor], - sources: dict[str, Tensor], -) -> dict[str, Tensor]: - """Compute PGD component masks: ci + (1 - ci) * source.""" - return {layer: ci[layer] + (1 - ci[layer]) * sources[layer] for layer in ci} + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[0, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + # sum over vocab, mean over positions (consistent with batched version) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) @dataclass @@ -188,94 +205,6 @@ def create_optimizable_ci_params( ) -def compute_l0_stats( - ci_outputs: CIOutputs, - ci_alive_threshold: float, -) -> dict[str, float]: - """Compute L0 statistics for each layer.""" - stats: dict[str, float] = {} - for layer_name, layer_ci in ci_outputs.lower_leaky.items(): - l0_val = calc_ci_l_zero(layer_ci, ci_alive_threshold) - stats[f"l0/{layer_name}"] = l0_val - stats["l0/total"] = sum(stats.values()) - return stats - - -def compute_specific_pos_ce_kl( - model: ComponentModel, - batch: Tensor, - target_out: Tensor, - ci: dict[str, Tensor], - rounding_threshold: float, - loss_seq_pos: int, -) -> dict[str, float]: - """Compute CE and KL metrics for a specific sequence position. - - Args: - model: The ComponentModel. - batch: Input tokens of shape [1, seq_len]. - target_out: Target model output logits of shape [1, seq_len, vocab]. - ci: Causal importance values (lower_leaky) per layer. - rounding_threshold: Threshold for rounding CI values to binary masks. - loss_seq_pos: Sequence position to compute metrics for. - - Returns: - Dict with kl and ce_difference metrics for ci_masked, unmasked, and rounded_masked. - """ - assert batch.ndim == 2 and batch.shape[0] == 1, "Expected batch shape [1, seq_len]" - - # Get target logits at the specified position - target_logits = target_out[0, loss_seq_pos, :] # [vocab] - - def kl_vs_target(logits: Tensor) -> float: - """KL divergence between predicted and target logits at target position.""" - pos_logits = logits[0, loss_seq_pos, :] # [vocab] - target_probs = F.softmax(target_logits, dim=-1) - pred_log_probs = F.log_softmax(pos_logits, dim=-1) - return F.kl_div(pred_log_probs, target_probs, reduction="sum").item() - - def ce_vs_target(logits: Tensor) -> float: - """CE between predicted logits and target's argmax at target position.""" - pos_logits = logits[0, loss_seq_pos, :] # [vocab] - target_token = target_logits.argmax() - return F.cross_entropy(pos_logits.unsqueeze(0), target_token.unsqueeze(0)).item() - - # Target model CE (baseline) - target_ce = ce_vs_target(target_out) - - # CI masked - ci_mask_infos = make_mask_infos(ci) - with bf16_autocast(): - ci_masked_logits = model(batch, mask_infos=ci_mask_infos) - ci_masked_kl = kl_vs_target(ci_masked_logits) - ci_masked_ce = ce_vs_target(ci_masked_logits) - - # Unmasked (all components active) - unmasked_infos = make_mask_infos({k: torch.ones_like(v) for k, v in ci.items()}) - with bf16_autocast(): - unmasked_logits = model(batch, mask_infos=unmasked_infos) - unmasked_kl = kl_vs_target(unmasked_logits) - unmasked_ce = ce_vs_target(unmasked_logits) - - # Rounded masked (binary masks based on threshold) - rounded_mask_infos = make_mask_infos( - {k: (v > rounding_threshold).float() for k, v in ci.items()} - ) - with bf16_autocast(): - rounded_masked_logits = model(batch, mask_infos=rounded_mask_infos) - rounded_masked_kl = kl_vs_target(rounded_masked_logits) - rounded_masked_ce = ce_vs_target(rounded_masked_logits) - - return { - "kl_ci_masked": ci_masked_kl, - "kl_unmasked": unmasked_kl, - "kl_rounded_masked": rounded_masked_kl, - "ce_difference_ci_masked": ci_masked_ce - target_ce, - "ce_difference_unmasked": unmasked_ce - target_ce, - "ce_difference_rounded_masked": rounded_masked_ce - target_ce, - } - - @dataclass class OptimCIConfig: """Configuration for optimizing CI values on a single prompt.""" @@ -292,9 +221,9 @@ class OptimCIConfig: log_freq: int - # Loss config (exactly one of CE or KL) + # Loss config (CE or KL — must target a specific position) imp_min_config: ImportanceMinimalityLossConfig - loss_config: LossConfig + loss_config: PositionalLossConfig sampling: SamplingType @@ -306,43 +235,51 @@ class OptimCIConfig: ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) +class CISnapshot(BaseModel): + """Snapshot of alive component counts during CI optimization for visualization.""" + + step: int + total_steps: int + layers: list[str] + seq_len: int + initial_alive: list[list[int]] # layers × seq + current_alive: list[list[int]] # layers × seq + l0_total: float + loss: float + + +CISnapshotCallback = Callable[[CISnapshot], None] + + @dataclass class OptimizeCIResult: """Result from CI optimization including params and final metrics.""" params: OptimizableCIParams metrics: OptimizationMetrics - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None -def _run_adv_pgd( +def run_adv_pgd( model: ComponentModel, tokens: Tensor, - ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], + ci: dict[str, Float[Tensor, "1 seq C"]], alive_masks: dict[str, Bool[Tensor, "1 seq C"]], adv_config: AdvPGDConfig, - loss_config: LossConfig, target_out: Tensor, - device: str, + loss_config: LossConfig, ) -> dict[str, Float[Tensor, "1 seq C"]]: - """Run PGD to find adversarial sources maximizing reconstruction loss. + """Run PGD to find adversarial sources maximizing loss. Sources are optimized via signed gradient ascent. Only alive positions are optimized. Masks are computed as ci + (1 - ci) * source (same interpolation as training PGD). Returns detached adversarial source tensors. """ - ci_detached = {k: v.detach() for k, v in ci_lower_leaky.items()} + ci_detached = {k: v.detach() for k, v in ci.items()} adv_sources: dict[str, Tensor] = {} - for layer_name, ci in ci_detached.items(): - match adv_config.init: - case "random": - source = torch.rand_like(ci) - case "ones": - source = torch.ones_like(ci) - case "zeroes": - source = torch.zeros_like(ci) + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) source[~alive_masks[layer_name]] = 0.0 source.requires_grad_(True) adv_sources[layer_name] = source @@ -350,12 +287,13 @@ def _run_adv_pgd( source_list = list(adv_sources.values()) for _ in range(adv_config.n_steps): - mask_infos = make_mask_infos(_interpolate_masks(ci_detached, adv_sources)) + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) with bf16_autocast(): out = model(tokens, mask_infos=mask_infos) - loss = _compute_recon_loss(out, loss_config, target_out, device) + loss = compute_recon_loss(out, loss_config, target_out, str(tokens.device)) + grads = torch.autograd.grad(loss, source_list) with torch.no_grad(): for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): @@ -372,6 +310,7 @@ def optimize_ci_values( config: OptimCIConfig, device: str, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizeCIResult: """Optimize CI values for a single prompt. @@ -406,13 +345,40 @@ def optimize_ci_values( weight_deltas = model.calc_weight_deltas() + # Precompute snapshot metadata for CI visualization + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) progress_interval = max(1, config.steps // 20) # Report ~20 times during optimization + latest_loss: float = 0.0 for step in tqdm(range(config.steps), desc="Optimizing CI values"): - if on_progress is not None and step % progress_interval == 0: - on_progress(step, config.steps, "optimizing") + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params.create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) optimizer.zero_grad() @@ -444,82 +410,46 @@ def optimize_ci_values( p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, ) - recon_loss = _compute_recon_loss(recon_out, config.loss_config, target_out, device) + recon_loss = compute_recon_loss(recon_out, config.loss_config, target_out, device) total_loss = config.loss_config.coeff * recon_loss + imp_min_coeff * imp_min_loss + latest_loss = total_loss.item() # PGD adversarial loss (runs in tandem with recon) if config.adv_pgd is not None: - adv_sources = _run_adv_pgd( + adv_sources = run_adv_pgd( model=model, tokens=tokens, - ci_lower_leaky=ci_outputs.lower_leaky, + ci=ci_outputs.lower_leaky, alive_masks=alive_info.alive_masks, adv_config=config.adv_pgd, loss_config=config.loss_config, target_out=target_out, - device=device, ) pgd_mask_infos = make_mask_infos( - _interpolate_masks(ci_outputs.lower_leaky, adv_sources) + interpolate_pgd_mask(ci_outputs.lower_leaky, adv_sources) ) with bf16_autocast(): pgd_out = model(tokens, mask_infos=pgd_mask_infos) - pgd_loss = _compute_recon_loss(pgd_out, config.loss_config, target_out, device) + pgd_loss = compute_recon_loss(pgd_out, config.loss_config, target_out, device) total_loss = total_loss + config.loss_config.coeff * pgd_loss - if step % config.log_freq == 0 or step == config.steps - 1: - l0_stats = compute_l0_stats(ci_outputs, ci_alive_threshold=0.0) - - with torch.no_grad(): - ce_kl_stats = compute_specific_pos_ce_kl( - model=model, - batch=tokens, - target_out=target_out, - ci=ci_outputs.lower_leaky, - rounding_threshold=config.ce_kl_rounding_threshold, - loss_seq_pos=config.loss_config.position, - ) - - log_terms: dict[str, float] = { - "imp_min_loss": imp_min_loss.item(), - "total_loss": total_loss.item(), - "recon_loss": recon_loss.item(), - } - - if isinstance(config.loss_config, CELossConfig): - pos = config.loss_config.position - label_token = config.loss_config.label_token - recon_label_prob = F.softmax(recon_out[0, pos, :], dim=-1)[label_token] - log_terms["recon_masked_label_prob"] = recon_label_prob.item() - - with torch.no_grad(): - mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") - logits = model(tokens, mask_infos=mask_infos) - probs = F.softmax(logits[0, pos, :], dim=-1) - log_terms["ci_masked_label_prob"] = float(probs[label_token].item()) - - tqdm.write(f"\n--- Step {step} ---") - for name, value in log_terms.items(): - tqdm.write(f" {name}: {value:.6f}") - for name, value in l0_stats.items(): - tqdm.write(f" {name}: {value:.2f}") - for name, value in ce_kl_stats.items(): - tqdm.write(f" {name}: {value:.6f}") - total_loss.backward() optimizer.step() # Compute final metrics after optimization with torch.no_grad(): final_ci_outputs = ci_params.create_ci_outputs(model, device) - final_l0_stats = compute_l0_stats(final_ci_outputs, ci_alive_threshold=0.0) + + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci_outputs.lower_leaky.values() + ) final_ci_masked_label_prob: float | None = None final_stoch_masked_label_prob: float | None = None - if isinstance(config.loss_config, CELossConfig): + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): pos = config.loss_config.position label_token = config.loss_config.label_token @@ -541,29 +471,26 @@ def optimize_ci_values( final_stoch_masked_label_prob = float(stoch_probs[label_token].item()) # Adversarial PGD final evaluation (needs gradients for PGD, so outside no_grad block) - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None final_adv_pgd_label_prob: float | None = None if config.adv_pgd is not None: - final_adv_sources = _run_adv_pgd( + final_adv_sources = run_adv_pgd( model=model, tokens=tokens, - ci_lower_leaky=final_ci_outputs.lower_leaky, + ci=final_ci_outputs.lower_leaky, alive_masks=alive_info.alive_masks, adv_config=config.adv_pgd, - loss_config=config.loss_config, target_out=target_out, - device=device, + loss_config=config.loss_config, ) with torch.no_grad(): adv_pgd_masks = make_mask_infos( - _interpolate_masks(final_ci_outputs.lower_leaky, final_adv_sources) + interpolate_pgd_mask(final_ci_outputs.lower_leaky, final_adv_sources) ) with bf16_autocast(): adv_logits = model(tokens, mask_infos=adv_pgd_masks) - adv_pgd_out_logits = adv_logits[0].detach() # [seq, vocab] - if isinstance(config.loss_config, CELossConfig): + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): pos = config.loss_config.position label_token = config.loss_config.label_token adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) @@ -573,16 +500,335 @@ def optimize_ci_values( ci_masked_label_prob=final_ci_masked_label_prob, stoch_masked_label_prob=final_stoch_masked_label_prob, adv_pgd_label_prob=final_adv_pgd_label_prob, - l0_total=final_l0_stats["l0/total"], + l0_total=total_l0, ) return OptimizeCIResult( params=ci_params, metrics=metrics, - adv_pgd_out_logits=adv_pgd_out_logits, ) +def compute_recon_loss_batched( + logits: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, + target_out: Float[Tensor, "N seq vocab"], + device: str, +) -> Float[Tensor, " N"]: + """Compute per-element reconstruction loss for batched logits.""" + match loss_config: + case CELossConfig(position=pos, label_token=label_token): + labels = torch.full((logits.shape[0],), label_token, device=device) + return F.cross_entropy(logits[:, pos, :], labels, reduction="none") + case KLLossConfig(position=pos): + target_probs = F.softmax(target_out[:, pos, :], dim=-1) + pred_log_probs = F.log_softmax(logits[:, pos, :], dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1) + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[:, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) + + +def importance_minimality_loss_per_element( + ci_upper_leaky_batched: dict[str, Float[Tensor, "N seq C"]], + n_batch: int, + current_frac_of_training: float, + pnorm: float, + beta: float, + eps: float, + p_anneal_start_frac: float, + p_anneal_final_p: float | None, + p_anneal_end_frac: float, +) -> Float[Tensor, " N"]: + """Compute importance minimality loss independently for each batch element.""" + losses = [] + for i in range(n_batch): + element_ci = {k: v[i : i + 1] for k, v in ci_upper_leaky_batched.items()} + losses.append( + importance_minimality_loss( + ci_upper_leaky=element_ci, + current_frac_of_training=current_frac_of_training, + pnorm=pnorm, + beta=beta, + eps=eps, + p_anneal_start_frac=p_anneal_start_frac, + p_anneal_final_p=p_anneal_final_p, + p_anneal_end_frac=p_anneal_end_frac, + ) + ) + return torch.stack(losses) + + +def run_adv_pgd_batched( + model: ComponentModel, + tokens: Float[Tensor, "N seq"], + ci: dict[str, Float[Tensor, "N seq C"]], + alive_masks: dict[str, Bool[Tensor, "N seq C"]], + adv_config: AdvPGDConfig, + target_out: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, +) -> dict[str, Float[Tensor, "N seq C"]]: + """Run PGD adversary with batched tensors. Returns detached adversarial sources.""" + ci_detached = {k: v.detach() for k, v in ci.items()} + + adv_sources: dict[str, Tensor] = {} + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) + source[~alive_masks[layer_name]] = 0.0 + source.requires_grad_(True) + adv_sources[layer_name] = source + + source_list = list(adv_sources.values()) + + for _ in range(adv_config.n_steps): + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) + + with bf16_autocast(): + out = model(tokens, mask_infos=mask_infos) + + losses = compute_recon_loss_batched(out, loss_config, target_out, str(tokens.device)) + loss = losses.sum() + + grads = torch.autograd.grad(loss, source_list) + with torch.no_grad(): + for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): + source.add_(adv_config.step_size * grad.sign()) + source.clamp_(0.0, 1.0) + source[~alive_masks[layer_name]] = 0.0 + + return {k: v.detach() for k, v in adv_sources.items()} + + +def optimize_ci_values_batched( + model: ComponentModel, + tokens: Float[Tensor, "1 seq"], + configs: list[OptimCIConfig], + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizeCIResult]: + """Optimize CI values for N sparsity coefficients in a single batched loop. + + All configs must share the same loss_config, steps, mask_type, adv_pgd settings — + only imp_min_config.coeff varies between them. + """ + N = len(configs) + assert N > 0 + + config = configs[0] + imp_min_coeffs = torch.tensor([c.imp_min_config.coeff for c in configs], device=device) + for c in configs: + assert c.imp_min_config.coeff is not None + + model.requires_grad_(False) + + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + initial_ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + target_out = output_with_cache.output.detach() + + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky) + + ci_params_list = [ + create_optimizable_ci_params( + alive_info=alive_info, + initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, + ) + for _ in range(N) + ] + + weight_deltas = model.calc_weight_deltas() + + all_params: list[Tensor] = [] + for ci_params in ci_params_list: + all_params.extend(ci_params.get_parameters()) + + optimizer = optim.AdamW(all_params, lr=config.lr, weight_decay=config.weight_decay) + + tokens_batched = tokens.expand(N, -1) + target_out_batched = target_out.expand(N, -1, -1) + + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + + progress_interval = max(1, config.steps // 20) + latest_loss = 0.0 + + for step in tqdm(range(config.steps), desc="Optimizing CI values (batched)"): + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params_list[0].create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) + + optimizer.zero_grad() + + ci_outputs_list = [cp.create_ci_outputs(model, device) for cp in ci_params_list] + + layers = list(ci_outputs_list[0].lower_leaky.keys()) + batched_ci_lower_leaky: dict[str, Tensor] = { + layer: torch.cat([co.lower_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + batched_ci_upper_leaky: dict[str, Tensor] = { + layer: torch.cat([co.upper_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + + match config.mask_type: + case "stochastic": + recon_mask_infos = calc_stochastic_component_mask_info( + causal_importances=batched_ci_lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + case "ci": + recon_mask_infos = make_mask_infos(component_masks=batched_ci_lower_leaky) + + with bf16_autocast(): + recon_out = model(tokens_batched, mask_infos=recon_mask_infos) + + imp_min_losses = importance_minimality_loss_per_element( + ci_upper_leaky_batched=batched_ci_upper_leaky, + n_batch=N, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + beta=config.imp_min_config.beta, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) + + recon_losses = compute_recon_loss_batched( + recon_out, config.loss_config, target_out_batched, device + ) + + loss_coeff = config.loss_config.coeff + total_loss = (loss_coeff * recon_losses + imp_min_coeffs * imp_min_losses).sum() + latest_loss = total_loss.item() + + if config.adv_pgd is not None: + batched_alive_masks = { + k: v.expand(N, -1, -1) for k, v in alive_info.alive_masks.items() + } + adv_sources = run_adv_pgd_batched( + model=model, + tokens=tokens_batched, + ci=batched_ci_lower_leaky, + alive_masks=batched_alive_masks, + adv_config=config.adv_pgd, + target_out=target_out_batched, + loss_config=config.loss_config, + ) + pgd_masks = interpolate_pgd_mask(batched_ci_lower_leaky, adv_sources) + pgd_mask_infos = make_mask_infos(pgd_masks) + with bf16_autocast(): + pgd_out = model(tokens_batched, mask_infos=pgd_mask_infos) + pgd_losses = compute_recon_loss_batched( + pgd_out, config.loss_config, target_out_batched, device + ) + total_loss = total_loss + (loss_coeff * pgd_losses).sum() + + total_loss.backward() + optimizer.step() + + # Compute final metrics per element + results: list[OptimizeCIResult] = [] + for ci_params in ci_params_list: + with torch.no_grad(): + final_ci = ci_params.create_ci_outputs(model, device) + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci.lower_leaky.values() + ) + + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + + ci_mask_infos = make_mask_infos(final_ci.lower_leaky, routing_masks="all") + ci_logits = model(tokens, mask_infos=ci_mask_infos) + ci_probs = F.softmax(ci_logits[0, pos, :], dim=-1) + ci_masked_label_prob = float(ci_probs[label_token].item()) + + stoch_mask_infos = calc_stochastic_component_mask_info( + causal_importances=final_ci.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + stoch_logits = model(tokens, mask_infos=stoch_mask_infos) + stoch_probs = F.softmax(stoch_logits[0, pos, :], dim=-1) + stoch_masked_label_prob = float(stoch_probs[label_token].item()) + + adv_pgd_label_prob: float | None = None + if config.adv_pgd is not None: + final_adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=final_ci.lower_leaky, + alive_masks=alive_info.alive_masks, + adv_config=config.adv_pgd, + target_out=target_out, + loss_config=config.loss_config, + ) + with torch.no_grad(): + adv_masks = make_mask_infos( + interpolate_pgd_mask(final_ci.lower_leaky, final_adv_sources) + ) + with bf16_autocast(): + adv_logits = model(tokens, mask_infos=adv_masks) + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) + adv_pgd_label_prob = float(adv_probs[label_token].item()) + + results.append( + OptimizeCIResult( + params=ci_params, + metrics=OptimizationMetrics( + ci_masked_label_prob=ci_masked_label_prob, + stoch_masked_label_prob=stoch_masked_label_prob, + adv_pgd_label_prob=adv_pgd_label_prob, + l0_total=total_l0, + ), + ) + ) + + return results + + def get_out_dir() -> Path: """Get the output directory for optimization results.""" out_dir = Path(__file__).parent / "out" diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index b7a6f8ed3..7a3dfadcf 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -7,10 +7,14 @@ from spd.app.backend.routers.data_sources import router as data_sources_router from spd.app.backend.routers.dataset_attributions import router as dataset_attributions_router from spd.app.backend.routers.dataset_search import router as dataset_search_router +from spd.app.backend.routers.graph_interp import router as graph_interp_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router from spd.app.backend.routers.pretrain_info import router as pretrain_info_router from spd.app.backend.routers.prompts import router as prompts_router +from spd.app.backend.routers.run_registry import router as run_registry_router from spd.app.backend.routers.runs import router as runs_router __all__ = [ @@ -21,9 +25,13 @@ "data_sources_router", "dataset_attributions_router", "dataset_search_router", + "graph_interp_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", "pretrain_info_router", "prompts_router", + "run_registry_router", "runs_router", ] diff --git a/spd/app/backend/routers/clusters.py b/spd/app/backend/routers/clusters.py index e2dbae37a..b2dc1d5b9 100644 --- a/spd/app/backend/routers/clusters.py +++ b/spd/app/backend/routers/clusters.py @@ -10,6 +10,7 @@ from spd.app.backend.utils import log_errors from spd.base_config import BaseConfig from spd.settings import SPD_OUT_DIR +from spd.topology import TransformerTopology router = APIRouter(prefix="/api/clusters", tags=["clusters"]) @@ -86,4 +87,17 @@ def load_cluster_mapping(file_path: str) -> ClusterMapping: f"but loaded run is '{run_state.run.wandb_path}'", ) - return ClusterMapping(mapping=parsed.clusters) + canonical_clusters = _to_canonical_keys(parsed.clusters, run_state.topology) + return ClusterMapping(mapping=canonical_clusters) + + +def _to_canonical_keys( + clusters: dict[str, int | None], topology: TransformerTopology +) -> dict[str, int | None]: + """Convert concrete component keys (e.g. 'h.3.mlp.down_proj:5') to canonical (e.g. '3.mlp.down:5').""" + result: dict[str, int | None] = {} + for key, cluster_id in clusters.items(): + layer, idx = key.rsplit(":", 1) + canonical_layer = topology.target_to_canon(layer) + result[f"{canonical_layer}:{idx}"] = cluster_id + return result diff --git a/spd/app/backend/routers/data_sources.py b/spd/app/backend/routers/data_sources.py index 5287d91bd..6888b339f 100644 --- a/spd/app/backend/routers/data_sources.py +++ b/spd/app/backend/routers/data_sources.py @@ -28,15 +28,21 @@ class AutointerpInfo(BaseModel): class AttributionsInfo(BaseModel): subrun_id: str - n_batches_processed: int n_tokens_processed: int ci_threshold: float +class GraphInterpInfo(BaseModel): + subrun_id: str + config: dict[str, Any] | None + label_counts: dict[str, int] + + class DataSourcesResponse(BaseModel): harvest: HarvestInfo | None autointerp: AutointerpInfo | None attributions: AttributionsInfo | None + graph_interp: GraphInterpInfo | None router = APIRouter(prefix="/api/data_sources", tags=["data_sources"]) @@ -70,13 +76,21 @@ def get_data_sources(loaded: DepLoadedRun) -> DataSourcesResponse: storage = loaded.attributions.get_attributions() attributions_info = AttributionsInfo( subrun_id=loaded.attributions.subrun_id, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, ci_threshold=storage.ci_threshold, ) + graph_interp_info: GraphInterpInfo | None = None + if loaded.graph_interp is not None: + graph_interp_info = GraphInterpInfo( + subrun_id=loaded.graph_interp.subrun_id, + config=loaded.graph_interp.get_config(), + label_counts=loaded.graph_interp.get_label_counts(), + ) + return DataSourcesResponse( harvest=harvest_info, autointerp=autointerp_info, attributions=attributions_info, + graph_interp=graph_interp_info, ) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 4c3d07753..178eefc72 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -7,46 +7,43 @@ from typing import Annotated, Literal from fastapi import APIRouter, HTTPException, Query -from jaxtyping import Float from pydantic import BaseModel -from torch import Tensor from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry -from spd.dataset_attributions.storage import DatasetAttributionStorage +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool - n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -54,91 +51,67 @@ class ComponentAttributions(BaseModel): ) -def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str: - """Translate canonical layer + idx to concrete storage key. - - "embed" maps to the concrete embedding path (e.g. "wte") in storage. - "output" is a pseudo-layer used as-is in storage. - """ - if canonical_layer == "output": - return f"output:{component_idx}" - concrete = loaded.topology.canon_to_target(canonical_layer) - return f"{concrete}:{component_idx}" - - def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) return loaded.attributions.get_attributions() -def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" - if not storage.has_source(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as source in attributions", - ) - - -def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" - if not storage.has_target(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as target in attributions", - ) - - -def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" - return loaded.topology.get_unembed_weight() - - def _to_api_entries( - loaded: DepLoadedRun, entries: list[StorageEntry] + entries: list[StorageEntry], loaded: DepLoadedRun ) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format with canonical keys.""" - - def _canonicalize_layer(layer: str) -> str: - if layer == "output": - return layer - return loaded.topology.target_to_canon(layer) - return [ DatasetAttributionEntry( - component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}", - layer=_canonicalize_layer(e.layer), + component_key=e.component_key, + layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + loaded: DepLoadedRun, + component_key: str, + k: int, + metric: AttrMetric, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric), loaded + ), + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric), loaded + ), + positive_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "positive", metric), loaded + ), + negative_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "negative", metric), loaded + ), + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" if loaded.attributions is None: return DatasetAttributionMetadata( available=False, - n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() return DatasetAttributionMetadata( available=True, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -150,58 +123,18 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all metrics.""" storage = _require_storage(loaded) - component_key = _to_concrete_key(layer, component_idx, loaded) - - # Component can be both a source and a target, so we need to check both - is_source = storage.has_source(component_key) - is_target = storage.has_target(component_key) - - if not is_source and not is_target: - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found in attributions", - ) - - w_unembed = _get_w_unembed(loaded) if is_source else None - - return ComponentAttributions( - positive_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "positive") - ) - if is_target - else [], - negative_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "negative") - ) - if is_target - else [], - positive_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "positive", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], - negative_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "negative", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], + component_key = f"{layer}:{component_idx}" + + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, loaded, component_key, k, metric + ) + for metric in ATTR_METRICS + } ) @@ -213,16 +146,11 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = _to_concrete_key(layer, component_idx, loaded) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if layer == "output" else None - return _to_api_entries( - loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed) + storage.get_top_sources(f"{layer}:{component_idx}", k, sign, metric), loaded ) @@ -234,35 +162,9 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(layer, component_idx, loaded) - _require_source(storage, source_key) - - w_unembed = _get_w_unembed(loaded) - return _to_api_entries( - loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed) + storage.get_top_targets(f"{layer}:{component_idx}", k, sign, metric), loaded ) - - -@router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") -@log_errors -def get_attribution_between( - source_layer: str, - source_idx: int, - target_layer: str, - target_idx: int, - loaded: DepLoadedRun, -) -> float: - """Get attribution strength from source component to target component.""" - storage = _require_storage(loaded) - source_key = _to_concrete_key(source_layer, source_idx, loaded) - target_key = _to_concrete_key(target_layer, target_idx, loaded) - _require_source(storage, source_key) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py new file mode 100644 index 000000000..003075b4d --- /dev/null +++ b/spd/app/backend/routers/graph_interp.py @@ -0,0 +1,244 @@ +"""Graph interpretation endpoints. + +Serves context-aware component labels (output/input/unified) and the +prompt-edge graph produced by the graph_interp pipeline. +""" + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors +from spd.graph_interp.schemas import LabelResult +from spd.topology import TransformerTopology + +MAX_GRAPH_NODES = 500 + + +_ALREADY_CANONICAL = {"embed", "output"} + + +def _concrete_to_canonical_key(concrete_key: str, topology: TransformerTopology) -> str: + layer, idx = concrete_key.rsplit(":", 1) + if layer in _ALREADY_CANONICAL: + return concrete_key + canonical = topology.target_to_canon(layer) + return f"{canonical}:{idx}" + + +def _canonical_to_concrete_key( + canonical_layer: str, component_idx: int, topology: TransformerTopology +) -> str: + concrete = topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +# -- Schemas ------------------------------------------------------------------- + + +class GraphInterpHeadline(BaseModel): + label: str + confidence: str + output_label: str | None + input_label: str | None + + +class LabelDetail(BaseModel): + label: str + confidence: str + reasoning: str + prompt: str + + +class GraphInterpDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + + +class PromptEdgeResponse(BaseModel): + related_key: str + pass_name: str + attribution: float + related_label: str | None + related_confidence: str | None + token_str: str | None + + +class GraphInterpComponentDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + edges: list[PromptEdgeResponse] + + +class GraphNode(BaseModel): + component_key: str + label: str + confidence: str + + +class GraphEdge(BaseModel): + source: str + target: str + attribution: float + pass_name: str + + +class ModelGraphResponse(BaseModel): + nodes: list[GraphNode] + edges: list[GraphEdge] + + +# -- Router -------------------------------------------------------------------- + +router = APIRouter(prefix="/api/graph_interp", tags=["graph_interp"]) + + +@router.get("/labels") +@log_errors +def get_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + repo = loaded.graph_interp + if repo is None: + return {} + + topology = loaded.topology + unified = repo.get_all_unified_labels() + output = repo.get_all_output_labels() + input_ = repo.get_all_input_labels() + + all_keys = set(unified) | set(output) | set(input_) + result: dict[str, GraphInterpHeadline] = {} + + for concrete_key in all_keys: + u = unified.get(concrete_key) + o = output.get(concrete_key) + i = input_.get(concrete_key) + + label = u or o or i + assert label is not None + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + + result[canonical_key] = GraphInterpHeadline( + label=label.label, + confidence=label.confidence, + output_label=o.label if o else None, + input_label=i.label if i else None, + ) + + return result + + +def _to_detail(label: LabelResult | None) -> LabelDetail | None: + if label is None: + return None + return LabelDetail( + label=label.label, + confidence=label.confidence, + reasoning=label.reasoning, + prompt=label.prompt, + ) + + +@router.get("/labels/{layer}/{c_idx}") +@log_errors +def get_label_detail(layer: str, c_idx: int, loaded: DepLoadedRun) -> GraphInterpDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + concrete_key = _canonical_to_concrete_key(layer, c_idx, loaded.topology) + + return GraphInterpDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + ) + + +@router.get("/detail/{layer}/{c_idx}") +@log_errors +def get_component_detail( + layer: str, c_idx: int, loaded: DepLoadedRun +) -> GraphInterpComponentDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + concrete_key = _canonical_to_concrete_key(layer, c_idx, topology) + + raw_edges = repo.get_prompt_edges(concrete_key) + tokenizer = loaded.tokenizer + edges = [] + for e in raw_edges: + rel_layer, rel_idx = e.related_key.rsplit(":", 1) + token_str = tokenizer.decode([int(rel_idx)]) if rel_layer in ("embed", "output") else None + edges.append( + PromptEdgeResponse( + related_key=_concrete_to_canonical_key(e.related_key, topology), + pass_name=e.pass_name, + attribution=e.attribution, + related_label=e.related_label, + related_confidence=e.related_confidence, + token_str=token_str, + ) + ) + + return GraphInterpComponentDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + edges=edges, + ) + + +@router.get("/graph") +@log_errors +def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + + unified = repo.get_all_unified_labels() + nodes = [] + for concrete_key, label in unified.items(): + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + nodes.append( + GraphNode( + component_key=canonical_key, + label=label.label, + confidence=label.confidence, + ) + ) + + nodes = nodes[:MAX_GRAPH_NODES] + node_keys = {n.component_key for n in nodes} + + raw_edges = repo.get_all_prompt_edges() + edges = [] + for e in raw_edges: + comp_canon = _concrete_to_canonical_key(e.component_key, topology) + rel_canon = _concrete_to_canonical_key(e.related_key, topology) + + match e.pass_name: + case "output": + source, target = comp_canon, rel_canon + case "input": + source, target = rel_canon, comp_canon + + if source not in node_keys or target not in node_keys: + continue + + edges.append( + GraphEdge( + source=source, + target=target, + attribution=e.attribution, + pass_name=e.pass_name, + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index a51b1649c..12cddb3cb 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -19,26 +19,96 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, + MAX_OUTPUT_NODES_PER_POS, Edge, + compute_intervention, compute_prompt_attributions, compute_prompt_attributions_optimized, + compute_prompt_attributions_optimized_batched, +) +from spd.app.backend.database import ( + GraphType, + OptimizationParams, + PgdConfig, + PromptAttrDB, + StoredGraph, ) -from spd.app.backend.database import GraphType, OptimizationParams, StoredGraph from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.optim_cis import ( AdvPGDConfig, CELossConfig, + CISnapshot, KLLossConfig, + LogitLossConfig, LossConfig, MaskType, + MeanKLLossConfig, OptimCIConfig, + PositionalLossConfig, ) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors -from spd.configs import ImportanceMinimalityLossConfig +from spd.configs import ImportanceMinimalityLossConfig, SamplingType from spd.log import logger +from spd.models.component_model import ComponentModel +from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device +NON_INTERVENTABLE_LAYERS = {"embed", "output"} + + +def _save_base_intervention_run( + graph_id: int, + model: ComponentModel, + tokens: torch.Tensor, + node_ci_vals: dict[str, float], + tokenizer: AppTokenizer, + topology: TransformerTopology, + db: PromptAttrDB, + sampling: SamplingType, + loss_config: LossConfig | None = None, +) -> None: + """Compute intervention for all interventable nodes and save as an intervention run.""" + interventable_keys = [ + k + for k, ci in node_ci_vals.items() + if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 + ] + if not interventable_keys: + logger.warning( + f"Graph {graph_id}: no interventable nodes with CI > 0, skipping base intervention run" + ) + return + + active_nodes: list[tuple[str, int, int]] = [] + for key in interventable_keys: + canon_layer, seq_str, cidx_str = key.split(":") + concrete_path = topology.canon_to_target(canon_layer) + active_nodes.append((concrete_path, int(seq_str), int(cidx_str))) + + effective_loss_config: LossConfig = ( + loss_config if loss_config is not None else MeanKLLossConfig() + ) + + result = compute_intervention( + model=model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=effective_loss_config, + sampling=sampling, + top_k=10, + ) + + db.save_intervention_run( + graph_id=graph_id, + selected_nodes=interventable_keys, + result_json=result.model_dump_json(), + ) + class EdgeData(BaseModel): """Edge in the attribution graph.""" @@ -65,12 +135,14 @@ class GraphData(BaseModel): graphType: GraphType tokens: list[str] edges: list[EdgeData] + edgesAbs: list[EdgeData] | None = None # absolute-target variant, None for old graphs outputProbs: dict[str, OutputProbability] nodeCiVals: dict[ str, float ] # node key -> CI value (or output prob for output nodes or 1 for embed node) nodeSubcompActs: dict[str, float] # node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: float # max absolute edge value + maxAbsAttrAbs: float | None = None # max absolute edge value for abs-target variant maxAbsSubcompAct: float # max absolute subcomponent activation for normalization l0_total: int # total active components at current CI threshold @@ -93,6 +165,70 @@ class KLLossResult(BaseModel): position: int +class LogitLossResult(BaseModel): + """Logit loss result (maximize pre-softmax logit).""" + + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + label_str: str + + +LossType = Literal["ce", "kl", "logit"] +LossResult = CELossResult | KLLossResult | LogitLossResult + + +def _build_loss_config( + loss_type: LossType, + loss_coeff: float, + loss_position: int, + label_token: int | None, +) -> PositionalLossConfig: + match loss_type: + case "ce": + assert label_token is not None, "label_token is required for CE loss" + return CELossConfig(coeff=loss_coeff, position=loss_position, label_token=label_token) + case "kl": + return KLLossConfig(coeff=loss_coeff, position=loss_position) + case "logit": + assert label_token is not None, "label_token is required for logit loss" + return LogitLossConfig( + coeff=loss_coeff, position=loss_position, label_token=label_token + ) + + +def _build_loss_result( + loss_config: PositionalLossConfig, + tok_display: Callable[[int], str], +) -> LossResult: + match loss_config: + case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): + return CELossResult( + coeff=coeff, position=pos, label_token=label_tok, label_str=tok_display(label_tok) + ) + case KLLossConfig(coeff=coeff, position=pos): + return KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + return LogitLossResult( + coeff=coeff, position=pos, label_token=label_tok, label_str=tok_display(label_tok) + ) + + +def _maybe_pgd( + n_steps: int | None, step_size: float | None +) -> tuple[PgdConfig, AdvPGDConfig] | None: + assert (n_steps is None) == (step_size is None), ( + "adv_pgd n_steps and step_size must both be set or both be None" + ) + if n_steps is None: + return None + assert step_size is not None # for narrowing + return PgdConfig(n_steps=n_steps, step_size=step_size), AdvPGDConfig( + n_steps=n_steps, step_size=step_size, init="random" + ) + + class OptimizationMetricsResult(BaseModel): """Final loss metrics from CI optimization.""" @@ -112,10 +248,9 @@ class OptimizationResult(BaseModel): pnorm: float beta: float mask_type: MaskType - loss: CELossResult | KLLossResult + loss: CELossResult | KLLossResult | LogitLossResult metrics: OptimizationMetricsResult - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + pgd: PgdConfig | None = None class GraphDataWithOptimization(GraphData): @@ -156,19 +291,6 @@ class TokenizeResponse(BaseModel): next_token_probs: list[float | None] # Probability of next token (last token is None) -class TokenInfo(BaseModel): - """A single token from the tokenizer vocabulary.""" - - id: int - string: str - - -class TokensResponse(BaseModel): - """Response containing all tokens in the vocabulary.""" - - tokens: list[TokenInfo] - - # SSE streaming message types class ProgressMessage(BaseModel): """Progress update during streaming computation.""" @@ -200,6 +322,12 @@ class CompleteMessageWithOptimization(BaseModel): data: GraphDataWithOptimization +class BatchGraphResult(BaseModel): + """Batch optimization result containing multiple graphs.""" + + graphs: list[GraphDataWithOptimization] + + router = APIRouter(prefix="/api/graphs", tags=["graphs"]) DEVICE = get_device() @@ -211,14 +339,10 @@ class CompleteMessageWithOptimization(BaseModel): ProgressCallback = Callable[[int, int, str], None] -MAX_OUTPUT_NODES_PER_POS = 15 - - def _build_out_probs( ci_masked_out_logits: torch.Tensor, target_out_logits: torch.Tensor, tok_display: Callable[[int], str], - adv_pgd_out_logits: torch.Tensor | None = None, ) -> dict[str, OutputProbability]: """Build output probs dict from logit tensors. @@ -226,9 +350,6 @@ def _build_out_probs( """ ci_masked_out_probs = torch.softmax(ci_masked_out_logits, dim=-1) target_out_probs = torch.softmax(target_out_logits, dim=-1) - adv_pgd_out_probs = ( - torch.softmax(adv_pgd_out_logits, dim=-1) if adv_pgd_out_logits is not None else None - ) out_probs: dict[str, OutputProbability] = {} for s in range(ci_masked_out_probs.shape[0]): @@ -243,65 +364,78 @@ def _build_out_probs( target_prob = float(target_out_probs[s, c_idx].item()) target_logit = float(target_out_logits[s, c_idx].item()) - adv_pgd_prob: float | None = None - adv_pgd_logit: float | None = None - if adv_pgd_out_probs is not None and adv_pgd_out_logits is not None: - adv_pgd_prob = round(float(adv_pgd_out_probs[s, c_idx].item()), 6) - adv_pgd_logit = round(float(adv_pgd_out_logits[s, c_idx].item()), 4) - key = f"{s}:{c_idx}" out_probs[key] = OutputProbability( prob=round(prob, 6), logit=round(logit, 4), target_prob=round(target_prob, 6), target_logit=round(target_logit, 4), - adv_pgd_prob=adv_pgd_prob, - adv_pgd_logit=adv_pgd_logit, token=tok_display(c_idx), ) return out_probs +CISnapshotCallback = Callable[[CISnapshot], None] + + def stream_computation( - work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + work: Callable[[ProgressCallback, CISnapshotCallback | None], BaseModel], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: progress_queue.put({"type": "progress", "current": current, "total": total, "stage": stage}) + def on_ci_snapshot(snapshot: CISnapshot) -> None: + progress_queue.put({"type": "ci_snapshot", **snapshot.model_dump()}) + def compute_thread() -> None: try: - result = work(on_progress) + result = work(on_progress, on_ci_snapshot) progress_queue.put({"type": "result", "result": result}) except Exception as e: traceback.print_exc(file=sys.stderr) progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] in ("progress", "ci_snapshot"): + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -343,40 +477,55 @@ def tokenize_text(text: str, loaded: DepLoadedRun) -> TokenizeResponse: ) -@router.get("/tokens") -@log_errors -def get_all_tokens(loaded: DepLoadedRun) -> TokensResponse: - """Get all tokens in the tokenizer vocabulary for client-side search.""" - tokens = [ - TokenInfo(id=tid, string=loaded.tokenizer.get_tok_display(tid)) - for tid in range(loaded.tokenizer.vocab_size) - ] - return TokensResponse(tokens=tokens) +class TokenSearchResult(BaseModel): + """A token search result with model probability at the queried position.""" + + id: int + string: str + prob: float class TokenSearchResponse(BaseModel): """Response from token search endpoint.""" - tokens: list[TokenInfo] + tokens: list[TokenSearchResult] @router.get("/tokens/search") @log_errors def search_tokens( q: Annotated[str, Query(min_length=1)], + prompt_id: Annotated[int, Query()], + position: Annotated[int, Query()], loaded: DepLoadedRun, - limit: Annotated[int, Query(ge=1, le=50)] = 10, + manager: DepStateManager, + limit: Annotated[int, Query(ge=1, le=50)] = 20, ) -> TokenSearchResponse: - """Search tokens by substring match. Returns up to `limit` results.""" + """Search tokens by substring match, sorted by target model probability at position.""" + prompt = manager.state.db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail=f"prompt {prompt_id} not found") + if not (0 <= position < len(prompt.token_ids)): + raise HTTPException( + status_code=422, + detail=f"position {position} out of range for prompt with {len(prompt.token_ids)} tokens", + ) + + device = next(loaded.model.parameters()).device + tokens_tensor = torch.tensor([prompt.token_ids], device=device) + with manager.gpu_lock(), torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits[0, position], dim=-1) + query = q.lower() - matches: list[TokenInfo] = [] + matches: list[TokenSearchResult] = [] for tid in range(loaded.tokenizer.vocab_size): string = loaded.tokenizer.get_tok_display(tid) if query in string.lower(): - matches.append(TokenInfo(id=tid, string=string)) - if len(matches) >= limit: - break - return TokenSearchResponse(tokens=matches) + matches.append(TokenSearchResult(id=tid, string=string, prob=probs[tid].item())) + + matches.sort(key=lambda m: m.prob, reverse=True) + return TokenSearchResponse(tokens=matches[:limit]) NormalizeType = Literal["none", "target", "layer"] @@ -450,7 +599,9 @@ def compute_graph_stream( spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) - def work(on_progress: ProgressCallback) -> GraphData: + def work( + on_progress: ProgressCallback, _on_ci_snapshot: CISnapshotCallback | None + ) -> GraphData: t_total = time.perf_counter() result = compute_prompt_attributions( @@ -474,6 +625,7 @@ def work(on_progress: ProgressCallback) -> GraphData: graph=StoredGraph( graph_type=graph_type, edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, node_ci_vals=result.node_ci_vals, @@ -483,6 +635,19 @@ def work(on_progress: ProgressCallback) -> GraphData: ) logger.info(f"[perf] save_graph: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + ) + logger.info(f"[perf] base intervention run: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() fg = filter_graph_for_display( raw_edges=result.edges, @@ -494,6 +659,7 @@ def work(on_progress: ProgressCallback) -> GraphData: num_tokens=len(token_ids), ci_threshold=ci_threshold, normalize=normalize, + raw_edges_abs=result.edges_abs, ) logger.info( f"[perf] filter_graph: {time.perf_counter() - t0:.2f}s ({len(fg.edges)} edges after filter)" @@ -505,15 +671,17 @@ def work(on_progress: ProgressCallback) -> GraphData: graphType=graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -557,9 +725,6 @@ def get_group_key(edge: Edge) -> str: return out_edges -LossType = Literal["ce", "kl"] - - @router.post("/optimized/stream") @log_errors def compute_graph_optimized_stream( @@ -586,19 +751,8 @@ def compute_graph_optimized_stream( label_token is required when loss_type is "ce". adv_pgd_n_steps and adv_pgd_step_size enable adversarial PGD when both are provided. """ - # Build loss config based on type - loss_config: LossConfig - match loss_type: - case "ce": - if label_token is None: - raise HTTPException(status_code=400, detail="label_token is required for CE loss") - loss_config = CELossConfig( - coeff=loss_coeff, position=loss_position, label_token=label_token - ) - case "kl": - loss_config = KLLossConfig(coeff=loss_coeff, position=loss_position) - - lr = 1e-2 + loss_config = _build_loss_config(loss_type, loss_coeff, loss_position, label_token) + pgd_configs = _maybe_pgd(adv_pgd_n_steps, adv_pgd_step_size) db = manager.db prompt = db.get_prompt(prompt_id) @@ -612,11 +766,9 @@ def compute_graph_optimized_stream( detail=f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens", ) - label_str = loaded.tokenizer.get_tok_display(label_token) if label_token is not None else None spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) - # Slice tokens to only include positions <= loss_position num_tokens = loss_position + 1 spans_sliced = spans[:num_tokens] @@ -627,13 +779,12 @@ def compute_graph_optimized_stream( beta=beta, mask_type=mask_type, loss=loss_config, - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=pgd_configs[0] if pgd_configs else None, ) optim_config = OptimCIConfig( seed=0, - lr=lr, + lr=1e-2, steps=steps, weight_decay=0.0, lr_schedule="cosine", @@ -645,12 +796,12 @@ def compute_graph_optimized_stream( sampling=loaded.config.sampling, ce_kl_rounding_threshold=0.5, mask_type=mask_type, - adv_pgd=AdvPGDConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size, init="random") - if adv_pgd_n_steps is not None and adv_pgd_step_size is not None - else None, + adv_pgd=pgd_configs[1] if pgd_configs else None, ) - def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> GraphDataWithOptimization: result = compute_prompt_attributions_optimized( model=loaded.model, topology=loaded.topology, @@ -660,28 +811,42 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: output_prob_threshold=0.01, device=DEVICE, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) ci_masked_out_logits = result.ci_masked_out_logits.cpu() target_out_logits = result.target_out_logits.cpu() - adv_pgd_out_logits = ( - result.adv_pgd_out_logits.cpu() if result.adv_pgd_out_logits is not None else None - ) + + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( graph_type="optimized", edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, - adv_pgd_out_logits=adv_pgd_out_logits, node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, optimization_params=opt_params, ), ) + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + fg = filter_graph_for_display( raw_edges=result.edges, node_ci_vals=result.node_ci_vals, @@ -692,32 +857,20 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: num_tokens=num_tokens, ci_threshold=ci_threshold, normalize=normalize, - adv_pgd_out_logits=adv_pgd_out_logits, + raw_edges_abs=result.edges_abs, ) - # Build loss result based on config type - loss_result: CELossResult | KLLossResult - match loss_config: - case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): - assert label_str is not None - loss_result = CELossResult( - coeff=coeff, - position=pos, - label_token=label_tok, - label_str=label_str, - ) - case KLLossConfig(coeff=coeff, position=pos): - loss_result = KLLossResult(coeff=coeff, position=pos) - return GraphDataWithOptimization( id=graph_id, graphType="optimized", tokens=spans_sliced, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, optimization=OptimizationResult( @@ -726,19 +879,203 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: pnorm=pnorm, beta=beta, mask_type=mask_type, - loss=loss_result, + loss=_build_loss_result(loss_config, loaded.tokenizer.get_tok_display), metrics=OptimizationMetricsResult( ci_masked_label_prob=result.metrics.ci_masked_label_prob, stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, l0_total=result.metrics.l0_total, ), - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=pgd_configs[0] if pgd_configs else None, + ), + ) + + return stream_computation(work, manager._gpu_lock) + + +class BatchOptimizedRequest(BaseModel): + """Request body for batch optimized graph computation.""" + + prompt_id: int + imp_min_coeffs: list[float] + steps: int + pnorm: float + beta: float + normalize: NormalizeType + ci_threshold: float + mask_type: MaskType + loss_type: LossType + loss_coeff: float + loss_position: int + label_token: int | None = None + adv_pgd_n_steps: int | None = None + adv_pgd_step_size: float | None = None + + +@router.post("/optimized/batch/stream") +@log_errors +def compute_graph_optimized_batch_stream( + body: BatchOptimizedRequest, + loaded: DepLoadedRun, + manager: DepStateManager, +): + """Compute optimized graphs for multiple sparsity coefficients in one batched optimization. + + Returns N graphs (one per imp_min_coeff) via SSE streaming. + All coefficients share the same loss config, steps, and other hyperparameters. + """ + assert len(body.imp_min_coeffs) > 0, "At least one coefficient required" + assert len(body.imp_min_coeffs) <= 20, "Too many coefficients (max 20)" + + loss_config = _build_loss_config( + body.loss_type, body.loss_coeff, body.loss_position, body.label_token + ) + pgd_configs = _maybe_pgd(body.adv_pgd_n_steps, body.adv_pgd_step_size) + + db = manager.db + prompt = db.get_prompt(body.prompt_id) + assert prompt is not None, f"prompt {body.prompt_id} not found" + + token_ids = prompt.token_ids + assert body.loss_position < len(token_ids), ( + f"loss_position {body.loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + spans = loaded.tokenizer.get_spans(token_ids) + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + num_tokens = body.loss_position + 1 + spans_sliced = spans[:num_tokens] + + configs = [ + OptimCIConfig( + seed=0, + lr=1e-2, + steps=body.steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, body.steps // 4), + imp_min_config=ImportanceMinimalityLossConfig( + coeff=coeff, pnorm=body.pnorm, beta=body.beta ), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type=body.mask_type, + adv_pgd=pgd_configs[1] if pgd_configs else None, ) + for coeff in body.imp_min_coeffs + ] - return stream_computation(work) + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> BatchGraphResult: + results = compute_prompt_attributions_optimized_batched( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + configs=configs, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + graphs: list[GraphDataWithOptimization] = [] + for result, coeff in zip(results, body.imp_min_coeffs, strict=True): + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + opt_params = OptimizationParams( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=loss_config, + pgd=pgd_configs[0] if pgd_configs else None, + ) + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob + + graph_id = db.save_graph( + prompt_id=body.prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + + fg = filter_graph_for_display( + raw_edges=result.edges, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + tok_display=loaded.tokenizer.get_tok_display, + num_tokens=num_tokens, + ci_threshold=body.ci_threshold, + normalize=body.normalize, + raw_edges_abs=result.edges_abs, + ) + + graphs.append( + GraphDataWithOptimization( + id=graph_id, + graphType="optimized", + tokens=spans_sliced, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, + nodeSubcompActs=result.node_subcomp_acts, + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, + optimization=OptimizationResult( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=_build_loss_result(loss_config, loaded.tokenizer.get_tok_display), + metrics=OptimizationMetricsResult( + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + l0_total=result.metrics.l0_total, + ), + pgd=pgd_configs[0] if pgd_configs else None, + ), + ) + ) + + return BatchGraphResult(graphs=graphs) + + return stream_computation(work, manager._gpu_lock) @dataclass @@ -746,9 +1083,11 @@ class FilteredGraph: """Result of filtering a raw graph for display.""" edges: list[EdgeData] + edges_abs: list[EdgeData] | None # absolute-target variant, None for old graphs node_ci_vals: dict[str, float] # with pseudo nodes out_probs: dict[str, OutputProbability] max_abs_attr: float + max_abs_attr_abs: float | None # max abs for absolute-target edges max_abs_subcomp_act: float l0_total: int @@ -763,8 +1102,8 @@ def filter_graph_for_display( num_tokens: int, ci_threshold: float, normalize: NormalizeType, + raw_edges_abs: list[Edge] | None = None, edge_limit: int = GLOBAL_EDGE_LIMIT, - adv_pgd_out_logits: torch.Tensor | None = None, ) -> FilteredGraph: """Filter and transform a raw attribution graph for display. @@ -775,9 +1114,7 @@ def filter_graph_for_display( 5. Normalize edge strengths (if requested) 6. Cap edges at edge_limit """ - out_probs = _build_out_probs( - ci_masked_out_logits, target_out_logits, tok_display, adv_pgd_out_logits - ) + out_probs = _build_out_probs(ci_masked_out_logits, target_out_logits, tok_display) filtered_node_ci_vals = {k: v for k, v in node_ci_vals.items() if v > ci_threshold} @@ -789,25 +1126,33 @@ def filter_graph_for_display( seq_pos, token_id = key.split(":") node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob - # Filter edges to only those connecting surviving nodes + # Filter, normalize, sort, and truncate an edge list to the surviving node set. node_keys = set(node_ci_vals_with_pseudo.keys()) - edges = [e for e in raw_edges if str(e.source) in node_keys and str(e.target) in node_keys] - edges = _normalize_edges(edges=edges, normalize=normalize) - max_abs_attr = compute_max_abs_attr(edges=edges) + def _filter_edges(raw: list[Edge]) -> tuple[list[EdgeData], float]: + filtered = [e for e in raw if str(e.source) in node_keys and str(e.target) in node_keys] + filtered = _normalize_edges(edges=filtered, normalize=normalize) + max_abs = compute_max_abs_attr(edges=filtered) + filtered = sorted(filtered, key=lambda e: abs(e.strength), reverse=True) + if len(filtered) > edge_limit: + logger.warning(f"Edge limit {edge_limit} exceeded ({len(filtered)} edges), truncating") + filtered = filtered[:edge_limit] + return [_edge_to_edge_data(e) for e in filtered], max_abs - # Always sort by abs(strength) desc so frontend can just slice(0, topK) without re-sorting - edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True) + edges_out, max_abs_attr = _filter_edges(raw_edges) - if len(edges) > edge_limit: - logger.warning(f"Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") - edges = edges[:edge_limit] + edges_abs_out: list[EdgeData] | None = None + max_abs_attr_abs: float | None = None + if raw_edges_abs is not None: + edges_abs_out, max_abs_attr_abs = _filter_edges(raw_edges_abs) return FilteredGraph( - edges=[_edge_to_edge_data(e) for e in edges], + edges=edges_out, + edges_abs=edges_abs_out, node_ci_vals=node_ci_vals_with_pseudo, out_probs=out_probs, max_abs_attr=max_abs_attr, + max_abs_attr_abs=max_abs_attr_abs, max_abs_subcomp_act=compute_max_abs_subcomp_act(node_subcomp_acts), l0_total=len(filtered_node_ci_vals), ) @@ -840,7 +1185,7 @@ def stored_graph_to_response( num_tokens=num_tokens, ci_threshold=ci_threshold, normalize=normalize, - adv_pgd_out_logits=graph.adv_pgd_out_logits, + raw_edges_abs=graph.edges_abs, ) if not is_optimized: @@ -849,10 +1194,12 @@ def stored_graph_to_response( graphType=graph.graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, ) @@ -860,29 +1207,17 @@ def stored_graph_to_response( assert graph.optimization_params is not None opt = graph.optimization_params - # Build loss result based on stored config type - loss_result: CELossResult | KLLossResult - match opt.loss: - case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): - label_str = tokenizer.get_tok_display(label_tok) - loss_result = CELossResult( - coeff=coeff, - position=pos, - label_token=label_tok, - label_str=label_str, - ) - case KLLossConfig(coeff=coeff, position=pos): - loss_result = KLLossResult(coeff=coeff, position=pos) - return GraphDataWithOptimization( id=graph.id, graphType=graph.graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, optimization=OptimizationResult( @@ -891,11 +1226,14 @@ def stored_graph_to_response( pnorm=opt.pnorm, beta=opt.beta, mask_type=opt.mask_type, - loss=loss_result, - # Metrics not stored in DB for cached graphs - use l0_total from graph - metrics=OptimizationMetricsResult(l0_total=float(fg.l0_total)), - adv_pgd_n_steps=opt.adv_pgd_n_steps, - adv_pgd_step_size=opt.adv_pgd_step_size, + loss=_build_loss_result(opt.loss, tokenizer.get_tok_display), + metrics=OptimizationMetricsResult( + l0_total=float(fg.l0_total), + ci_masked_label_prob=opt.ci_masked_label_prob, + stoch_masked_label_prob=opt.stoch_masked_label_prob, + adv_pgd_label_prob=opt.adv_pgd_label_prob, + ), + pgd=opt.pgd, ), ) diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 1ccdbf86c..e26a73462 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -4,8 +4,12 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from spd.app.backend.compute import compute_intervention_forward +from spd.app.backend.compute import ( + InterventionResult, + compute_intervention, +) from spd.app.backend.dependencies import DepDB, DepLoadedRun, DepStateManager +from spd.app.backend.optim_cis import AdvPGDConfig, LossConfig, MeanKLLossConfig from spd.app.backend.utils import log_errors from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device @@ -15,56 +19,19 @@ # ============================================================================= -class InterventionNode(BaseModel): - """A specific node to activate during intervention.""" - - layer: str - seq_pos: int - component_idx: int - - -class InterventionRequest(BaseModel): - """Request for intervention forward pass.""" - - text: str - nodes: list[InterventionNode] - top_k: int - - -class TokenPrediction(BaseModel): - """A single token prediction with probability.""" - - token: str - token_id: int - spd_prob: float - target_prob: float - logit: float - target_logit: float - - -class InterventionResponse(BaseModel): - """Response from intervention forward pass.""" - - input_tokens: list[str] - predictions_per_position: list[list[TokenPrediction]] +class AdvPgdParams(BaseModel): + n_steps: int + step_size: float class RunInterventionRequest(BaseModel): """Request to run and save an intervention.""" graph_id: int - text: str selected_nodes: list[str] # node keys (layer:seq:cIdx) - top_k: int = 10 - - -class ForkedInterventionRunSummary(BaseModel): - """Summary of a forked intervention run with modified tokens.""" - - id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result: InterventionResponse - created_at: str + nodes_to_ablate: list[str] | None = None # node keys to ablate in ablated (omit to skip) + top_k: int + adv_pgd: AdvPgdParams class InterventionRunSummary(BaseModel): @@ -72,16 +39,8 @@ class InterventionRunSummary(BaseModel): id: int selected_nodes: list[str] - result: InterventionResponse + result: InterventionResult created_at: str - forked_runs: list[ForkedInterventionRunSummary] - - -class ForkInterventionRequest(BaseModel): - """Request to fork an intervention run with modified tokens.""" - - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - top_k: int = 10 router = APIRouter(prefix="/api/intervention", tags=["intervention"]) @@ -104,100 +63,15 @@ def _parse_node_key(key: str, topology: TransformerTopology) -> tuple[str, int, return concrete_path, int(seq_str), int(cidx_str) -def _run_intervention_forward( - text: str, - selected_nodes: list[str], - top_k: int, - loaded: DepLoadedRun, -) -> InterventionResponse: - """Run intervention forward pass and return response.""" - token_ids = loaded.tokenizer.encode(text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [_parse_node_key(key, loaded.topology) for key in selected_nodes] - - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) - - -@router.post("") -@log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: - """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [ - ( - loaded.topology.canon_to_target(n.layer), - n.seq_pos, - n.component_idx, - ) - for n in request.nodes - ] - - seq_len = tokens.shape[1] +def _parse_and_validate_active_nodes( + selected_nodes: list[str], topology: TransformerTopology, seq_len: int +) -> list[tuple[str, int, int]]: + """Parse node keys and validate sequence bounds for the current prompt.""" + active_nodes = [_parse_node_key(key, topology) for key in selected_nodes] for _, seq_pos, _ in active_nodes: if seq_pos >= seq_len: raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return active_nodes @router.post("/run") @@ -206,19 +80,59 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + graph_record = db.get_graph(request.graph_id) + if graph_record is None: + raise HTTPException(status_code=404, detail="Graph not found") + graph, prompt_id = graph_record + + prompt = db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail="Prompt not found") + + token_ids = prompt.token_ids + active_nodes = _parse_and_validate_active_nodes( + request.selected_nodes, loaded.topology, len(token_ids) + ) + nodes_to_ablate = ( + _parse_and_validate_active_nodes( + request.nodes_to_ablate, loaded.topology, len(token_ids) + ) + if request.nodes_to_ablate is not None + else None + ) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Use graph's loss config if optimized, else mean KL + loss_config: LossConfig = ( + graph.optimization_params.loss + if graph.optimization_params is not None + else MeanKLLossConfig() + ) + + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=nodes_to_ablate, + tokenizer=loaded.tokenizer, + adv_pgd_config=AdvPGDConfig( + n_steps=request.adv_pgd.n_steps, + step_size=request.adv_pgd.step_size, + init="random", + ), + loss_config=loss_config, + sampling=loaded.config.sampling, + top_k=request.top_k, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, selected_nodes=request.selected_nodes, - result_json=response.model_dump_json(), + result_json=result.model_dump_json(), ) record = db.get_intervention_runs(request.graph_id) @@ -228,41 +142,25 @@ def run_and_save_intervention( return InterventionRunSummary( id=run_id, selected_nodes=request.selected_nodes, - result=response, + result=result, created_at=saved_run.created_at, - forked_runs=[], ) @router.get("/runs/{graph_id}") @log_errors def get_intervention_runs(graph_id: int, db: DepDB) -> list[InterventionRunSummary]: - """Get all intervention runs for a graph, including forked runs.""" + """Get all intervention runs for a graph.""" records = db.get_intervention_runs(graph_id) - results = [] - for r in records: - # Get forked runs for this intervention run - forked_records = db.get_forked_intervention_runs(r.id) - forked_runs = [ - ForkedInterventionRunSummary( - id=fr.id, - token_replacements=fr.token_replacements, - result=InterventionResponse.model_validate_json(fr.result_json), - created_at=fr.created_at, - ) - for fr in forked_records - ] - - results.append( - InterventionRunSummary( - id=r.id, - selected_nodes=r.selected_nodes, - result=InterventionResponse.model_validate_json(r.result_json), - created_at=r.created_at, - forked_runs=forked_runs, - ) + return [ + InterventionRunSummary( + id=r.id, + selected_nodes=r.selected_nodes, + result=InterventionResult.model_validate_json(r.result_json), + created_at=r.created_at, ) - return results + for r in records + ] @router.delete("/runs/{run_id}") @@ -271,86 +169,3 @@ def delete_intervention_run(run_id: int, db: DepDB) -> dict[str, bool]: """Delete an intervention run.""" db.delete_intervention_run(run_id) return {"success": True} - - -@router.post("/runs/{run_id}/fork") -@log_errors -def fork_intervention_run( - run_id: int, - request: ForkInterventionRequest, - loaded: DepLoadedRun, - manager: DepStateManager, -) -> ForkedInterventionRunSummary: - """Fork an intervention run with modified tokens. - - Takes the same selected_nodes from the parent run, applies token replacements - to the original prompt, and runs the intervention forward pass. - """ - db = manager.db - - # Get the parent intervention run - parent_run = db.get_intervention_run(run_id) - if parent_run is None: - raise HTTPException(status_code=404, detail="Intervention run not found") - - # Get the prompt_id from the graph - conn = db._get_conn() - row = conn.execute( - "SELECT prompt_id FROM graphs WHERE id = ?", (parent_run.graph_id,) - ).fetchone() - if row is None: - raise HTTPException(status_code=404, detail="Graph not found") - prompt_id = row["prompt_id"] - - # Get the prompt to get original token_ids - prompt = db.get_prompt(prompt_id) - if prompt is None: - raise HTTPException(status_code=404, detail="Prompt not found") - - # Apply token replacements to get modified token_ids - modified_token_ids = list(prompt.token_ids) # Make a copy - for seq_pos, new_token_id in request.token_replacements: - if seq_pos < 0 or seq_pos >= len(modified_token_ids): - raise HTTPException( - status_code=400, - detail=f"Invalid seq_pos {seq_pos} for prompt with {len(modified_token_ids)} tokens", - ) - modified_token_ids[seq_pos] = new_token_id - - # Decode the modified tokens back to text - modified_text = loaded.tokenizer.decode(modified_token_ids) - - # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) - - # Save the forked run - fork_id = db.save_forked_intervention_run( - intervention_run_id=run_id, - token_replacements=request.token_replacements, - result_json=response.model_dump_json(), - ) - - # Get the saved record for created_at - forked_records = db.get_forked_intervention_runs(run_id) - saved_fork = next((f for f in forked_records if f.id == fork_id), None) - assert saved_fork is not None - - return ForkedInterventionRunSummary( - id=fork_id, - token_replacements=request.token_replacements, - result=response, - created_at=saved_fork.created_at, - ) - - -@router.delete("/forks/{fork_id}") -@log_errors -def delete_forked_intervention_run(fork_id: int, db: DepDB) -> dict[str, bool]: - """Delete a forked intervention run.""" - db.delete_forked_intervention_run(fork_id) - return {"success": True} diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..3fa8297d0 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,314 @@ +"""Investigations endpoint for viewing agent investigation results. + +Lists and serves investigation data from SPD_OUT_DIR/investigations/. +Each investigation directory contains findings from a single agent run. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.settings import DEFAULT_PARTITION_NAME, SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +INVESTIGATIONS_DIR = SPD_OUT_DIR / "investigations" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + title: str | None + summary: str | None + status: str | None + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + artifact_ids: list[str] + title: str | None + summary: str | None + status: str | None + + +def _parse_metadata(inv_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from an investigation directory.""" + metadata_path = inv_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except json.JSONDecodeError: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + + return last_time, last_msg, count + + +def _parse_task_summary(inv_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from an investigation directory. Returns (title, summary, status).""" + summary_path = inv_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except json.JSONDecodeError: + return None, None, None + + +def _list_artifact_ids(inv_path: Path) -> list[str]: + """List all artifact IDs for an investigation.""" + artifacts_dir = inv_path / "artifacts" + if not artifacts_dir.exists(): + return [] + return [f.stem for f in sorted(artifacts_dir.glob("graph_*.json"))] + + +def _get_created_at(inv_path: Path, metadata: dict[str, Any] | None) -> str: + """Get creation time for an investigation.""" + events_path = inv_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except json.JSONDecodeError: + pass + + if metadata and "created_at" in metadata: + return metadata["created_at"] + + return datetime.fromtimestamp(inv_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations(loaded: DepLoadedRun) -> list[InvestigationSummary]: + """List investigations for the currently loaded run.""" + if not INVESTIGATIONS_DIR.exists(): + return [] + + wandb_path = loaded.run.wandb_path + results = [] + + for inv_path in INVESTIGATIONS_DIR.iterdir(): + if not inv_path.is_dir() or not inv_path.name.startswith("inv-"): + continue + + inv_id = inv_path.name + metadata = _parse_metadata(inv_path) + + meta_wandb_path = metadata.get("wandb_path") if metadata else None + if meta_wandb_path is None: + continue + # Normalize to canonical form for comparison (strips "runs/", "wandb:" prefix, etc.) + try: + e, p, r = parse_wandb_run_path(meta_wandb_path) + canonical_meta_path = f"{e}/{p}/{r}" + except ValueError: + continue + if canonical_meta_path != wandb_path: + continue + + events_path = inv_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(inv_path) + + explanations_path = inv_path / "explanations.jsonl" + + results.append( + InvestigationSummary( + id=inv_id, + wandb_path=meta_wandb_path, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + has_research_log=(inv_path / "research_log.md").exists(), + has_explanations=explanations_path.exists() + and explanations_path.stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{inv_id}") +def get_investigation(inv_id: str) -> InvestigationDetail: + """Get full details of an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + + if not inv_path.exists() or not inv_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + + metadata = _parse_metadata(inv_path) + + research_log = None + research_log_path = inv_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + events = [] + events_path = inv_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + explanations: list[dict[str, Any]] = [] + explanations_path = inv_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(inv_path) + artifact_ids = _list_artifact_ids(inv_path) + + return InvestigationDetail( + id=inv_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + artifact_ids=artifact_ids, + title=title, + summary=summary, + status=status, + ) + + +class LaunchRequest(BaseModel): + prompt: str + + +class LaunchResponse(BaseModel): + inv_id: str + job_id: str + + +@router.post("/launch") +def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) -> LaunchResponse: + """Launch a new investigation for the currently loaded run.""" + from spd.investigate.scripts.run_slurm import launch_investigation + + result = launch_investigation( + wandb_path=loaded.run.wandb_path, + prompt=request.prompt, + context_length=loaded.context_length, + max_turns=50, + partition=DEFAULT_PARTITION_NAME, # TODO: remove when investigate module drops required partition + time="8:00:00", + job_suffix=None, + ) + return LaunchResponse(inv_id=result.inv_id, job_id=result.job_id) + + +@router.get("/{inv_id}/artifacts") +def list_artifacts(inv_id: str) -> list[str]: + """List all artifact IDs for an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + if not inv_path.exists(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + return _list_artifact_ids(inv_path) + + +@router.get("/{inv_id}/artifacts/{artifact_id}") +def get_artifact(inv_id: str, artifact_id: str) -> dict[str, Any]: + """Get a specific artifact by ID.""" + inv_path = INVESTIGATIONS_DIR / inv_id + artifact_path = inv_path / "artifacts" / f"{artifact_id}.json" + + if not artifact_path.exists(): + raise HTTPException( + status_code=404, + detail=f"Artifact {artifact_id} not found in {inv_id}", + ) + + data: dict[str, Any] = json.loads(artifact_path.read_text()) + return data diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..488e48b4f --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1487 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Callable, Generator +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_ci_only, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import _build_out_probs +from spd.app.backend.routers.pretrain_info import _get_pretrain_info +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + + +@dataclass +class InvestigationConfig: + """Configuration for investigation mode. All paths are required when in investigation mode.""" + + events_log_path: Path + investigation_dir: Path + + +_investigation_config: InvestigationConfig | None = None + + +def set_investigation_config(config: InvestigationConfig) -> None: + """Configure MCP for investigation mode.""" + global _investigation_config + _investigation_config = config + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if in investigation mode.""" + if _investigation_config is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_investigation_config.events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response. + + Per JSON-RPC 2.0 spec, exactly one of result/error must be present (not both, not neither). + Use model_dump(exclude_none=True) when serializing to avoid including null fields. + """ + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up', '2.attn.o')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., '0.mlp.up:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), + ToolDefinition( + name="save_graph_artifact", + description="""Save a graph as an artifact for inclusion in your research report. + +After calling optimize_graph and getting a graph_id, call this to save the graph +as an artifact. Then reference it in your research log using the spd:graph syntax: + +```spd:graph +artifact: graph_001 +``` + +This allows humans reviewing your investigation to see interactive circuit visualizations +inline with your research notes.""", + inputSchema={ + "type": "object", + "properties": { + "graph_id": { + "type": "integer", + "description": "The graph ID returned by optimize_graph", + }, + "caption": { + "type": "string", + "description": "Optional caption describing what this graph shows", + }, + }, + "required": ["graph_id"], + }, + ), + ToolDefinition( + name="probe_component", + description="""Fast CI probing on custom text. + +Computes causal importance values and subcomponent activations for a specific component +across all positions in the input text. Also returns next-token probabilities. + +Use this for quick, targeted analysis of how a component responds to specific inputs.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The input text to probe", + }, + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + }, + "required": ["text", "layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_activation_examples", + description="""Get activation examples from harvest data for a component. + +Returns examples showing token windows where the component fires, along with +CI values and activation strengths at each position. + +Use this to understand what inputs activate a component.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "limit": { + "type": "integer", + "description": "Maximum number of examples to return (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_attributions", + description="""Get dataset-level component dependencies from pre-computed attributions. + +Returns the top source and target components that this component attributes to/from, +aggregated over the training dataset. Both positive and negative attributions are returned. + +Use this to understand a component's role in the broader network.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up') or 'output'", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "k": { + "type": "integer", + "description": "Number of top attributions to return per direction (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_model_info", + description="""Get architecture details about the pretrained model. + +Returns model type, summary, target model config, topology, and pretrain info. +No parameters required.""", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _canonicalize_layer(layer: str, loaded: Any) -> str: + """Translate concrete layer name to canonical, passing through 'output'.""" + if layer == "output": + return layer + return loaded.topology.target_to_canon(layer) + + +def _canonicalize_key(concrete_key: str, loaded: Any) -> str: + """Translate concrete component key (e.g. 'h.0.mlp.c_fc:444') to canonical ('0.mlp.up:444').""" + layer, idx = concrete_key.rsplit(":", 1) + return f"{_canonicalize_layer(layer, loaded)}:{idx}" + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + adv_pgd=None, # AdvPGDConfig(n_steps=10, step_size=0.01, init="random"), + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + # Build output probs for response + out_probs = _build_out_probs( + ci_masked_out_logits, + target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + canonical_key = f"{layer}:{component_idx}" + + # Harvest/interp repos store concrete keys (e.g. "h.0.mlp.c_fc:444") + concrete_layer = loaded.topology.canon_to_target(layer) + concrete_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_info: {canonical_key}", + {"layer": layer, "idx": component_idx}, + ) + + result: dict[str, Any] = {"component_key": canonical_key} + + # Get interpretation + if loaded.interp is not None: + interp = loaded.interp.get_interpretation(concrete_key) + if interp is not None: + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + else: + result["interpretation"] = None + + # Get token stats + assert loaded.harvest is not None, "harvest data not loaded" + token_stats = loaded.harvest.get_token_stats() + if token_stats is not None: + input_stats = analysis.get_input_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + else: + result["token_stats"] = None + + # Get correlations (return canonical keys) + correlations = loaded.harvest.get_correlations() + if correlations is not None and analysis.has_component(correlations, concrete_key): + result["correlated_components"] = { + "precision": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "precision", top_k + ) + ], + "pmi": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, + compute_intervention, + ) + from spd.app.backend.optim_cis import MeanKLLossConfig + + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "embed", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=loaded.tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=MeanKLLossConfig(), + sampling=loaded.config.sampling, + top_k=top_k, + ) + + predictions = [] + for pos_predictions in result.ci: + pos_result = [] + for pred in pos_predictions: + pos_result.append( + { + "token": pred.token, + "token_id": pred.token_id, + "circuit_prob": round(pred.prob, 6), + "full_model_prob": round(pred.target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _require_investigation_config() -> InvestigationConfig: + """Get investigation config, raising if not in investigation mode.""" + assert _investigation_config is not None, "Not running in investigation mode" + return _investigation_config + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + config = _require_investigation_config() + content = params["content"] + research_log_path = config.investigation_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.investigate.schemas import BehaviorExplanation, ComponentInfo, Evidence + + config = _require_investigation_config() + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = config.investigation_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + config = _require_investigation_config() + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = config.investigation_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +def _tool_save_graph_artifact(params: dict[str, Any]) -> dict[str, Any]: + """Save a graph as an artifact for the research report. + + Uses the same filtering logic as the main graph API: + 1. Filter nodes by CI threshold + 2. Add pseudo nodes (wte, output) + 3. Filter edges to only active nodes + 4. Apply edge limit + """ + config = _require_investigation_config() + manager, loaded = _get_state() + + graph_id = params["graph_id"] + caption = params.get("caption") + ci_threshold = params.get("ci_threshold", 0.5) + edge_limit = params.get("edge_limit", 5000) + + _log_event( + "tool_call", + f"save_graph_artifact: graph_id={graph_id}", + {"graph_id": graph_id, "caption": caption}, + ) + + # Fetch graph from DB + result = manager.db.get_graph(graph_id) + if result is None: + raise ValueError(f"Graph with id={graph_id} not found") + + graph, prompt_id = result + + # Get tokens from prompt + prompt_record = manager.db.get_prompt(prompt_id) + if prompt_record is None: + raise ValueError(f"Prompt with id={prompt_id} not found") + + tokens = [loaded.tokenizer.get_tok_display(tid) for tid in prompt_record.token_ids] + num_tokens = len(tokens) + + # Create artifacts directory + artifacts_dir = config.investigation_dir / "artifacts" + artifacts_dir.mkdir(exist_ok=True) + + # Generate artifact ID (find max existing number to avoid collisions) + existing_nums = [] + for f in artifacts_dir.glob("graph_*.json"): + try: + num = int(f.stem.split("_")[1]) + existing_nums.append(num) + except (IndexError, ValueError): + continue + artifact_num = max(existing_nums, default=0) + 1 + artifact_id = f"graph_{artifact_num:03d}" + + # Compute out_probs from stored logits + out_probs = _build_out_probs( + graph.ci_masked_out_logits, + graph.target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Step 1: Filter nodes by CI threshold (same as main graph API) + filtered_ci_vals = {k: v for k, v in graph.node_ci_vals.items() if v > ci_threshold} + l0_total = len(filtered_ci_vals) + + # Step 2: Add pseudo nodes (embed and output) - same as _add_pseudo_layer_nodes + node_ci_vals_with_pseudo = dict(filtered_ci_vals) + for seq_pos in range(num_tokens): + node_ci_vals_with_pseudo[f"embed:{seq_pos}:0"] = 1.0 + for key, out_prob in out_probs.items(): + seq_pos, token_id = key.split(":") + node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob + + # Step 3: Filter edges to only active nodes + active_node_keys = set(node_ci_vals_with_pseudo.keys()) + filtered_edges = [ + e + for e in graph.edges + if str(e.source) in active_node_keys and str(e.target) in active_node_keys + ] + + # Step 4: Sort by strength and apply edge limit + filtered_edges.sort(key=lambda e: abs(e.strength), reverse=True) + filtered_edges = filtered_edges[:edge_limit] + + # Build edges data + edges_data = [ + { + "src": str(e.source), + "tgt": str(e.target), + "val": e.strength, + } + for e in filtered_edges + ] + + # Compute max abs attr from filtered edges + max_abs_attr = max((abs(e.strength) for e in filtered_edges), default=0.0) + + # Filter nodeSubcompActs to match nodeCiVals + filtered_subcomp_acts = { + k: v for k, v in graph.node_subcomp_acts.items() if k in node_ci_vals_with_pseudo + } + + # Build artifact data (self-contained GraphData, same structure as API response) + artifact = { + "type": "graph", + "id": artifact_id, + "caption": caption, + "graph_id": graph_id, + "data": { + "tokens": tokens, + "edges": edges_data, + "outputProbs": { + k: { + "prob": v.prob, + "logit": v.logit, + "target_prob": v.target_prob, + "target_logit": v.target_logit, + "token": v.token, + } + for k, v in out_probs.items() + }, + "nodeCiVals": node_ci_vals_with_pseudo, + "nodeSubcompActs": filtered_subcomp_acts, + "maxAbsAttr": max_abs_attr, + "l0_total": l0_total, + }, + } + + # Save artifact + artifact_path = artifacts_dir / f"{artifact_id}.json" + artifact_path.write_text(json.dumps(artifact, indent=2)) + + _log_event( + "artifact_saved", + f"Saved graph artifact: {artifact_id}", + {"artifact_id": artifact_id, "graph_id": graph_id, "path": str(artifact_path)}, + ) + + return {"artifact_id": artifact_id, "path": str(artifact_path)} + + +def _tool_probe_component(params: dict[str, Any]) -> dict[str, Any]: + """Fast CI probing on custom text for a specific component.""" + manager, loaded = _get_state() + + text = params["text"] + layer = params["layer"] + component_idx = params["component_idx"] + + _log_event( + "tool_call", + f"probe_component: '{text[:50]}...' layer={layer} idx={component_idx}", + {"text": text, "layer": layer, "component_idx": component_idx}, + ) + + token_ids = loaded.tokenizer.encode(text) + assert token_ids, "Text produced no tokens" + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + concrete_layer = loaded.topology.canon_to_target(layer) + + with manager.gpu_lock(): + result = compute_ci_only( + model=loaded.model, tokens=tokens_tensor, sampling=loaded.config.sampling + ) + + ci_values = result.ci_lower_leaky[concrete_layer][0, :, component_idx].tolist() + subcomp_acts = result.component_acts[concrete_layer][0, :, component_idx].tolist() + + # Get next token probs from target model output + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = result.target_out_probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "tokens": token_strings, + "ci_values": ci_values, + "subcomp_acts": subcomp_acts, + "next_token_probs": next_token_probs, + } + + +def _tool_get_component_activation_examples(params: dict[str, Any]) -> dict[str, Any]: + """Get activation examples from harvest data.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + limit = params.get("limit", 10) + + concrete_layer = loaded.topology.canon_to_target(layer) + component_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_activation_examples: {component_key}", + {"layer": layer, "component_idx": component_idx, "limit": limit}, + ) + + assert loaded.harvest is not None, "harvest data not loaded" + canonical_key = f"{layer}:{component_idx}" + comp = loaded.harvest.get_component(component_key) + if comp is None: + return {"component_key": canonical_key, "examples": [], "total": 0} + + examples = [] + for ex in comp.activation_examples[:limit]: + token_strings = [loaded.tokenizer.get_tok_display(t) for t in ex.token_ids] + examples.append( + { + "tokens": token_strings, + "ci_values": ex.activations["causal_importance"], + "component_acts": ex.activations["component_activation"], + } + ) + + return { + "component_key": canonical_key, + "examples": examples, + "total": len(comp.activation_examples), + "mean_ci": comp.mean_activations["causal_importance"], + } + + +def _tool_get_model_info(_params: dict[str, Any]) -> dict[str, Any]: + """Get architecture details about the pretrained model.""" + _, loaded = _get_state() + + _log_event("tool_call", "get_model_info", {}) + + info = _get_pretrain_info(loaded.config) + return info.model_dump() + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +_STREAMING_TOOLS: dict[str, Callable[..., Generator[dict[str, Any]]]] = { + "optimize_graph": _tool_optimize_graph, +} + +_SIMPLE_TOOLS: dict[str, Callable[..., dict[str, Any]]] = { + "get_component_info": _tool_get_component_info, + "run_ablation": _tool_run_ablation, + "search_dataset": _tool_search_dataset, + "create_prompt": _tool_create_prompt, + "update_research_log": _tool_update_research_log, + "save_explanation": _tool_save_explanation, + "set_investigation_summary": _tool_set_investigation_summary, + "save_graph_artifact": _tool_save_graph_artifact, + "probe_component": _tool_probe_component, + "get_component_activation_examples": _tool_get_component_activation_examples, + # "get_component_attributions": _tool_get_component_attributions, + "get_model_info": _tool_get_model_info, +} + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name in _STREAMING_TOOLS: + return _STREAMING_TOOLS[name](arguments) + + if name in _SIMPLE_TOOLS: + result = _SIMPLE_TOOLS[name](arguments) + return {"content": [{"type": "text", "text": json.dumps(result, indent=2)}]} + + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(exclude_none=True), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True) + ) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump(exclude_none=True))}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump(exclude_none=True))}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump( + exclude_none=True + ) + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump(exclude_none=True) + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump(exclude_none=True) + ) diff --git a/spd/app/backend/routers/pretrain_info.py b/spd/app/backend/routers/pretrain_info.py index 2872423c9..424f7b035 100644 --- a/spd/app/backend/routers/pretrain_info.py +++ b/spd/app/backend/routers/pretrain_info.py @@ -38,6 +38,7 @@ class TopologyInfo(BaseModel): class PretrainInfoResponse(BaseModel): model_type: str summary: str + dataset_short: str | None target_model_config: dict[str, Any] | None pretrain_config: dict[str, Any] | None pretrain_wandb_path: str | None @@ -161,6 +162,27 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) return " · ".join(parts) +_DATASET_SHORT_NAMES: dict[str, str] = { + "simplestories": "SS", + "pile": "Pile", + "tinystories": "TS", +} + + +def _get_dataset_short(pretrain_config: dict[str, Any] | None) -> str | None: + """Extract a short dataset label from the pretrain config.""" + if pretrain_config is None: + return None + dataset_name: str = ( + pretrain_config.get("train_dataset_config", {}).get("name", "") + or pretrain_config.get("dataset", "") + ).lower() + for key, short in _DATASET_SHORT_NAMES.items(): + if key in dataset_name: + return short + return None + + def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: """Extract pretrain info from an SPD config.""" model_class_name = spd_config.pretrained_model_class @@ -190,10 +212,12 @@ def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: n_blocks = target_model_config.get("n_layer", 0) if target_model_config else 0 topology = _build_topology(model_type, n_blocks) summary = _build_summary(model_type, target_model_config) + dataset_short = _get_dataset_short(pretrain_config) return PretrainInfoResponse( model_type=model_type, summary=summary, + dataset_short=dataset_short, target_model_config=target_model_config, pretrain_config=pretrain_config, pretrain_wandb_path=pretrain_wandb_path, diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 7f06bcf68..2cf0da197 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -6,16 +6,6 @@ from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.utils import log_errors -from spd.utils.distributed_utils import get_device - -# TODO: Re-enable these endpoints when dependencies are available: -# - extract_active_from_ci from database -# - PromptSearchQuery, PromptSearchResponse from schemas -# - DatasetConfig, LMTaskConfig from configs -# - create_data_loader, extract_batch_data from data -# - logger from utils - -DEVICE = get_device() # ============================================================================= # Schemas diff --git a/spd/app/backend/routers/run_registry.py b/spd/app/backend/routers/run_registry.py new file mode 100644 index 000000000..d44a7108f --- /dev/null +++ b/spd/app/backend/routers/run_registry.py @@ -0,0 +1,95 @@ +"""Run registry endpoint. + +Returns architecture and data availability for requested SPD runs. +The canonical run list lives in the frontend; the backend just hydrates it. +""" + +import asyncio +from pathlib import Path + +from fastapi import APIRouter +from pydantic import BaseModel + +from spd.app.backend.routers.pretrain_info import _get_pretrain_info, _load_spd_config_lightweight +from spd.app.backend.utils import log_errors +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/run_registry", tags=["run_registry"]) + + +class DataAvailability(BaseModel): + harvest: bool + autointerp: bool + attributions: bool + graph_interp: bool + + +class RunInfoResponse(BaseModel): + wandb_run_id: str + architecture: str | None + availability: DataAvailability + + +def _has_glob_match(pattern_dir: Path, glob_pattern: str) -> bool: + """Check if any file matches a glob pattern under a directory.""" + if not pattern_dir.exists(): + return False + return next(pattern_dir.glob(glob_pattern), None) is not None + + +def _check_availability(run_id: str) -> DataAvailability: + """Lightweight filesystem checks for post-processing data availability.""" + harvest_dir = SPD_OUT_DIR / "harvest" / run_id + autointerp_dir = SPD_OUT_DIR / "autointerp" / run_id + attributions_dir = SPD_OUT_DIR / "dataset_attributions" / run_id + graph_interp_dir = SPD_OUT_DIR / "graph_interp" / run_id + + return DataAvailability( + harvest=_has_glob_match(harvest_dir, "h-*/harvest.db"), + autointerp=_has_glob_match(autointerp_dir, "a-*/.done"), + attributions=_has_glob_match(attributions_dir, "da-*/dataset_attributions.pt"), + graph_interp=_has_glob_match(graph_interp_dir, "*/interp.db"), + ) + + +def _get_architecture_summary(wandb_path: str) -> str | None: + """Get a short architecture label for a run. Returns None on failure.""" + try: + spd_config = _load_spd_config_lightweight(wandb_path) + info = _get_pretrain_info(spd_config) + parts: list[str] = [] + if info.dataset_short: + parts.append(info.dataset_short) + parts.append(info.model_type) + cfg = info.target_model_config + if cfg: + n_layer = cfg.get("n_layer") + n_embd = cfg.get("n_embd") + if n_layer is not None: + parts.append(f"{n_layer}L") + if n_embd is not None: + parts.append(f"d{n_embd}") + return " ".join(parts) + except Exception: + logger.exception(f"[run_registry] Failed to get architecture for {wandb_path}") + return None + + +def _build_run_info(wandb_run_id: str) -> RunInfoResponse: + _, _, run_id = parse_wandb_run_path(wandb_run_id) + return RunInfoResponse( + wandb_run_id=wandb_run_id, + architecture=_get_architecture_summary(wandb_run_id), + availability=_check_availability(run_id), + ) + + +@router.post("") +@log_errors +async def get_run_info(wandb_run_ids: list[str]) -> list[RunInfoResponse]: + """Return architecture and availability for the requested runs.""" + loop = asyncio.get_running_loop() + tasks = [loop.run_in_executor(None, _build_run_info, wid) for wid in wandb_run_ids] + return list(await asyncio.gather(*tasks)) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 0989cea54..3b9c2300e 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -15,6 +15,7 @@ from spd.autointerp.repo import InterpRepo from spd.configs import LMTaskConfig from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -42,6 +43,8 @@ class LoadedRun(BaseModel): backend_user: str dataset_attributions_available: bool dataset_search_enabled: bool + graph_interp_available: bool + autointerp_available: bool router = APIRouter(prefix="/api", tags=["runs"]) @@ -128,6 +131,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): harvest=HarvestRepo.open_most_recent(run_id), interp=InterpRepo.open(run_id), attributions=AttributionRepo.open(run_id), + graph_interp=GraphInterpRepo.open(run_id), ) logger.info(f"[API] Run {run.id} loaded on {DEVICE}") @@ -165,6 +169,8 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: backend_user=getpass.getuser(), dataset_attributions_available=manager.run_state.attributions is not None, dataset_search_enabled=dataset_search_enabled, + graph_interp_available=manager.run_state.graph_interp is not None, + autointerp_available=manager.run_state.interp is not None, ) diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index bc99dc831..bf5f3e9b2 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -19,8 +19,6 @@ class OutputProbability(BaseModel): logit: float # CI-masked (SPD model) raw logit target_prob: float # Target model probability target_logit: float # Target model raw logit - adv_pgd_prob: float | None = None # Adversarial PGD probability - adv_pgd_logit: float | None = None # Adversarial PGD raw logit token: str @@ -29,17 +27,6 @@ class OutputProbability(BaseModel): # ============================================================================= -class ActivationContextsGenerationConfig(BaseModel): - """Configuration for generating activation contexts.""" - - importance_threshold: float = 0.01 - n_batches: int = 100 - batch_size: int = 32 - n_tokens_either_side: int = 5 - topk_examples: int = 20 - separation_tokens: int = 0 - - class SubcomponentMetadata(BaseModel): """Lightweight metadata for a subcomponent (without examples/token_prs)""" diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 3804ce756..afbff5db6 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -8,10 +8,12 @@ python -m spd.app.backend.server --port 8000 """ +import os import time import traceback from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from pathlib import Path import fire import torch @@ -32,14 +34,19 @@ data_sources_router, dataset_attributions_router, dataset_search_router, + graph_interp_router, graphs_router, intervention_router, + investigations_router, + mcp_router, pretrain_info_router, prompts_router, + run_registry_router, runs_router, ) from spd.app.backend.state import StateManager from spd.log import logger +from spd.settings import SPD_APP_DEFAULT_RUN from spd.utils.distributed_utils import get_device DEVICE = get_device() @@ -48,6 +55,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + from spd.app.backend.routers.mcp import InvestigationConfig, set_investigation_config + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -58,6 +67,24 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for investigation mode (derives paths from investigation dir) + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + inv_dir = Path(investigation_dir) + set_investigation_config( + InvestigationConfig( + events_log_path=inv_dir / "events.jsonl", + investigation_dir=inv_dir, + ) + ) + logger.info(f"[STARTUP] Investigation mode enabled: dir={investigation_dir}") + + if SPD_APP_DEFAULT_RUN is not None: + from spd.app.backend.routers.runs import load_run + + logger.info(f"[STARTUP] Auto-loading default run: {SPD_APP_DEFAULT_RUN}") + load_run(SPD_APP_DEFAULT_RUN, context_length=512, manager=manager) + yield manager.close() @@ -157,8 +184,12 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_search_router) app.include_router(dataset_attributions_router) app.include_router(agents_router) +app.include_router(investigations_router) +app.include_router(mcp_router) app.include_router(data_sources_router) +app.include_router(graph_interp_router) app.include_router(pretrain_info_router) +app.include_router(run_registry_router) def cli(port: int = 8000) -> None: diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index cf71c2bc6..2cdabda73 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -5,14 +5,20 @@ - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +from fastapi import HTTPException + from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB, Run from spd.autointerp.repo import InterpRepo from spd.configs import Config from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.models.component_model import ComponentModel from spd.topology import TransformerTopology @@ -32,6 +38,7 @@ class RunState: harvest: HarvestRepo | None interp: InterpRepo | None attributions: AttributionRepo | None + graph_interp: GraphInterpRepo | None @dataclass @@ -62,6 +69,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -104,3 +112,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() diff --git a/spd/app/frontend/package-lock.json b/spd/app/frontend/package-lock.json index b6c451303..32da0218c 100644 --- a/spd/app/frontend/package-lock.json +++ b/spd/app/frontend/package-lock.json @@ -7,6 +7,9 @@ "": { "name": "frontend", "version": "0.0.0", + "dependencies": { + "marked": "^17.0.1" + }, "devDependencies": { "@eslint/js": "^9.38.0", "@sveltejs/vite-plugin-svelte": "^6.2.1", @@ -2347,6 +2350,18 @@ "@jridgewell/sourcemap-codec": "^1.5.5" } }, + "node_modules/marked": { + "version": "17.0.1", + "resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz", + "integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==", + "license": "MIT", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 20" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", diff --git a/spd/app/frontend/package.json b/spd/app/frontend/package.json index f54e1bb3d..f298885ce 100644 --- a/spd/app/frontend/package.json +++ b/spd/app/frontend/package.json @@ -27,5 +27,8 @@ "typescript": "~5.9.3", "typescript-eslint": "^8.46.2", "vite": "^7.1.7" + }, + "dependencies": { + "marked": "^17.0.1" } } diff --git a/spd/app/frontend/src/app.css b/spd/app/frontend/src/app.css index 8bb0c490f..bf6649aee 100644 --- a/spd/app/frontend/src/app.css +++ b/spd/app/frontend/src/app.css @@ -1,22 +1,22 @@ :root { - /* Punchy Research - crisp whites, bold contrasts */ + /* Goodfire-inspired - warm whites, navy text, vibrant blue accent */ --bg-base: #ffffff; --bg-surface: #ffffff; --bg-elevated: #ffffff; - --bg-inset: #f8f9fa; - --bg-hover: #f3f4f6; + --bg-inset: #f7f6f2; + --bg-hover: #f0efeb; - --border-subtle: #e0e0e0; - --border-default: #c0c0c0; - --border-strong: #888888; + --border-subtle: #e5e3dc; + --border-default: #c8c5bc; + --border-strong: #8a8780; - --text-primary: #111111; - --text-secondary: #555555; - --text-muted: #999999; + --text-primary: #1d272a; + --text-secondary: #646464; + --text-muted: #b4b4b4; - --accent-primary: #2563eb; - --accent-primary-bright: #3b82f6; - --accent-primary-dim: #1d4ed8; + --accent-primary: #7c4d33; + --accent-primary-bright: #96613f; + --accent-primary-dim: #5e3a27; --status-positive: #16a34a; --status-positive-bright: #22c55e; @@ -24,8 +24,10 @@ --status-negative-bright: #ef4444; --status-warning: #eab308; --status-warning-bright: #facc15; - --status-info: #2563eb; - --status-info-bright: #3b82f6; + --status-info: #4d65ff; + --status-info-bright: #6b7fff; + + --focus-ring: #4d65ff; /* Typography - Clean system fonts with mono for code */ --font-mono: "SF Mono", "Menlo", "Monaco", "Consolas", monospace; diff --git a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte index e20ba1adf..c9c304950 100644 --- a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte +++ b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte @@ -1,23 +1,29 @@ @@ -106,21 +112,22 @@
- @@ -129,58 +136,69 @@
-
- {#if displaySettings.centerOnPeak} -
- {#each paginatedIndices as idx (idx)} - {@const fp = firingPositions[idx]} -
-
- -
-
- + {#if loading} +
+
+ {#each Array(pageSize) as _, i (i)} +
+ {/each} +
+
+ {:else} + {@const d = loaded!} +
+ {#if displaySettings.centerOnPeak} +
+ {#each paginatedIndices as idx (idx)} + {@const fp = firingPositions[idx]} +
+
+ +
+
+ +
+
+ +
-
+ {/each} +
+ {:else} +
+ {#each paginatedIndices as idx (idx)} +
-
- {/each} -
- {:else} -
- {#each paginatedIndices as idx (idx)} -
- -
- {/each} -
- {/if} -
+ {/each} +
+ {/if} +
+ {/if}
diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index d20831c1a..3df0eeee7 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -1,15 +1,16 @@
@@ -288,7 +304,7 @@
@@ -412,26 +428,22 @@ {/if} - +
+ + {#if currentGraphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail.status === "loading"} -
Loading component data...
- {:else if componentData.componentDetail.status === "loaded"} - - {:else if componentData.componentDetail.status === "error"} - Error loading component data: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading component data: {String(activationExamples.error)} {:else} - Something went wrong loading component data. + {/if} + import { getContext, onMount } from "svelte"; + import { computeMaxAbsComponentAct } from "../lib/colors"; + import { mapLoadable } from "../lib/index"; + import { anyCorrelationStatsEnabled } from "../lib/displaySettings.svelte"; + import { useComponentDataExpectCached } from "../lib/useComponentDataExpectCached.svelte"; + import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; + import ActivationContextsPagedTable, { type ActivationExamplesData } from "./ActivationContextsPagedTable.svelte"; + import ComponentProbeInput from "./ComponentProbeInput.svelte"; + import ComponentCorrelationMetrics from "./ui/ComponentCorrelationMetrics.svelte"; + import DatasetAttributionsSection from "./ui/DatasetAttributionsSection.svelte"; + import InterpretationBadge from "./ui/InterpretationBadge.svelte"; + import SectionHeader from "./ui/SectionHeader.svelte"; + import StatusText from "./ui/StatusText.svelte"; + import TokenStatsSection from "./ui/TokenStatsSection.svelte"; + + const runState = getContext(RUN_KEY); + + type Props = { + layer: string; + cIdx: number; + }; + + let { layer, cIdx }: Props = $props(); + + const intruderScore = $derived(runState.getIntruderScore(`${layer}:${cIdx}`)); + + const componentData = useComponentDataExpectCached(); + + onMount(() => { + componentData.load(layer, cIdx); + }); + + const inputTokenLists = $derived.by(() => { + const tokenStats = componentData.tokenStats; + if (tokenStats.status !== "loaded" || tokenStats.data === null) return null; + return [ + { + title: "Top Precision", + mathNotation: "P(component fires | token)", + items: tokenStats.data.input.top_precision.map(([token, value]) => ({ + token, + value, + })), + maxScale: 1, + }, + ]; + }); + + const outputTokenLists = $derived.by(() => { + const tokenStats = componentData.tokenStats; + if (tokenStats.status !== "loaded" || tokenStats.data === null) return null; + const maxAbsPmi = Math.max( + tokenStats.data.output.top_pmi[0]?.[1] ?? 0, + Math.abs(tokenStats.data.output.bottom_pmi?.[0]?.[1] ?? 0), + ); + return [ + { + title: "Top PMI", + mathNotation: "positive association with predictions", + items: tokenStats.data.output.top_pmi.map(([token, value]) => ({ token, value })), + maxScale: maxAbsPmi, + }, + { + title: "Bottom PMI", + mathNotation: "negative association with predictions", + items: tokenStats.data.output.bottom_pmi.map(([token, value]) => ({ + token, + value, + })), + maxScale: maxAbsPmi, + }, + ]; + }); + + function formatNumericalValue(val: number): string { + return Math.abs(val) < 0.001 ? val.toExponential(2) : val.toFixed(3); + } + + const maxAbsComponentAct = $derived.by(() => { + if (componentData.componentDetail.status !== "loaded") return 1; + return computeMaxAbsComponentAct(componentData.componentDetail.data.example_component_acts); + }); + + const activationExamples = $derived( + mapLoadable( + componentData.componentDetail, + (d): ActivationExamplesData => ({ + tokens: d.example_tokens, + ci: d.example_ci, + componentActs: d.example_component_acts, + maxAbsComponentAct, + }), + ), + ); + + +
+
+

{layer}:{cIdx}

+
+ {#if componentData.componentDetail.status === "loaded"} + Mean CI: {formatNumericalValue(componentData.componentDetail.data.mean_ci)} + {/if} + {#if intruderScore !== null} + Intruder: {Math.round(intruderScore * 100)}% + {/if} +
+
+ + + +
+ + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + + {/if} +
+ + + + {#if componentData.datasetAttributions.status === "uninitialized"} + uninitialized + {:else if componentData.datasetAttributions.status === "loaded"} + {#if componentData.datasetAttributions.data !== null} + + {:else} + No dataset attributions available. + {/if} + {:else if componentData.datasetAttributions.status === "loading"} +
+ + Loading... +
+ {:else if componentData.datasetAttributions.status === "error"} +
+ + Error: {String(componentData.datasetAttributions.error)} +
+ {/if} + +
+ +
+ {#if componentData.tokenStats.status === "loading" || componentData.tokenStats.status === "uninitialized"} + Loading token stats... + {:else if componentData.tokenStats.status === "error"} + Error: {String(componentData.tokenStats.error)} + {:else} + + + + {/if} +
+
+ + {#if anyCorrelationStatsEnabled()} +
+ + {#if componentData.correlations.status === "loading"} + Loading... + {:else if componentData.correlations.status === "loaded" && componentData.correlations.data} + + {:else if componentData.correlations.status === "error"} + Error loading correlations: {String(componentData.correlations.error)} + {:else} + No correlations available. + {/if} +
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClusterPathInput.svelte b/spd/app/frontend/src/components/ClusterPathInput.svelte index 27b066c2c..6adcfb2b3 100644 --- a/spd/app/frontend/src/components/ClusterPathInput.svelte +++ b/spd/app/frontend/src/components/ClusterPathInput.svelte @@ -1,8 +1,8 @@ + +
+ {#if clusterMapping} + + {:else} + No clusters loaded. Use the cluster path input in the header bar to load a cluster mapping. + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClustersViewer.svelte b/spd/app/frontend/src/components/ClustersViewer.svelte new file mode 100644 index 000000000..6ff194c82 --- /dev/null +++ b/spd/app/frontend/src/components/ClustersViewer.svelte @@ -0,0 +1,249 @@ + + +
+ {#if selectedClusterId === null} +
+

Clusters ({clusterGroups.sorted.length})

+ {#each clusterGroups.sorted as [clusterId, members] (clusterId)} + {@const previewLabels = getPreviewLabels(members)} + + {/each} + {#if clusterGroups.singletons.length > 0} + + {/if} +
+ {:else} +
+
+ +

+ {selectedClusterId === "unclustered" ? "Unclustered" : `Cluster ${selectedClusterId}`} +

+ {selectedMembers.length} components +
+
+ {#each selectedMembers as member (`${member.layer}:${member.cIdx}`)} +
+ +
+ {/each} +
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/DataSourcesTab.svelte b/spd/app/frontend/src/components/DataSourcesTab.svelte index bc9282c27..c54a1fa0a 100644 --- a/spd/app/frontend/src/components/DataSourcesTab.svelte +++ b/spd/app/frontend/src/components/DataSourcesTab.svelte @@ -29,7 +29,7 @@ }); function formatConfigValue(value: unknown): string { - if (value === null || value === undefined) return "—"; + if (value === null || value === undefined) return "\u2014"; if (typeof value === "object") return JSON.stringify(value); return String(value); } @@ -50,116 +50,91 @@ } -
- {#if runState.run.status === "loaded" && runState.run.data.config_yaml} -
-

Run Config

-
{runState.run.data.config_yaml}
-
- {/if} - - - {#if pretrainData.status === "loading"} -
-

Target Model

-

Loading target model info...

-
- {:else if pretrainData.status === "loaded"} - {@const pt = pretrainData.data} -
-

Target Model

-
- Architecture - {pt.summary} - - {#if pt.pretrain_wandb_path} - Pretrain run - {pt.pretrain_wandb_path} - {/if} -
+
+ +
+ {#if runState.run.status === "loaded" && runState.run.data.config_yaml} +
+

Run Config

+
{runState.run.data.config_yaml}
+
+ {/if} - {#if pt.topology} -
-

Topology

- +
+

Target Model

+ {#if pretrainData.status === "loading"} +

Loading...

+ {:else if pretrainData.status === "loaded"} + {@const pt = pretrainData.data} +
+ Architecture + {pt.summary} + {#if pt.pretrain_wandb_path} + Pretrain run + {pt.pretrain_wandb_path} + {/if}
+ {#if pt.topology} +
+ +
+ {/if} + {#if pt.pretrain_config} +
+ Pretraining config +
{formatPretrainConfigYaml(pt.pretrain_config)}
+
+ {/if} + {:else if pretrainData.status === "error"} +

Failed to load target model info

{/if} - - {#if pt.pretrain_config} -
- Pretraining config -
{formatPretrainConfigYaml(pt.pretrain_config)}
-
- {/if} -
- {:else if pretrainData.status === "error"} -
-

Target Model

-

Failed to load target model info

- {/if} - - {#if data.status === "loading"} -

Loading data sources...

- {:else if data.status === "error"} -

Failed to load data sources: {data.error}

- {:else if data.status === "loaded"} - {@const { harvest, autointerp, attributions } = data.data} - - {#if !harvest && !autointerp && !attributions} -

No pipeline data available for this run.

- {/if} - - {#if harvest} -
-

Harvest

+
+ + +
+ +
+
+ +

Harvest

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.harvest} + {@const harvest = data.data.harvest}
Subrun {harvest.subrun_id} - Components {harvest.n_components.toLocaleString()} - Intruder eval {harvest.has_intruder_scores ? "yes" : "no"} - {#each Object.entries(harvest.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} - - {#if attributions} -
-

Dataset Attributions

-
- Subrun - {attributions.subrun_id} - - Batches - {attributions.n_batches_processed.toLocaleString()} - - Tokens - {attributions.n_tokens_processed.toLocaleString()} - - CI threshold - {attributions.ci_threshold} -
-
- {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
- {#if autointerp} -
-

Autointerp

+ +
+
+ +

Autointerp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.autointerp} + {@const autointerp = data.data.autointerp}
Subrun {autointerp.subrun_id} - Interpretations {autointerp.n_interpretations.toLocaleString()} - Eval scores {#if autointerp.eval_scores.length > 0} @@ -168,65 +143,156 @@ none {/if} - {#each Object.entries(autointerp.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} - {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Dataset Attributions

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.attributions} + {@const attributions = data.data.attributions} +
+ Subrun + {attributions.subrun_id} + Tokens + {attributions.n_tokens_processed.toLocaleString()} + CI threshold + {attributions.ci_threshold} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Graph Interp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.graph_interp} + {@const graph_interp = data.data.graph_interp} +
+ Subrun + {graph_interp.subrun_id} + {#each Object.entries(graph_interp.label_counts) as [key, value] (key)} + {key} labels + {value.toLocaleString()} + {/each} + {#if graph_interp.config} + {#each Object.entries(graph_interp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} + {/if} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+
diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte new file mode 100644 index 000000000..a7dea1423 --- /dev/null +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -0,0 +1,642 @@ + + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ + {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} + {#if selected.data.research_log} + + {:else} +

No research log available

+ {/if} + {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ +
{ + e.preventDefault(); + handleLaunch(); + }} + > + + +
+ {#if launchState.status === "error"} +
{launchState.error}
+ {/if} + + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

+ No investigations found. Run spd-investigate to create one. +

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ModelGraph.svelte b/spd/app/frontend/src/components/ModelGraph.svelte new file mode 100644 index 000000000..ae49a25e3 --- /dev/null +++ b/spd/app/frontend/src/components/ModelGraph.svelte @@ -0,0 +1,520 @@ + + +
+ +
+
+ + + +
+
+ +
+
+ +
+
+ +
+
+ {filteredNodes.length} nodes, {visibleEdges.length} edges +
+
+ + + +
+
+ + + + + + {#if tooltipNode} +
+
{tooltipNode.label}
+
+ {tooltipNode.confidence} + {tooltipNode.key} +
+
+ {/if} + + + {#if selectedNodeKey} + {@const selectedNode = layout.nodes.get(selectedNodeKey)} + {#if selectedNode} +
+
+ {selectedNode.label} + {selectedNode.confidence} + +
+
{selectedNode.key}
+
+ {#if selectedNodeEdges.length > 0} +
+ {#each selectedNodeEdges as e, i (i)} + {@const other = e.source === selectedNodeKey ? e.target : e.source} + {@const otherNode = layout.nodes.get(other)} +
+ {e.source === selectedNodeKey ? "to" : "from"} + {otherNode?.label ?? other} + {e.attribution.toFixed(3)} +
+ {/each} +
+ {:else} + No edges + {/if} +
+
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ProbColoredTokens.svelte b/spd/app/frontend/src/components/ProbColoredTokens.svelte index 9a81b4858..7b30870c3 100644 --- a/spd/app/frontend/src/components/ProbColoredTokens.svelte +++ b/spd/app/frontend/src/components/ProbColoredTokens.svelte @@ -1,5 +1,7 @@ {#each tokens as tok, i (i)}{tok}{#each tokens as tok, i (i)}{@const prob = getProbAtPosition(nextTokenProbs, i)}{/each} @@ -33,34 +26,10 @@ display: inline-flex; flex-wrap: wrap; gap: 1px; - font-family: var(--font-mono); } - .prob-token { - padding: 1px 2px; + .prob-token-wrapper { border-right: 1px solid var(--border-subtle); - position: relative; - white-space: pre; - } - - .prob-token::after { - content: attr(data-tooltip); - position: absolute; - top: calc(100% + 4px); - left: 0; - background: var(--bg-elevated); - border: 1px solid var(--border-strong); - color: var(--text-primary); - padding: var(--space-1) var(--space-2); - font-size: var(--text-xs); - font-family: var(--font-mono); - white-space: nowrap; - opacity: 0; - pointer-events: none; - z-index: 1000; - } - - .prob-token:hover::after { - opacity: 1; + padding: 1px 0; } diff --git a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte index da0d0e935..c5e182917 100644 --- a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte @@ -1,6 +1,6 @@
@@ -40,7 +57,15 @@ {/if} @@ -93,6 +128,10 @@ {runState.run.error}
{/if} + +
+ +
{#if runState.prompts.status === "loaded"}
@@ -109,6 +148,11 @@
+ {#if runState.clusterMapping} +
+ +
+ {/if} {:else if runState.run.status === "loading" || runState.prompts.status === "loading"}

Loading run...

diff --git a/spd/app/frontend/src/components/TokenHighlights.svelte b/spd/app/frontend/src/components/TokenHighlights.svelte index 4916c1f00..456cdbc72 100644 --- a/spd/app/frontend/src/components/TokenHighlights.svelte +++ b/spd/app/frontend/src/components/TokenHighlights.svelte @@ -1,5 +1,6 @@ + +
+ {#if caption} +
{caption}
+ {/if} + + +
+ + +
+ + + {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} + + {getRowLabel(getRowKey(layer))} + + {/each} + + +
+ +
+ + + + + {@html edgesSvg} + + + + {#each Object.entries(layout.nodePositions) as [key, pos] (key)} + {@const style = nodeStyles[key]} + {@const [layer, seqIdxStr, cIdxStr] = key.split(":")} + {@const seqIdx = parseInt(seqIdxStr)} + {@const cIdx = parseInt(cIdxStr)} + + handleNodeHover(e, layer, seqIdx, cIdx)} + onmouseleave={handleNodeLeave} + /> + + + {/each} + + + + +
+ + + {#each data.tokens as token, i (i)} + {@const colLeft = layout.seqXStarts[i] + 8} + + {token} + + [{i}] + {/each} + + +
+
+
+ +
+ L0: {data.l0_total} · Edges: {filteredEdges.length} +
+ + + {#if hoveredNode && runState} + (isHoveringTooltip = true)} + onMouseLeave={() => { + isHoveringTooltip = false; + hoveredNode = null; + }} + /> + {/if} +
+ + diff --git a/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte new file mode 100644 index 000000000..ef9b8d40d --- /dev/null +++ b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte @@ -0,0 +1,223 @@ + + +
+ {#each contentBlocks as block, i (i)} + {#if block.type === "html"} + +
{@html block.content}
+ {:else if block.type === "graph"} + {@const artifact = artifacts[block.artifactId]} + {#if artifact} + + {:else if artifactsLoading} +
+ Loading graph: {block.artifactId}... +
+ {:else} +
+ Graph artifact not found: {block.artifactId} +
+ {/if} + {/if} + {/each} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte index a0d663208..8d6ef9263 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte @@ -1,9 +1,11 @@
@@ -208,30 +204,26 @@
- +
+ + {#if graphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail.status === "uninitialized"} - uninitialized - {:else if componentData.componentDetail.status === "loading"} - Loading details... - {:else if componentData.componentDetail.status === "loaded"} - {#if componentData.componentDetail.data.example_tokens.length > 0} - - {/if} - {:else if componentData.componentDetail.status === "error"} - Error loading details: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + {/if}
@@ -243,34 +235,29 @@ title="Prompt Attributions" incomingLabel="Incoming" outgoingLabel="Outgoing" - {incomingPositive} - {incomingNegative} - {outgoingPositive} - {outgoingNegative} + {incoming} + {outgoing} pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE} onClick={handleEdgeNodeClick} - {tokens} - {outputProbs} /> {/if} - {#if componentData.datasetAttributions.status === "uninitialized"} - uninitialized + {#if componentData.datasetAttributions.status === "loading" || componentData.datasetAttributions.status === "uninitialized"} +
+ +
+
+
+
+
{:else if componentData.datasetAttributions.status === "loaded"} {#if componentData.datasetAttributions.data !== null} - {:else} - No dataset attributions available. {/if} - {:else if componentData.datasetAttributions.status === "loading"} -
- - Loading... -
{:else if componentData.datasetAttributions.status === "error"}
@@ -282,7 +269,12 @@
{#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} - Loading token stats... +
+
+
+
+
+
{:else if componentData.tokenStats.status === "error"} Error: {String(componentData.tokenStats.error)} {:else} @@ -306,7 +298,10 @@
{#if componentData.correlations.status === "loading"} - Loading... +
+
+
+
{:else if componentData.correlations.status === "loaded" && componentData.correlations.data} diff --git a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte index 1f3c09a01..23619c281 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte @@ -1,46 +1,54 @@
-
- {#each state.stages as stage, i (i)} - {@const isCurrent = i === state.currentStage} - {@const isComplete = i < state.currentStage} -
-
- {i + 1} - {stage.name} - {#if isComplete} - - {/if} -
- {#if isCurrent} -
- {#if stage.progress !== null} -
- {:else} -
+
+ {#if ciSnapshot} + + {/if} +
+ {#each state.stages as stage, i (i)} + {@const isCurrent = i === state.currentStage} + {@const isComplete = i < state.currentStage} +
+
+ {i + 1} + {stage.name} + {#if isComplete} + {/if}
- {:else if isComplete} -
-
-
- {:else} -
- {/if} -
- {/each} + {#if isCurrent} +
+ {#if stage.progress !== null} +
+ {:else} +
+ {/if} +
+ {:else if isComplete} +
+
+
+ {:else} +
+ {/if} +
+ {/each} +
@@ -54,6 +62,13 @@ z-index: 100; } + .content { + display: flex; + flex-direction: column; + align-items: center; + gap: var(--space-6); + } + .stages { display: flex; flex-direction: column; diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index 95433d57e..84d0fdc25 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -1,10 +1,11 @@
+ {#if displaySettings.showEdgeAttributions && wteOutgoing.length > 0} + {}} + /> + {/if} {:else if isOutput} - + {:else if !hideNodeCard} {#key `${hoveredNode.layer}:${hoveredNode.cIdx}`} diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte new file mode 100644 index 000000000..3fa6a0213 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte @@ -0,0 +1,134 @@ + + +
+
+ + Step {snapshot.step}/{snapshot.total_steps} + + + L0: {Math.round(snapshot.l0_total)} / {initialL0} + ({(fractionRemaining * 100).toFixed(0)}%) + + {#if snapshot.loss > 0} + loss: {snapshot.loss.toFixed(4)} + {/if} +
+ +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte index fa3413320..83d9f0594 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte @@ -21,42 +21,86 @@
- steps{optimization.steps} - imp_min{optimization.imp_min_coeff} - pnorm{optimization.pnorm} - beta{optimization.beta} - mask{optimization.mask_type} - + steps{optimization.steps} + imp_min{optimization.imp_min_coeff} + pnorm{optimization.pnorm} + beta{optimization.beta} + mask{optimization.mask_type} + {optimization.loss.type}{optimization.loss.coeff} - - pos{optimization.loss.position}{#if tokenAtPos !== null} - ({tokenAtPos}){/if} + + pos + {optimization.loss.position} + {#if tokenAtPos !== null} + ({tokenAtPos}) + {/if} - {#if optimization.loss.type === "ce"} - + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + label({optimization.loss.label_str}) {/if} - {#if optimization.adv_pgd_n_steps !== null} - - adv_steps{optimization.adv_pgd_n_steps} + {#if optimization.pgd} + + pgd_steps{optimization.pgd.n_steps} - - adv_lr{optimization.adv_pgd_step_size} + + pgd_lr{optimization.pgd.step_size} {/if} - + L0{optimization.metrics.l0_total.toFixed(1)} - {#if optimization.loss.type === "ce"} - + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + CI prob{formatProb(optimization.metrics.ci_masked_label_prob)} - + stoch prob{formatProb(optimization.metrics.stoch_masked_label_prob)} + + adv prob{formatProb(optimization.metrics.adv_pgd_label_prob)} + {/if}
diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte index f762cdff9..3c15034b4 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte @@ -1,15 +1,19 @@
@@ -68,48 +68,57 @@ /> Cross-Entropy +
- -
- At position - { - if (e.currentTarget.value === "") return; - const position = parseInt(e.currentTarget.value); - onChange({ ...config, loss: { ...config.loss, position } }); - }} - min={0} - max={tokens.length - 1} - step={1} - /> - {#if tokenAtSeqPos !== null} - ({tokenAtSeqPos}) - {/if} - {#if config.loss.type === "ce"} - , predict - { - if (config.loss.type !== "ce") - throw new Error( - "inconsistent state: Token dropdown rendered but loss not type CE but no label token", - ); - - if (tokenId !== null) { - onChange({ - ...config, - loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, - }); - } - }} - placeholder="token..." - /> - {/if} + +
+ +
+ {#each tokens as tok, i (i)} + {@const prob = getProbAtPosition(nextTokenProbs, i)} + + {/each} +
+
+ pos {config.loss.position} + {#if config.loss.type === "ce" || config.loss.type === "logit"} + {config.loss.type === "logit" ? "maximize" : "predict"} + { + if (config.loss.type !== "ce" && config.loss.type !== "logit") + throw new Error("inconsistent state: Token dropdown rendered but loss type has no label"); + + if (tokenId !== null) { + onChange({ + ...config, + loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, + }); + } + }} + placeholder="token..." + /> + {/if} +
@@ -256,7 +265,7 @@ display: flex; flex-direction: column; gap: var(--space-3); - max-width: 400px; + max-width: 500px; } .loss-type-options { @@ -296,44 +305,99 @@ color: var(--text-primary); } - .target-section { + .position-section { display: flex; - align-items: center; + flex-direction: column; gap: var(--space-2); - flex-wrap: wrap; - padding: var(--space-2); - background: var(--bg-surface); - border: 1px solid var(--border-default); } - .target-label { + .section-label { font-size: var(--text-xs); + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; color: var(--text-muted); } - .pos-input { - width: 50px; - padding: var(--space-1) var(--space-2); + .token-strip { + display: flex; + flex-wrap: wrap; + gap: 2px; + padding: var(--space-2); + background: var(--bg-inset); border: 1px solid var(--border-default); - background: var(--bg-base); - color: var(--text-primary); - font-size: var(--text-sm); font-family: var(--font-mono); + font-size: var(--text-sm); } - .pos-input:focus { - outline: none; - border-color: var(--accent-primary-dim); + .strip-token { + padding: 2px 2px; + border: 1px solid var(--border-subtle); + border-radius: 2px; + cursor: pointer; + white-space: pre; + font-family: inherit; + font-size: inherit; + color: var(--text-primary); + background: transparent; + position: relative; + transition: + border-color var(--transition-fast), + box-shadow var(--transition-fast); } - .token { - white-space: pre; + .strip-token:hover { + border-color: var(--border-strong); + } + + .strip-token.selected { + border-color: var(--accent-primary); + box-shadow: 0 0 0 1px var(--accent-primary); + z-index: 1; + } + + .strip-token::after { + content: attr(title); + position: absolute; + bottom: calc(100% + 4px); + left: 50%; + transform: translateX(-50%); + background: var(--bg-elevated); + border: 1px solid var(--border-strong); + color: var(--text-primary); + padding: var(--space-1) var(--space-2); + font-size: var(--text-xs); + white-space: nowrap; + opacity: 0; + pointer-events: none; + z-index: 100; + border-radius: var(--radius-sm); + } + + .strip-token:hover::after { + opacity: 1; + } + + .position-info { + display: flex; + align-items: center; + gap: var(--space-2); + } + + .pos-label { + font-size: var(--text-xs); font-family: var(--font-mono); + color: var(--text-muted); background: var(--bg-inset); - padding: 0 var(--space-1); + padding: var(--space-1) var(--space-2); border-radius: var(--radius-sm); } + .predict-label { + font-size: var(--text-xs); + color: var(--text-muted); + } + .slider-section { display: flex; flex-direction: column; @@ -346,14 +410,6 @@ align-items: center; } - .section-label { - font-size: var(--text-xs); - font-weight: 600; - text-transform: uppercase; - letter-spacing: 0.05em; - color: var(--text-muted); - } - .imp-min-input { width: 80px; padding: var(--space-1) var(--space-2); diff --git a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte index 4fdd57b1b..9dcab7f0e 100644 --- a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte @@ -1,14 +1,25 @@
@@ -61,13 +64,6 @@ 2, )})
- {#if singlePosEntry.adv_pgd_prob !== null && singlePosEntry.adv_pgd_logit !== null} -
- Adversarial: {(singlePosEntry.adv_pgd_prob * 100).toFixed(1)}% (logit: {singlePosEntry.adv_pgd_logit.toFixed( - 2, - )}) -
- {/if}

Position: @@ -83,10 +79,6 @@ Logit Target Logit - {#if hasAdvPgd} - Adv - Logit - {/if} @@ -97,15 +89,22 @@ {pos.logit.toFixed(2)} {(pos.target_prob * 100).toFixed(2)}% {pos.target_logit.toFixed(2)} - {#if hasAdvPgd} - {pos.adv_pgd_prob !== null ? (pos.adv_pgd_prob * 100).toFixed(2) + "%" : "—"} - {pos.adv_pgd_logit !== null ? pos.adv_pgd_logit.toFixed(2) : "—"} - {/if} {/each} {/if} + {#if displaySettings.showEdgeAttributions && outputIncoming.length > 0} + {}} + /> + {/if}

diff --git a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte index 077917234..63a1062d0 100644 --- a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte +++ b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte @@ -1,16 +1,19 @@
Center on peak +
+
+

Edge Variant

+

Attribution target: value or |value|

+
+ {#each edgeVariants as variant (variant)} + + {/each} +
+

Component Filtering

Filter components in Components tab by mean CI

diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte index 844cb7c04..ad3f821f5 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte @@ -1,5 +1,5 @@ -{#if hasAnyIncoming} +{#if incoming.length > 0}
-
- {#if incomingPositive.length > 0} -
- -
- {/if} - {#if incomingNegative.length > 0} -
- -
- {/if} -
+
{/if} -{#if hasAnyOutgoing} +{#if outgoing.length > 0}
-
- {#if outgoingPositive.length > 0} -
- -
- {/if} - {#if outgoingNegative.length > 0} -
- -
- {/if} -
+
{/if} @@ -110,17 +36,4 @@ flex-direction: column; gap: var(--space-2); } - - .pos-neg-row { - display: grid; - grid-template-columns: 1fr 1fr; - gap: var(--space-3); - } - - .edge-list { - min-width: 0; - display: flex; - flex-direction: column; - gap: var(--space-1); - } diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte index aa98848b5..3809aceaf 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte @@ -1,7 +1,8 @@ + +
+ + + {#if expanded && detail} +
+
+
+ Input + {#if detail.input?.reasoning} +

{detail.input.reasoning}

+ {/if} + {#each incomingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+ Output + {#if detail.output?.reasoning} +

{detail.output.reasoning}

+ {/if} + {#each outgoingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ui/TokenSpan.svelte b/spd/app/frontend/src/components/ui/TokenSpan.svelte new file mode 100644 index 000000000..4a5fa9b81 --- /dev/null +++ b/spd/app/frontend/src/components/ui/TokenSpan.svelte @@ -0,0 +1,43 @@ + + +{sanitizeToken(token)} + + diff --git a/spd/app/frontend/src/lib/api/correlations.ts b/spd/app/frontend/src/lib/api/correlations.ts index 2e56c3c7e..8dcc63f04 100644 --- a/spd/app/frontend/src/lib/api/correlations.ts +++ b/spd/app/frontend/src/lib/api/correlations.ts @@ -3,7 +3,7 @@ */ import type { SubcomponentCorrelationsResponse, TokenStatsResponse } from "../promptAttributionsTypes"; -import { apiUrl, fetchJson } from "./index"; +import { ApiError, apiUrl, fetchJson } from "./index"; export async function getComponentCorrelations( layer: string, @@ -47,10 +47,18 @@ export async function getIntruderScores(): Promise> { return fetchJson>("/api/correlations/intruder_scores"); } -export async function getInterpretationDetail(layer: string, componentIdx: number): Promise { - return fetchJson( - `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, - ); +export async function getInterpretationDetail( + layer: string, + componentIdx: number, +): Promise { + try { + return await fetchJson( + `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, + ); + } catch (e) { + if (e instanceof ApiError && e.status === 404) return null; + throw e; + } } export async function requestComponentInterpretation( diff --git a/spd/app/frontend/src/lib/api/dataSources.ts b/spd/app/frontend/src/lib/api/dataSources.ts index e715af1b1..ac20b7220 100644 --- a/spd/app/frontend/src/lib/api/dataSources.ts +++ b/spd/app/frontend/src/lib/api/dataSources.ts @@ -20,15 +20,21 @@ export type AutointerpInfo = { export type AttributionsInfo = { subrun_id: string; - n_batches_processed: number; n_tokens_processed: number; ci_threshold: number; }; +export type GraphInterpInfoDS = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + export type DataSourcesResponse = { harvest: HarvestInfo | null; autointerp: AutointerpInfo | null; attributions: AttributionsInfo | null; + graph_interp: GraphInterpInfoDS | null; }; export async function fetchDataSources(): Promise { diff --git a/spd/app/frontend/src/lib/api/datasetAttributions.ts b/spd/app/frontend/src/lib/api/datasetAttributions.ts index f995a33f6..030eae6c6 100644 --- a/spd/app/frontend/src/lib/api/datasetAttributions.ts +++ b/spd/app/frontend/src/lib/api/datasetAttributions.ts @@ -9,15 +9,23 @@ export type DatasetAttributionEntry = { layer: string; component_idx: number; value: number; + token_str: string | null; }; -export type ComponentAttributions = { +export type SignedAttributions = { positive_sources: DatasetAttributionEntry[]; negative_sources: DatasetAttributionEntry[]; positive_targets: DatasetAttributionEntry[]; negative_targets: DatasetAttributionEntry[]; }; +export type AttrMetric = "attr" | "attr_abs"; + +export type AllMetricAttributions = { + attr: SignedAttributions; + attr_abs: SignedAttributions; +}; + export type DatasetAttributionsMetadata = { available: boolean; }; @@ -30,8 +38,8 @@ export async function getComponentAttributions( layer: string, componentIdx: number, k: number = 10, -): Promise { +): Promise { const url = apiUrl(`/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("k", String(k)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } diff --git a/spd/app/frontend/src/lib/api/graphInterp.ts b/spd/app/frontend/src/lib/api/graphInterp.ts new file mode 100644 index 000000000..8229e757c --- /dev/null +++ b/spd/app/frontend/src/lib/api/graphInterp.ts @@ -0,0 +1,81 @@ +/** + * API client for /api/graph_interp endpoints. + */ + +import { fetchJson } from "./index"; + +export type GraphInterpHeadline = { + label: string; + confidence: string; + output_label: string | null; + input_label: string | null; +}; + +export type LabelDetail = { + label: string; + confidence: string; + reasoning: string; + prompt: string; +}; + +export type GraphInterpDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; +}; + +export type PromptEdgeResponse = { + related_key: string; + pass_name: string; + attribution: number; + related_label: string | null; + related_confidence: string | null; + token_str: string | null; +}; + +export type GraphInterpComponentDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; + edges: PromptEdgeResponse[]; +}; + +export type GraphNode = { + component_key: string; + label: string; + confidence: string; +}; + +export type GraphEdge = { + source: string; + target: string; + attribution: number; + pass_name: string; +}; + +export type ModelGraphResponse = { + nodes: GraphNode[]; + edges: GraphEdge[]; +}; + +export type GraphInterpInfo = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + +export async function getAllGraphInterpLabels(): Promise> { + return fetchJson>("/api/graph_interp/labels"); +} + +export async function getGraphInterpDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/labels/${layer}/${cIdx}`); +} + +export async function getGraphInterpComponentDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/detail/${layer}/${cIdx}`); +} + +export async function getModelGraph(): Promise { + return fetchJson("/api/graph_interp/graph"); +} diff --git a/spd/app/frontend/src/lib/api/graphs.ts b/spd/app/frontend/src/lib/api/graphs.ts index 42490d531..788a6a130 100644 --- a/spd/app/frontend/src/lib/api/graphs.ts +++ b/spd/app/frontend/src/lib/api/graphs.ts @@ -2,11 +2,25 @@ * API client for /api/graphs endpoints. */ -import type { GraphData, TokenizeResponse, TokenInfo } from "../promptAttributionsTypes"; +import type { GraphData, EdgeData, TokenizeResponse, TokenSearchResult, CISnapshot } from "../promptAttributionsTypes"; import { buildEdgeIndexes } from "../promptAttributionsTypes"; -import { setArchitecture } from "../layerAliasing"; import { apiUrl, ApiError, fetchJson } from "./index"; +/** Hydrate a raw API graph response into a full GraphData with edge indexes. */ +function hydrateGraph(raw: Record): GraphData { + const g = raw as Omit; + const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); + const edgesAbs = (g.edgesAbs as EdgeData[] | null) ?? null; + let edgesAbsBySource: Map | null = null; + let edgesAbsByTarget: Map | null = null; + if (edgesAbs) { + const absIndexes = buildEdgeIndexes(edgesAbs); + edgesAbsBySource = absIndexes.edgesBySource; + edgesAbsByTarget = absIndexes.edgesByTarget; + } + return { ...g, edgesBySource, edgesByTarget, edgesAbs, edgesAbsBySource, edgesAbsByTarget } as GraphData; +} + export type NormalizeType = "none" | "target" | "layer"; export type GraphProgress = { @@ -23,14 +37,13 @@ export type ComputeGraphParams = { includedNodes?: string[]; }; -/** - * Parse SSE stream and return GraphData result. - * Handles progress updates, errors, and completion messages. - */ -async function parseGraphSSEStream( +/** Generic SSE stream parser. Delegates result extraction to the caller via extractResult. */ +async function parseSSEStream( response: Response, + extractResult: (data: Record) => T, onProgress?: (progress: GraphProgress) => void, -): Promise { + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { const reader = response.body?.getReader(); if (!reader) { throw new Error("Response body is not readable"); @@ -38,7 +51,7 @@ async function parseGraphSSEStream( const decoder = new TextDecoder(); let buffer = ""; - let result: GraphData | null = null; + let result: T | null = null; while (true) { const { done, value } = await reader.read(); @@ -56,19 +69,12 @@ async function parseGraphSSEStream( if (data.type === "progress" && onProgress) { onProgress({ current: data.current, total: data.total, stage: data.stage }); + } else if (data.type === "ci_snapshot" && onCISnapshot) { + onCISnapshot(data as CISnapshot); } else if (data.type === "error") { throw new ApiError(data.error, 500); } else if (data.type === "complete") { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of data.data.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(data.data.edges); - result = { ...data.data, edgesBySource, edgesByTarget }; + result = extractResult(data); await reader.cancel(); break; } @@ -102,11 +108,11 @@ export async function computeGraphStream( throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); } - return parseGraphSSEStream(response, onProgress); + return parseSSEStream(response, (data) => hydrateGraph(data.data as Record), onProgress); } export type MaskType = "stochastic" | "ci"; -export type LossType = "ce" | "kl"; +export type LossType = "ce" | "kl" | "logit"; export type ComputeGraphOptimizedParams = { promptId: number; @@ -128,6 +134,7 @@ export type ComputeGraphOptimizedParams = { export async function computeGraphOptimizedStream( params: ComputeGraphOptimizedParams, onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, ): Promise { const url = apiUrl("/api/graphs/optimized/stream"); url.searchParams.set("prompt_id", String(params.promptId)); @@ -157,26 +164,79 @@ export async function computeGraphOptimizedStream( throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); } - return parseGraphSSEStream(response, onProgress); + return parseSSEStream( + response, + (data) => hydrateGraph(data.data as Record), + onProgress, + onCISnapshot, + ); +} + +export type ComputeGraphOptimizedBatchParams = { + promptId: number; + impMinCoeffs: number[]; + steps: number; + pnorm: number; + beta: number; + normalize: NormalizeType; + ciThreshold: number; + maskType: MaskType; + lossType: LossType; + lossCoeff: number; + lossPosition: number; + labelToken?: number; + advPgdNSteps?: number; + advPgdStepSize?: number; +}; + +export async function computeGraphOptimizedBatchStream( + params: ComputeGraphOptimizedBatchParams, + onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { + const url = apiUrl("/api/graphs/optimized/batch/stream"); + + const body: Record = { + prompt_id: params.promptId, + imp_min_coeffs: params.impMinCoeffs, + steps: params.steps, + pnorm: params.pnorm, + beta: params.beta, + normalize: params.normalize, + ci_threshold: params.ciThreshold, + mask_type: params.maskType, + loss_type: params.lossType, + loss_coeff: params.lossCoeff, + loss_position: params.lossPosition, + }; + if (params.labelToken !== undefined) body.label_token = params.labelToken; + if (params.advPgdNSteps !== undefined) body.adv_pgd_n_steps = params.advPgdNSteps; + if (params.advPgdStepSize !== undefined) body.adv_pgd_step_size = params.advPgdStepSize; + + const response = await fetch(url.toString(), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + if (!response.ok) { + const error = await response.json(); + throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); + } + + return parseSSEStream( + response, + (data) => (data.data as { graphs: Record[] }).graphs.map((g) => hydrateGraph(g)), + onProgress, + onCISnapshot, + ); } export async function getGraphs(promptId: number, normalize: NormalizeType, ciThreshold: number): Promise { const url = apiUrl(`/api/graphs/${promptId}`); url.searchParams.set("normalize", normalize); url.searchParams.set("ci_threshold", String(ciThreshold)); - const graphs = await fetchJson[]>(url.toString()); - return graphs.map((g) => { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of g.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); - return { ...g, edgesBySource, edgesByTarget }; - }); + const graphs = await fetchJson[]>(url.toString()); + return graphs.map((g) => hydrateGraph(g)); } export async function tokenizeText(text: string): Promise { @@ -185,15 +245,17 @@ export async function tokenizeText(text: string): Promise { return fetchJson(url.toString(), { method: "POST" }); } -export async function getAllTokens(): Promise { - const response = await fetchJson<{ tokens: TokenInfo[] }>("/api/graphs/tokens"); - return response.tokens; -} - -export async function searchTokens(query: string, limit: number = 10): Promise { +export async function searchTokens( + query: string, + promptId: number, + position: number, + limit: number = 20, +): Promise { const url = apiUrl("/api/graphs/tokens/search"); url.searchParams.set("q", query); url.searchParams.set("limit", String(limit)); - const response = await fetchJson<{ tokens: TokenInfo[] }>(url.toString()); + url.searchParams.set("prompt_id", String(promptId)); + url.searchParams.set("position", String(position)); + const response = await fetchJson<{ tokens: TokenSearchResult[] }>(url.toString()); return response.tokens; } diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 773663636..d2d810283 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -51,5 +51,8 @@ export * from "./datasetAttributions"; export * from "./intervention"; export * from "./dataset"; export * from "./clusters"; +export * from "./investigations"; export * from "./dataSources"; +export * from "./graphInterp"; export * from "./pretrainInfo"; +export * from "./runRegistry"; diff --git a/spd/app/frontend/src/lib/api/intervention.ts b/spd/app/frontend/src/lib/api/intervention.ts index 689c29cc1..154228181 100644 --- a/spd/app/frontend/src/lib/api/intervention.ts +++ b/spd/app/frontend/src/lib/api/intervention.ts @@ -2,11 +2,7 @@ * API client for /api/intervention endpoints. */ -import type { - ForkedInterventionRunSummary, - InterventionRunSummary, - RunInterventionRequest, -} from "../interventionTypes"; +import type { InterventionRunSummary, RunInterventionRequest } from "../interventionTypes"; export async function runAndSaveIntervention(request: RunInterventionRequest): Promise { const response = await fetch("/api/intervention/run", { @@ -39,30 +35,3 @@ export async function deleteInterventionRun(runId: number): Promise { throw new Error(error.detail || "Failed to delete intervention run"); } } - -export async function forkInterventionRun( - runId: number, - tokenReplacements: [number, number][], - topK: number = 10, -): Promise { - const response = await fetch(`/api/intervention/runs/${runId}/fork`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ token_replacements: tokenReplacements, top_k: topK }), - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to fork intervention run"); - } - return (await response.json()) as ForkedInterventionRunSummary; -} - -export async function deleteForkedInterventionRun(forkId: number): Promise { - const response = await fetch(`/api/intervention/forks/${forkId}`, { - method: "DELETE", - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to delete forked intervention run"); - } -} diff --git a/spd/app/frontend/src/lib/api/investigations.ts b/spd/app/frontend/src/lib/api/investigations.ts new file mode 100644 index 000000000..42f1fb1f3 --- /dev/null +++ b/spd/app/frontend/src/lib/api/investigations.ts @@ -0,0 +1,101 @@ +/** + * API client for investigation results. + */ + +export interface InvestigationSummary { + id: string; // inv_id (e.g., "inv-abc12345") + wandb_path: string | null; + prompt: string | null; + created_at: string; + has_research_log: boolean; + has_explanations: boolean; + event_count: number; + last_event_time: string | null; + last_event_message: string | null; + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; // in_progress, completed, inconclusive +} + +export interface EventEntry { + event_type: string; + timestamp: string; + message: string; + details: Record | null; +} + +export interface InvestigationDetail { + id: string; + wandb_path: string | null; + prompt: string | null; + created_at: string; + research_log: string | null; + events: EventEntry[]; + explanations: Record[]; + artifact_ids: string[]; // List of artifact IDs available for this investigation + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; +} + +import type { EdgeData, OutputProbability } from "../promptAttributionsTypes"; + +/** Data for a graph artifact (subset of GraphData, self-contained for offline viewing) */ +export interface ArtifactGraphData { + tokens: string[]; + edges: EdgeData[]; + outputProbs: Record; + nodeCiVals: Record; + nodeSubcompActs: Record; + maxAbsAttr: number; + l0_total: number; +} + +export interface GraphArtifact { + type: "graph"; + id: string; + caption: string | null; + graph_id: number; + data: ArtifactGraphData; +} + +export interface LaunchResponse { + inv_id: string; + job_id: string; +} + +export async function launchInvestigation(prompt: string): Promise { + const res = await fetch("/api/investigations/launch", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ prompt }), + }); + if (!res.ok) throw new Error(`Failed to launch investigation: ${res.statusText}`); + return res.json(); +} + +export async function listInvestigations(): Promise { + const res = await fetch("/api/investigations"); + if (!res.ok) throw new Error(`Failed to list investigations: ${res.statusText}`); + return res.json(); +} + +export async function getInvestigation(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}`); + if (!res.ok) throw new Error(`Failed to get investigation: ${res.statusText}`); + return res.json(); +} + +export async function listArtifacts(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts`); + if (!res.ok) throw new Error(`Failed to list artifacts: ${res.statusText}`); + return res.json(); +} + +export async function getArtifact(invId: string, artifactId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts/${artifactId}`); + if (!res.ok) throw new Error(`Failed to get artifact: ${res.statusText}`); + return res.json(); +} diff --git a/spd/app/frontend/src/lib/api/pretrainInfo.ts b/spd/app/frontend/src/lib/api/pretrainInfo.ts index 7092c735a..0cd66bd97 100644 --- a/spd/app/frontend/src/lib/api/pretrainInfo.ts +++ b/spd/app/frontend/src/lib/api/pretrainInfo.ts @@ -20,6 +20,7 @@ export type TopologyInfo = { export type PretrainInfoResponse = { model_type: string; summary: string; + dataset_short: string | null; target_model_config: Record | null; pretrain_config: Record | null; pretrain_wandb_path: string | null; diff --git a/spd/app/frontend/src/lib/api/runRegistry.ts b/spd/app/frontend/src/lib/api/runRegistry.ts new file mode 100644 index 000000000..c727f4dcc --- /dev/null +++ b/spd/app/frontend/src/lib/api/runRegistry.ts @@ -0,0 +1,26 @@ +/** + * API client for /api/run_registry endpoint. + */ + +import { fetchJson } from "./index"; + +export type DataAvailability = { + harvest: boolean; + autointerp: boolean; + attributions: boolean; + graph_interp: boolean; +}; + +export type RunInfoResponse = { + wandb_run_id: string; + architecture: string | null; + availability: DataAvailability; +}; + +export async function fetchRunInfo(wandbRunIds: string[]): Promise { + return fetchJson("/api/run_registry", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(wandbRunIds), + }); +} diff --git a/spd/app/frontend/src/lib/api/runs.ts b/spd/app/frontend/src/lib/api/runs.ts index 1430632a4..d898c8671 100644 --- a/spd/app/frontend/src/lib/api/runs.ts +++ b/spd/app/frontend/src/lib/api/runs.ts @@ -14,6 +14,8 @@ export type LoadedRun = { backend_user: string; dataset_attributions_available: boolean; dataset_search_enabled: boolean; + graph_interp_available: boolean; + autointerp_available: boolean; }; export async function getStatus(): Promise { diff --git a/spd/app/frontend/src/lib/colors.ts b/spd/app/frontend/src/lib/colors.ts index e64cc696d..d15462693 100644 --- a/spd/app/frontend/src/lib/colors.ts +++ b/spd/app/frontend/src/lib/colors.ts @@ -7,17 +7,17 @@ */ export const colors = { - // Text - punchy contrast (matches --text-*) - textPrimary: "#111111", - textSecondary: "#555555", - textMuted: "#999999", + // Text - warm navy contrast (matches --text-*) + textPrimary: "#1d272a", + textSecondary: "#646464", + textMuted: "#b4b4b4", // Status colors for edges/data (matches --accent-primary, --status-negative) - positive: "#2563eb", + positive: "#4d65ff", negative: "#dc2626", // RGB components for dynamic opacity - positiveRgb: { r: 37, g: 99, b: 235 }, // blue - matches --accent-primary + positiveRgb: { r: 77, g: 101, b: 255 }, // vibrant blue - matches --accent-primary negativeRgb: { r: 220, g: 38, b: 38 }, // red - matches --status-negative // Output node gradient (green) - matches --status-positive @@ -28,10 +28,10 @@ export const colors = { tokenHighlightOpacity: 0.4, // Node default - nodeDefault: "#6b7280", + nodeDefault: "#8a8780", // Accent (for active states) - matches --accent-primary - accent: "#2563eb", + accent: "#7C4D33", // Set overlap visualization (A/B/intersection) setOverlap: { diff --git a/spd/app/frontend/src/lib/componentKeys.ts b/spd/app/frontend/src/lib/componentKeys.ts new file mode 100644 index 000000000..ff83bda06 --- /dev/null +++ b/spd/app/frontend/src/lib/componentKeys.ts @@ -0,0 +1,17 @@ +/** + * Utilities for component key display (e.g. rendering embed/output keys with token strings). + */ + +export function isTokenNode(key: string): boolean { + const layer = key.split(":")[0]; + return layer === "embed" || layer === "output"; +} + +export function formatComponentKey(key: string, tokenStr: string | null): string { + if (tokenStr && isTokenNode(key)) { + const layer = key.split(":")[0]; + const label = layer === "embed" ? "input" : "output"; + return `'${tokenStr}' (${label})`; + } + return key; +} diff --git a/spd/app/frontend/src/lib/displaySettings.svelte.ts b/spd/app/frontend/src/lib/displaySettings.svelte.ts index db3e3f7c9..6998214ee 100644 --- a/spd/app/frontend/src/lib/displaySettings.svelte.ts +++ b/spd/app/frontend/src/lib/displaySettings.svelte.ts @@ -13,6 +13,14 @@ export const NODE_COLOR_MODE_LABELS: Record = { subcomp_act: "Subcomp Act", }; +// Edge variant for attribution graphs +export type EdgeVariant = "signed" | "abs_target"; + +export const EDGE_VARIANT_LABELS: Record = { + signed: "Signed", + abs_target: "Abs Target", +}; + // Example color mode for activation contexts viewer export type ExampleColorMode = "ci" | "component_act" | "both"; @@ -48,6 +56,8 @@ export const displaySettings = $state({ meanCiCutoff: 1e-7, centerOnPeak: false, showAutoInterpPromptButton: false, + curvedEdges: true, + edgeVariant: "signed" as EdgeVariant, }); export function anyCorrelationStatsEnabled() { diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index c364be243..db2794da5 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -1,46 +1,105 @@ /** Types for the intervention forward pass feature */ -export type InterventionNode = { - layer: string; - seq_pos: number; - component_idx: number; -}; +/** Default eval PGD settings (distinct from training PGD which is an optimization regularizer) */ +export const EVAL_PGD_N_STEPS = 4; +export const EVAL_PGD_STEP_SIZE = 1.0; export type TokenPrediction = { token: string; token_id: number; - spd_prob: number; - target_prob: number; + prob: number; logit: number; + target_prob: number; target_logit: number; }; -export type InterventionResponse = { - input_tokens: string[]; - predictions_per_position: TokenPrediction[][]; +export type LabelPredictions = { + position: number; + ci: TokenPrediction; + stochastic: TokenPrediction; + adversarial: TokenPrediction; + ablated: TokenPrediction | null; }; -/** A forked intervention run with modified tokens */ -export type ForkedInterventionRunSummary = { - id: number; - token_replacements: [number, number][]; // [(seq_pos, new_token_id), ...] - result: InterventionResponse; - created_at: string; +export type InterventionResult = { + input_tokens: string[]; + ci: TokenPrediction[][]; + stochastic: TokenPrediction[][]; + adversarial: TokenPrediction[][]; + ablated: TokenPrediction[][] | null; + ci_loss: number; + stochastic_loss: number; + adversarial_loss: number; + ablated_loss: number | null; + label: LabelPredictions | null; }; /** Persisted intervention run from the server */ export type InterventionRunSummary = { id: number; selected_nodes: string[]; // node keys (layer:seq:cIdx) - result: InterventionResponse; + result: InterventionResult; created_at: string; - forked_runs?: ForkedInterventionRunSummary[]; // child runs with modified tokens }; /** Request to run and save an intervention */ export type RunInterventionRequest = { graph_id: number; - text: string; selected_nodes: string[]; - top_k?: number; + nodes_to_ablate?: string[]; + top_k: number; + adv_pgd: { n_steps: number; step_size: number }; }; + +// --- Frontend-only run lifecycle types --- + +import { SvelteSet } from "svelte/reactivity"; +import { isInterventableNode } from "./promptAttributionsTypes"; + +/** Draft run: cloned from a parent, editable node selection. No forwarded results yet. */ +export type DraftRun = { + kind: "draft"; + parentId: number; + selectedNodes: SvelteSet; +}; + +/** Baked run: forwarded and immutable. Wraps a persisted InterventionRunSummary. */ +export type BakedRun = { + kind: "baked"; + id: number; + selectedNodes: Set; + result: InterventionResult; + createdAt: string; +}; + +export type InterventionRun = DraftRun | BakedRun; + +export type InterventionState = { + runs: InterventionRun[]; + activeIndex: number; +}; + +/** Build initial InterventionState from persisted runs. + * The first persisted run is the base run (all CI > 0 nodes), auto-created during graph computation. */ +export function buildInterventionState(persistedRuns: InterventionRunSummary[]): InterventionState { + if (persistedRuns.length === 0) throw new Error("Graph must have at least one intervention run (the base run)"); + const runs: InterventionRun[] = persistedRuns.map( + (r): BakedRun => ({ + kind: "baked", + id: r.id, + selectedNodes: new Set(r.selected_nodes), + result: r.result, + createdAt: r.created_at, + }), + ); + return { runs, activeIndex: 0 }; +} + +/** Get all interventable node keys with CI > 0 from a nodeCiVals record */ +export function getInterventableNodes(nodeCiVals: Record): Set { + const nodes = new Set(); + for (const [nodeKey, ci] of Object.entries(nodeCiVals)) { + if (isInterventableNode(nodeKey) && ci > 0) nodes.add(nodeKey); + } + return nodes; +} diff --git a/spd/app/frontend/src/lib/layerAliasing.ts b/spd/app/frontend/src/lib/layerAliasing.ts deleted file mode 100644 index 2c5269543..000000000 --- a/spd/app/frontend/src/lib/layerAliasing.ts +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Layer aliasing system - transforms internal module names to human-readable aliases. - * - * Formats: - * - Internal: "h.0.mlp.c_fc", "h.1.attn.q_proj" - * - Aliased: "L0.mlp.in", "L1.attn.q" - * - * Handles multiple architectures: - * - GPT-2: c_fc -> mlp.in, down_proj -> mlp.out - * - Llama SwiGLU: gate_proj -> mlp.gate, up_proj -> mlp.up, down_proj -> mlp.down - * - Attention: q_proj -> attn.q, k_proj -> attn.k, v_proj -> attn.v, o_proj -> attn.o - * - Special: lm_head -> W_U, embed/output unchanged - */ - -type Architecture = "gpt2" | "llama" | "unknown"; - -/** Mapping of internal module names to aliases by architecture */ -const ALIASES: Record> = { - gpt2: { - // MLP - c_fc: "in", - down_proj: "out", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - llama: { - // MLP (SwiGLU) - gate_proj: "gate", - up_proj: "up", - down_proj: "down", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - unknown: { - // Fallback - just do attention mappings - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, -}; - -/** Special layers with fixed display names */ -const SPECIAL_LAYERS: Record = { - lm_head: "W_U", - embed: "embed", - output: "output", -}; - -// Cache for detected architecture from the full model -let cachedArchitecture: Architecture | null = null; - -/** - * Detect architecture from a collection of layer names. - * Llama has gate_proj/up_proj, GPT-2 has c_fc. - * - * This should be called once with all available layer names to establish - * the architecture for the session, ensuring down_proj is aliased correctly. - */ -export function detectArchitectureFromLayers(layers: string[]): Architecture { - const hasLlamaLayers = layers.some((layer) => layer.includes("gate_proj") || layer.includes("up_proj")); - if (hasLlamaLayers) { - return "llama"; - } - - const hasGPT2Layers = layers.some((layer) => layer.includes("c_fc")); - if (hasGPT2Layers) { - return "gpt2"; - } - - return "unknown"; -} - -/** - * Set the architecture for aliasing operations. - * Call this when you have access to all layer names (e.g., when loading a graph). - */ -export function setArchitecture(layers: string[]): void { - cachedArchitecture = detectArchitectureFromLayers(layers); -} - -/** - * Detect architecture from layer name. - * Uses cached architecture if available (set via setArchitecture()), - * otherwise falls back to single-layer detection. - * - * Note: down_proj appears in both architectures with different meanings: - * - GPT-2: down_proj -> "out" (second MLP projection) - * - Llama: down_proj -> "down" (third MLP projection after gate/up) - * - * Single-layer detection cannot distinguish these cases reliably. - */ -function detectArchitecture(layer: string): Architecture { - // Use cached architecture if available - if (cachedArchitecture !== null) { - return cachedArchitecture; - } - - // Fallback: single-layer detection (less reliable for down_proj) - if (layer.includes("gate_proj") || layer.includes("up_proj")) { - return "llama"; - } - if (layer.includes("c_fc")) { - return "gpt2"; - } - // down_proj is ambiguous without context, default to GPT-2 - if (layer.includes("down_proj")) { - return "gpt2"; - } - return "unknown"; -} - -/** - * Parse a layer name into components. - * Returns null for special layers (embed, output, lm_head) or unrecognized formats. - */ -function parseLayerName(layer: string): { block: number; moduleType: string; submodule: string } | null { - if (layer in SPECIAL_LAYERS) { - return null; - } - - const match = layer.match(/^h\.(\d+)\.(attn|mlp)\.(\w+)$/); - if (!match) { - return null; - } - - const [, blockStr, moduleType, submodule] = match; - return { - block: parseInt(blockStr), - moduleType, - submodule, - }; -} - -/** - * Transform a layer name to its aliased form. - * - * Examples: - * - "h.0.mlp.c_fc" -> "L0.mlp.in" - * - "h.2.attn.q_proj" -> "L2.attn.q" - * - "lm_head" -> "W_U" - * - "embed" -> "embed" - */ -export function getLayerAlias(layer: string): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Get a row label for grouped display in graphs. - * - * @param layer - Internal layer name (e.g., "h.0.mlp.c_fc") - * @param isQkvGroup - Whether this represents a grouped QKV row - * @returns Label (e.g., "L0.mlp.in", "L2.attn.qkv") - * - * @example - * getAliasedRowLabel("h.0.mlp.c_fc") // => "L0.mlp.in" - * getAliasedRowLabel("h.2.attn.q_proj", true) // => "L2.attn.qkv" - */ -export function getAliasedRowLabel(layer: string, isQkvGroup = false): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - if (isQkvGroup) { - return `L${parsed.block}.${parsed.moduleType}.qkv`; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Format a node key with aliased layer names. - * - * Node keys are "layer:seq:cIdx" or "layer:cIdx" format. - * - * Examples: - * - "h.0.mlp.c_fc:3:5" -> "L0.mlp.in:3:5" - * - "h.1.attn.q_proj:2:10" -> "L1.attn.q:2:10" - */ -export function formatNodeKeyWithAliases(nodeKey: string): string { - const parts = nodeKey.split(":"); - const layer = parts[0]; - const aliasedLayer = getLayerAlias(layer); - return [aliasedLayer, ...parts.slice(1)].join(":"); -} diff --git a/spd/app/frontend/src/lib/promptAttributionsTypes.ts b/spd/app/frontend/src/lib/promptAttributionsTypes.ts index fc705fad0..8d601d158 100644 --- a/spd/app/frontend/src/lib/promptAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/promptAttributionsTypes.ts @@ -20,18 +20,45 @@ export type EdgeAttribution = { key: string; // "layer:seq:cIdx" for prompt or "layer:cIdx" for dataset value: number; // raw attribution value (positive or negative) normalizedMagnitude: number; // |value| / maxAbsValue, for color intensity (0-1) + tokenStr: string | null; // resolved token string for embed/output layers }; +/** Sort edges by |val| desc, take top N, normalize magnitudes to [0,1]. */ +export function topEdgeAttributions( + edges: EdgeData[], + getKey: (e: EdgeData) => string, + n: number, + resolveTokenStr?: (key: string) => string | null, +): EdgeAttribution[] { + const sorted = [...edges].sort((a, b) => Math.abs(b.val) - Math.abs(a.val)).slice(0, n); + const maxAbsVal = Math.abs(sorted[0]?.val || 1); + return sorted.map((e) => ({ + key: getKey(e), + value: e.val, + normalizedMagnitude: Math.abs(e.val) / maxAbsVal, + tokenStr: resolveTokenStr ? resolveTokenStr(getKey(e)) : null, + })); +} + export type OutputProbability = { prob: number; // CI-masked (SPD model) probability logit: number; // CI-masked (SPD model) raw logit target_prob: number; // Target model probability target_logit: number; // Target model raw logit - adv_pgd_prob: number | null; // Adversarial PGD probability - adv_pgd_logit: number | null; // Adversarial PGD raw logit token: string; }; +export type CISnapshot = { + step: number; + total_steps: number; + layers: string[]; + seq_len: number; + initial_alive: number[][]; + current_alive: number[][]; + l0_total: number; + loss: number; +}; + export type GraphType = "standard" | "optimized" | "manual"; export type GraphData = { @@ -41,10 +68,15 @@ export type GraphData = { edges: EdgeData[]; edgesBySource: Map; // nodeKey -> edges where this node is source edgesByTarget: Map; // nodeKey -> edges where this node is target + // Absolute-target variant (∂|y|/∂x · x), null for old graphs + edgesAbs: EdgeData[] | null; + edgesAbsBySource: Map | null; + edgesAbsByTarget: Map | null; outputProbs: Record; // key is "seq:cIdx" nodeCiVals: Record; // node key -> CI value (or output prob for output nodes or 1 for wte node) nodeSubcompActs: Record; // node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: number; // max absolute edge value + maxAbsAttrAbs: number | null; // max absolute edge value for abs-target variant maxAbsSubcompAct: number; // max absolute subcomponent activation for normalization l0_total: number; // total active components at current CI threshold optimization?: OptimizationResult; @@ -93,7 +125,15 @@ export type KLLossResult = { position: number; }; -export type LossResult = CELossResult | KLLossResult; +export type LogitLossResult = { + type: "logit"; + coeff: number; + position: number; + label_token: number; + label_str: string; +}; + +export type LossResult = CELossResult | KLLossResult | LogitLossResult; export type OptimizationMetrics = { ci_masked_label_prob: number | null; // Probability of label under CI mask (CE loss only) @@ -102,6 +142,11 @@ export type OptimizationMetrics = { l0_total: number; // Total L0 (active components) }; +export type PgdConfig = { + n_steps: number; + step_size: number; +}; + export type OptimizationResult = { imp_min_coeff: number; steps: number; @@ -110,8 +155,7 @@ export type OptimizationResult = { mask_type: MaskType; loss: LossResult; metrics: OptimizationMetrics; - adv_pgd_n_steps: number | null; - adv_pgd_step_size: number | null; + pgd: PgdConfig | null; }; export type SubcomponentMetadata = { @@ -169,11 +213,33 @@ export type TokenizeResponse = { next_token_probs: (number | null)[]; // Probability of next token (last is null) }; -export type TokenInfo = { +export type TokenSearchResult = { id: number; string: string; + prob: number; }; +/** Select active edge set based on variant preference. Falls back to signed if abs unavailable. */ +export function getActiveEdges( + data: GraphData, + variant: "signed" | "abs_target", +): { edges: EdgeData[]; bySource: Map; byTarget: Map; maxAbsAttr: number } { + if (variant === "abs_target" && data.edgesAbs) { + return { + edges: data.edgesAbs, + bySource: data.edgesAbsBySource!, + byTarget: data.edgesAbsByTarget!, + maxAbsAttr: data.maxAbsAttrAbs || 1, + }; + } + return { + edges: data.edges, + bySource: data.edgesBySource, + byTarget: data.edgesByTarget, + maxAbsAttr: data.maxAbsAttr || 1, + }; +} + // Client-side computed types export type NodePosition = { @@ -233,7 +299,7 @@ export function formatNodeKeyForDisplay(nodeKey: string, displayNames: Record>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); // Current coords being loaded/displayed (for interpretation lookup) let currentCoords = $state(null); @@ -132,20 +134,40 @@ export function useComponentData() { datasetAttributions = { status: "loaded", data: null }; } - // Fetch interpretation detail (404 = no interpretation for this component) - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail (skip if not available for this run) + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } /** @@ -159,6 +181,7 @@ export function useComponentData() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache - reactive to both coords and cache @@ -212,6 +235,9 @@ export function useComponentData() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts index f32dab70a..d76c5da9e 100644 --- a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts @@ -6,7 +6,7 @@ * examples (200). Dataset attributions and interpretation detail are on-demand. */ -import { getContext } from "svelte"; +import { getContext, untrack } from "svelte"; import type { Loadable } from "."; import { ApiError, @@ -14,10 +14,11 @@ import { getComponentAttributions, getComponentCorrelations, getComponentTokenStats, + getGraphInterpComponentDetail, getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; +import type { AllMetricAttributions, GraphInterpComponentDetail, InterpretationDetail } from "./api"; import type { SubcomponentCorrelationsResponse, SubcomponentActivationContexts, @@ -29,7 +30,7 @@ const DATASET_ATTRIBUTIONS_TOP_K = 20; /** Fetch more activation examples in background after initial cached load */ const ACTIVATION_EXAMPLES_FULL_LIMIT = 200; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; @@ -39,8 +40,9 @@ export function useComponentDataExpectCached() { let componentDetail = $state>({ status: "uninitialized" }); let correlations = $state>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); let currentCoords = $state(null); let requestId = 0; @@ -87,21 +89,41 @@ export function useComponentDataExpectCached() { datasetAttributions = { status: "loaded", data: null }; } - // Fetch interpretation detail on-demand (not cached) - interpretationDetail = { status: "loading" }; - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + interpretationDetail = { status: "loading" }; + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } function load(layer: string, cIdx: number) { @@ -144,6 +166,7 @@ export function useComponentDataExpectCached() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache @@ -197,6 +220,9 @@ export function useComponentDataExpectCached() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useRun.svelte.ts b/spd/app/frontend/src/lib/useRun.svelte.ts index de6d20c7d..1cfc3cca6 100644 --- a/spd/app/frontend/src/lib/useRun.svelte.ts +++ b/spd/app/frontend/src/lib/useRun.svelte.ts @@ -7,13 +7,8 @@ import type { Loadable } from "."; import * as api from "./api"; -import type { LoadedRun as RunData, InterpretationHeadline } from "./api"; -import type { - PromptPreview, - SubcomponentActivationContexts, - TokenInfo, - SubcomponentMetadata, -} from "./promptAttributionsTypes"; +import type { LoadedRun as RunData, InterpretationHeadline, GraphInterpHeadline } from "./api"; +import type { PromptPreview, SubcomponentActivationContexts, SubcomponentMetadata } from "./promptAttributionsTypes"; /** Maps component keys to cluster IDs. Singletons (unclustered components) have null values. */ export type ClusterMappingData = Record; @@ -46,17 +41,15 @@ export function useRun() { /** Intruder eval scores keyed by component key */ let intruderScores = $state>>({ status: "uninitialized" }); + /** Graph interp labels keyed by component key (layer:cIdx) */ + let graphInterpLabels = $state>>({ status: "uninitialized" }); + /** Cluster mapping for the current run */ let clusterMapping = $state(null); /** Available prompts for the current run */ let prompts = $state>({ status: "uninitialized" }); - /** All tokens in the tokenizer for the current run */ - let allTokens = $state>({ status: "uninitialized" }); - - /** Model topology info for frontend layout */ - /** Activation contexts summary (null = harvest not available) */ let activationContextsSummary = $state | null>>({ status: "uninitialized", @@ -68,9 +61,9 @@ export function useRun() { /** Reset all run-scoped state */ function resetRunScopedState() { prompts = { status: "uninitialized" }; - allTokens = { status: "uninitialized" }; interpretations = { status: "uninitialized" }; intruderScores = { status: "uninitialized" }; + graphInterpLabels = { status: "uninitialized" }; activationContextsSummary = { status: "uninitialized" }; _componentDetailsCache = {}; clusterMapping = null; @@ -88,6 +81,9 @@ export function useRun() { api.getIntruderScores() .then((data) => (intruderScores = { status: "loaded", data })) .catch((error) => (intruderScores = { status: "error", error })); + api.getAllGraphInterpLabels() + .then((data) => (graphInterpLabels = { status: "loaded", data })) + .catch((error) => (graphInterpLabels = { status: "error", error })); api.getAllInterpretations() .then((i) => { interpretations = { @@ -106,14 +102,6 @@ export function useRun() { .catch((error) => (interpretations = { status: "error", error })); } - /** Fetch tokens - must complete before run is considered loaded */ - async function fetchTokens(): Promise { - allTokens = { status: "loading" }; - const tokens = await api.getAllTokens(); - allTokens = { status: "loaded", data: tokens }; - return tokens; - } - async function loadRun(wandbPath: string, contextLength: number) { run = { status: "loading" }; try { @@ -122,8 +110,6 @@ export function useRun() { if (status) { run = { status: "loaded", data: status }; fetchRunScopedData(); - // Fetch tokens in background (no longer blocks UI - used only by token search) - fetchTokens(); } else { run = { status: "error", error: "Failed to load run" }; } @@ -142,10 +128,6 @@ export function useRun() { try { const status = await api.getStatus(); if (status) { - // Fetch tokens and model info if we don't have them (e.g., page refresh) - if (allTokens.status === "uninitialized") { - await fetchTokens(); - } run = { status: "loaded", data: status }; // Fetch other run-scoped data if we don't have it if (interpretations.status === "uninitialized") { @@ -230,6 +212,11 @@ export function useRun() { return clusterMapping?.data[key] ?? null; } + function getGraphInterpLabel(componentKey: string): GraphInterpHeadline | null { + if (graphInterpLabels.status !== "loaded") return null; + return graphInterpLabels.data[componentKey] ?? null; + } + return { get run() { return run; @@ -237,21 +224,27 @@ export function useRun() { get interpretations() { return interpretations; }, + get graphInterpLabels() { + return graphInterpLabels; + }, get clusterMapping() { return clusterMapping; }, get prompts() { return prompts; }, - get allTokens() { - return allTokens; - }, get activationContextsSummary() { return activationContextsSummary; }, get datasetAttributionsAvailable() { return run.status === "loaded" && run.data.dataset_attributions_available; }, + get graphInterpAvailable() { + return run.status === "loaded" && run.data.graph_interp_available; + }, + get autoInterpAvailable() { + return run.status === "loaded" && run.data.autointerp_available; + }, loadRun, clearRun, syncStatus, @@ -259,6 +252,7 @@ export function useRun() { getInterpretation, setInterpretation, getIntruderScore, + getGraphInterpLabel, getActivationContextDetail, loadActivationContextsSummary, setClusterMapping, diff --git a/spd/app/frontend/vite.config.ts b/spd/app/frontend/vite.config.ts index fc72bbc92..a08d086fb 100644 --- a/spd/app/frontend/vite.config.ts +++ b/spd/app/frontend/vite.config.ts @@ -9,6 +9,7 @@ const backendUrl = process.env.BACKEND_URL || "http://localhost:8000"; export default defineConfig({ plugins: [svelte()], server: { + hmr: false, proxy: { "/api": { target: backendUrl, diff --git a/spd/app/run_app.py b/spd/app/run_app.py index 6aff0ce4c..c61174d1e 100755 --- a/spd/app/run_app.py +++ b/spd/app/run_app.py @@ -303,7 +303,7 @@ def spawn_frontend( return proc def monitor_child_liveness(self) -> None: - log_lines_to_show = 5 + log_lines_to_show = 20 prev_lines: list[str] = [] while True: diff --git a/spd/autointerp/db.py b/spd/autointerp/db.py index 4cba168b7..f05227f5c 100644 --- a/spd/autointerp/db.py +++ b/spd/autointerp/db.py @@ -1,6 +1,5 @@ """SQLite database for autointerp data (interpretations and scores). NFS-hosted, single writer then read-only.""" -import sqlite3 from pathlib import Path import orjson diff --git a/spd/autointerp/prompt_helpers.py b/spd/autointerp/prompt_helpers.py index ac1f412fd..7b2da592e 100644 --- a/spd/autointerp/prompt_helpers.py +++ b/spd/autointerp/prompt_helpers.py @@ -9,6 +9,17 @@ from spd.app.backend.utils import delimit_tokens from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md + + +def token_pmi_pairs( + app_tok: AppTokenizer, + token_pmi_top: list[tuple[int, float]] | None, +) -> list[tuple[str, float]] | None: + if not token_pmi_top: + return None + return [(app_tok.get_tok_display(tid), pmi) for tid, pmi in token_pmi_top] + DATASET_DESCRIPTIONS: dict[str, str] = { "SimpleStories/SimpleStories": ( @@ -43,11 +54,7 @@ def ordinal(n: int) -> str: def human_layer_desc(canonical: str, n_blocks: int) -> str: - """Convert canonical layer string to human-readable description. - - '0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks' - '1.attn.q' -> 'attention query projection in the 2nd of 4 blocks' - """ + """'0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks'""" m = re.match(r"(\d+)\.(.*)", canonical) if not m: return canonical @@ -58,7 +65,6 @@ def human_layer_desc(canonical: str, n_blocks: int) -> str: def layer_position_note(canonical: str, n_blocks: int) -> str: - """Brief note about what layer position means for interpretation.""" m = re.match(r"(\d+)\.", canonical) if not m: return "" @@ -86,75 +92,73 @@ def density_note(firing_density: float) -> str: def build_output_section( output_stats: TokenPRLift, output_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if output_pmi: - section += ( + md.labeled_list( "**Output PMI (pointwise mutual information, in nats: how much more likely " "a token is to be produced when this component fires, vs its base rate. " - "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**\n" + "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in output_pmi[:10]], ) - for tok, pmi in output_pmi[:10]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - if output_stats.top_precision: - section += "\n**Output precision — of all probability mass for token X, what fraction is at positions where this component fires?**\n" - for tok, prec in output_stats.top_precision[:10]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section + md.labeled_list( + "**Output precision — of all probability mass for token X, what fraction " + "is at positions where this component fires?**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in output_stats.top_precision[:10]], + ) + return md def build_input_section( input_stats: TokenPRLift, input_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if input_pmi: - section += "**Input PMI (same metric as above, for input tokens):**\n" - for tok, pmi in input_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - + md.labeled_list( + "**Input PMI (same metric as above, for input tokens):**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in input_pmi[:6]], + ) if input_stats.top_precision: - section += "\n**Input precision — probability the component fires given the current token is X:**\n" - for tok, prec in input_stats.top_precision[:8]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section + md.labeled_list( + "**Input precision — probability the component fires given the current token is X:**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in input_stats.top_precision[:8]], + ) + return md -def build_fires_on_examples( +def _build_examples( component: ComponentData, app_tok: AppTokenizer, max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] + shift_firings: bool, +) -> Md: + lines: list[str] = [] + for i, ex in enumerate(component.activation_examples[:max_examples]): + if not any(ex.firings): + continue + spans = app_tok.get_spans(ex.token_ids) + firings = [False] + ex.firings[:-1] if shift_firings else ex.firings + tokens = list(zip(spans, firings, strict=True)) + lines.append(f"{i + 1}. {delimit_tokens(tokens)}") + md = Md() + if lines: + md.p("\n".join(lines)) + return md - for i, ex in enumerate(examples): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - tokens = list(zip(spans, ex.firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - return section +def build_fires_on_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> Md: + return _build_examples(component, app_tok, max_examples, shift_firings=False) def build_says_examples( component: ComponentData, app_tok: AppTokenizer, max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - shifted_firings = [False] + ex.firings[:-1] - tokens = list(zip(spans, shifted_firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section +) -> Md: + return _build_examples(component, app_tok, max_examples, shift_firings=True) diff --git a/spd/autointerp/scoring/scripts/run_label_scoring.py b/spd/autointerp/scoring/scripts/run_label_scoring.py index be2efa388..fd95f9763 100644 --- a/spd/autointerp/scoring/scripts/run_label_scoring.py +++ b/spd/autointerp/scoring/scripts/run_label_scoring.py @@ -25,7 +25,7 @@ def main( decomposition_id: str, scorer_type: LabelScorerType, config_json: dict[str, Any], - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, ) -> None: assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" load_dotenv() @@ -44,15 +44,11 @@ def main( # Separate writable DB for saving scores (the repo's DB is readonly/immutable) score_db = InterpDB(interp_repo._subrun_dir / "interp.db") - if harvest_subrun_id is not None: - harvest = HarvestRepo( - decomposition_id=decomposition_id, - subrun_id=harvest_subrun_id, - readonly=True, - ) - else: - harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=True) - assert harvest is not None, f"No harvest data for {decomposition_id}" + harvest = HarvestRepo( + decomposition_id=decomposition_id, + subrun_id=harvest_subrun_id, + readonly=True, + ) components = harvest.get_all_components() @@ -99,18 +95,16 @@ def get_command( decomposition_id: str, scorer_type: LabelScorerType, config: AutointerpEvalConfig, - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, ) -> str: config_json = config.model_dump_json(exclude_none=True) - cmd = ( + return ( f"python -m spd.autointerp.scoring.scripts.run_label_scoring " f"--decomposition_id {decomposition_id} " f"--scorer_type {scorer_type} " f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " ) - if harvest_subrun_id is not None: - cmd += f" --harvest_subrun_id {harvest_subrun_id} " - return cmd if __name__ == "__main__": diff --git a/spd/autointerp/scripts/run_interpret.py b/spd/autointerp/scripts/run_interpret.py index da056dc35..263329172 100644 --- a/spd/autointerp/scripts/run_interpret.py +++ b/spd/autointerp/scripts/run_interpret.py @@ -22,7 +22,7 @@ def main( decomposition_id: str, config_json: dict[str, Any], - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, autointerp_subrun_id: str | None = None, ) -> None: assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" @@ -32,12 +32,7 @@ def main( openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") assert openrouter_api_key, "OPENROUTER_API_KEY not set" - if harvest_subrun_id is not None: - harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=False) - else: - harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=False) - if harvest is None: - raise ValueError(f"No harvest data found for {decomposition_id}") + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=False) if autointerp_subrun_id is not None: subrun_dir = get_autointerp_dir(decomposition_id) / autointerp_subrun_id @@ -76,7 +71,7 @@ def main( def get_command( decomposition_id: str, config: AutointerpConfig, - harvest_subrun_id: str | None = None, + harvest_subrun_id: str, autointerp_subrun_id: str | None = None, ) -> str: config_json = config.model_dump_json(exclude_none=True) @@ -84,9 +79,8 @@ def get_command( "python -m spd.autointerp.scripts.run_interpret " f"--decomposition_id {decomposition_id} " f"--config_json '{config_json}' " + f"--harvest_subrun_id {harvest_subrun_id} " ) - if harvest_subrun_id is not None: - cmd += f"--harvest_subrun_id {harvest_subrun_id} " if autointerp_subrun_id is not None: cmd += f"--autointerp_subrun_id {autointerp_subrun_id} " return cmd diff --git a/spd/autointerp/scripts/run_slurm.py b/spd/autointerp/scripts/run_slurm.py index bcd83b94c..4a1cd0bdc 100644 --- a/spd/autointerp/scripts/run_slurm.py +++ b/spd/autointerp/scripts/run_slurm.py @@ -30,9 +30,9 @@ class AutointerpSubmitResult: def submit_autointerp( decomposition_id: str, config: AutointerpSlurmConfig, + harvest_subrun_id: str, dependency_job_id: str | None = None, snapshot_branch: str | None = None, - harvest_subrun_id: str | None = None, ) -> AutointerpSubmitResult: """Submit the autointerp pipeline to SLURM. diff --git a/spd/autointerp/scripts/run_slurm_cli.py b/spd/autointerp/scripts/run_slurm_cli.py index fffc75c60..56db16499 100644 --- a/spd/autointerp/scripts/run_slurm_cli.py +++ b/spd/autointerp/scripts/run_slurm_cli.py @@ -10,18 +10,19 @@ import fire -def main(decomposition_id: str, config: str) -> None: +def main(decomposition_id: str, config: str, harvest_subrun_id: str) -> None: """Submit autointerp pipeline (interpret + evals) to SLURM. Args: decomposition_id: ID of the target decomposition run. config: Path to AutointerpSlurmConfig YAML/JSON. + harvest_subrun_id: Harvest subrun to use (e.g. "h-20260306_120000"). """ from spd.autointerp.config import AutointerpSlurmConfig from spd.autointerp.scripts.run_slurm import submit_autointerp slurm_config = AutointerpSlurmConfig.from_file(config) - submit_autointerp(decomposition_id, slurm_config) + submit_autointerp(decomposition_id, slurm_config, harvest_subrun_id=harvest_subrun_id) def cli() -> None: diff --git a/spd/autointerp/strategies/compact_skeptical.py b/spd/autointerp/strategies/compact_skeptical.py index 6998bfda8..857a11144 100644 --- a/spd/autointerp/strategies/compact_skeptical.py +++ b/spd/autointerp/strategies/compact_skeptical.py @@ -6,10 +6,15 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.autointerp.config import CompactSkepticalConfig -from spd.autointerp.prompt_helpers import DATASET_DESCRIPTIONS, build_fires_on_examples +from spd.autointerp.prompt_helpers import ( + DATASET_DESCRIPTIONS, + build_fires_on_examples, + token_pmi_pairs, +) from spd.autointerp.schemas import ModelMetadata from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md SPD_CONTEXT = ( "Each component has a causal importance (CI) value per token position. " @@ -29,29 +34,18 @@ def format_prompt( output_pmi: list[tuple[str, float]] | None = None if config.include_pmi: - input_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None - ) - output_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None - ) + input_pmi = token_pmi_pairs(app_tok, component.input_token_pmi.top) + output_pmi = token_pmi_pairs(app_tok, component.output_token_pmi.top) input_section = _build_input_section(input_token_stats, input_pmi) output_section = _build_output_section(output_token_stats, output_pmi) - examples_section = build_fires_on_examples( - component, - app_tok, - config.max_examples, - ) + examples_section = build_fires_on_examples(component, app_tok, config.max_examples) - if component.firing_density > 0.0: - rate_str = f"~1 in {int(1 / component.firing_density)} tokens" - else: - rate_str = "extremely rare" # TODO(oli) make this string better. does this even happen? + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) layer_desc = model_metadata.layer_descriptions.get(component.layer, component.layer) @@ -60,83 +54,89 @@ def format_prompt( dataset_desc = DATASET_DESCRIPTIONS[model_metadata.dataset_name] dataset_line = f", dataset: {dataset_desc}" - spd_context_block = f"\n{SPD_CONTEXT}\n" if config.include_spd_context else "" - forbidden = ", ".join(config.forbidden_words) if config.forbidden_words else "(none)" - return f"""\ -Label this neural network component. -{spd_context_block} -## Context -- Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line} -- Component location: {layer_desc} -- Component firing rate: {component.firing_density * 100:.2f}% ({rate_str}) - -## Token correlations + md = Md() + md.p("Label this neural network component.") -{input_section} -{output_section} + if config.include_spd_context: + md.p(SPD_CONTEXT) -## Activation examples (active tokens in <>) - -{examples_section} - -## Task + md.h(2, "Context").bullets( + [ + f"Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line}", + f"Component location: {layer_desc}", + f"Component firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) -Give a 2-{config.label_max_words} word label for what this component detects. + md.h(2, "Token correlations") + md.extend(input_section).extend(output_section) -Be SKEPTICAL. If you can't identify specific tokens or a tight grammatical pattern, say "unclear". + md.h(2, "Activation examples (active tokens in <>)") + md.extend(examples_section) -Rules: -1. Good labels name SPECIFIC tokens: "'the'", "##ing suffix", "she/her pronouns" -2. Say "unclear" if: tokens are too varied, pattern is abstract, or evidence is weak -3. FORBIDDEN words (too vague): {forbidden} -4. Lowercase only -5. Confidence: "high" = clear, specific pattern with strong evidence; "medium" = plausible but noisy; "low" = speculative + md.h(2, "Task") + md.p(f"Give a 2-{config.label_max_words} word label for what this component detects.") + md.p( + "Be SKEPTICAL. If you can't identify specific tokens or a tight grammatical " + 'pattern, say "unclear".' + ) + md.p("Rules:") + md.numbered( + [ + 'Good labels name SPECIFIC tokens: "\'the\'", "##ing suffix", "she/her pronouns"', + 'Say "unclear" if: tokens are too varied, pattern is abstract, or evidence is weak', + f"FORBIDDEN words (too vague): {forbidden}", + "Lowercase only", + 'Confidence: "high" = clear, specific pattern with strong evidence; ' + '"medium" = plausible but noisy; "low" = speculative', + ] + ) + md.p( + 'GOOD: "##ed suffix", "\'and\' conjunction", "she/her/hers", "period then capital", "unclear"\n' + 'BAD: "various words and punctuation", "verbs and adjectives", "tokens near commas"' + ) -GOOD: "##ed suffix", "'and' conjunction", "she/her/hers", "period then capital", "unclear" -BAD: "various words and punctuation", "verbs and adjectives", "tokens near commas" -""" + return md.build() def _build_input_section( input_stats: TokenPRLift, input_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if input_stats.top_recall: - section += "**Input tokens with highest recall (most common current tokens when the component is firing)**\n" - for tok, recall in input_stats.top_recall[:8]: - section += f"- {repr(tok)}: {recall * 100:.0f}%\n" - + md.labeled_list( + "**Input tokens with highest recall (most common current tokens when the component is firing)**", + [f"{repr(tok)}: {recall * 100:.0f}%" for tok, recall in input_stats.top_recall[:8]], + ) if input_stats.top_precision: - section += "\n**Input tokens with highest precision (probability the component fires given the current token is X)**\n" - for tok, prec in input_stats.top_precision[:8]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - + md.labeled_list( + "**Input tokens with highest precision (probability the component fires given the current token is X)**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in input_stats.top_precision[:8]], + ) if input_pmi: - section += "\n**Input tokens with highest PMI (pointwise mutual information. Tokens with higher-than-base-rate likelihood of co-occurrence with the component firing)**\n" - for tok, pmi in input_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - return section + md.labeled_list( + "**Input tokens with highest PMI (pointwise mutual information. Tokens with higher-than-base-rate likelihood of co-occurrence with the component firing)**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in input_pmi[:6]], + ) + return md def _build_output_section( output_stats: TokenPRLift, output_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - +) -> Md: + md = Md() if output_stats.top_precision: - section += "**Output precision — of all predicted probability for token X, what fraction is at positions where this component fires?**\n" - for tok, prec in output_stats.top_precision[:10]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - + md.labeled_list( + "**Output precision — of all predicted probability for token X, what fraction is at positions where this component fires?**", + [f"{repr(tok)}: {prec * 100:.0f}%" for tok, prec in output_stats.top_precision[:10]], + ) if output_pmi: - section += "\n**Output PMI — tokens the model predicts at higher-than-base-rate when this component fires:**\n" - for tok, pmi in output_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - return section + md.labeled_list( + "**Output PMI — tokens the model predicts at higher-than-base-rate when this component fires:**", + [f"{repr(tok)}: {pmi:.2f}" for tok, pmi in output_pmi[:6]], + ) + return md diff --git a/spd/autointerp/strategies/dual_view.py b/spd/autointerp/strategies/dual_view.py index 1206c73f5..8256f740d 100644 --- a/spd/autointerp/strategies/dual_view.py +++ b/spd/autointerp/strategies/dual_view.py @@ -18,10 +18,12 @@ density_note, human_layer_desc, layer_position_note, + token_pmi_pairs, ) from spd.autointerp.schemas import ModelMetadata from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md def format_prompt( @@ -36,26 +38,19 @@ def format_prompt( output_pmi: list[tuple[str, float]] | None = None if config.include_pmi: - input_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None - ) - output_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None - ) + input_pmi = token_pmi_pairs(app_tok, component.input_token_pmi.top) + output_pmi = token_pmi_pairs(app_tok, component.output_token_pmi.top) output_section = build_output_section(output_token_stats, output_pmi) input_section = build_input_section(input_token_stats, input_pmi) fires_on_examples = build_fires_on_examples(component, app_tok, config.max_examples) says_examples = build_says_examples(component, app_tok, config.max_examples) - if component.firing_density > 0.0: - rate_str = f"~1 in {int(1 / component.firing_density)} tokens" - else: - rate_str = "extremely rare" + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) @@ -77,48 +72,64 @@ def format_prompt( else "" ) - return f"""\ -Describe what this neural network component does. - -Each component is a learned linear transformation inside a weight matrix. It has an input function (what causes it to fire) and an output function (what tokens it causes the model to produce). These are often different — a component might fire on periods but produce sentence-opening words, or fire on prepositions but produce abstract nouns. - -Consider all of the evidence below critically. Token statistics can be noisy, especially for high-density components. The activation examples are sampled and may not be representative. Look for patterns that are consistent across multiple sources of evidence. - -## Context -- Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line} -- Component location: {layer_desc} -- Component firing rate: {component.firing_density * 100:.2f}% ({rate_str}) - -{context_notes} - -## Output tokens (what the model produces when this component fires) - -{output_section} -## Input tokens (what causes this component to fire) - -{input_section} -## Activation examples — where the component fires - -<> mark tokens where this component is active. + md = Md() + md.p( + "Describe what this neural network component does.\n\n" + "Each component is a learned linear transformation inside a weight matrix. " + "It has an input function (what causes it to fire) and an output function " + "(what tokens it causes the model to produce). These are often different — " + "a component might fire on periods but produce sentence-opening words, or " + "fire on prepositions but produce abstract nouns.\n\n" + "Consider all of the evidence below critically. Token statistics can be noisy, " + "especially for high-density components. The activation examples are sampled " + "and may not be representative. Look for patterns that are consistent across " + "multiple sources of evidence." + ) -{fires_on_examples} -## Activation examples — what the model produces + md.h(2, "Context").bullets( + [ + f"Model: {model_metadata.model_class} ({model_metadata.n_blocks} blocks){dataset_line}", + f"Component location: {layer_desc}", + f"Component firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) + if context_notes: + md.p(context_notes) -Same examples with <> shifted right by one — showing the token that follows each firing position. + md.h(2, "Output tokens (what the model produces when this component fires)") + md.extend(output_section) -{says_examples} + md.h(2, "Input tokens (what causes this component to fire)") + md.extend(input_section) -## Task + md.h(2, "Activation examples — where the component fires") + md.p("<> mark tokens where this component is active.") + md.extend(fires_on_examples) -Give a {config.label_max_words}-word-or-fewer label describing this component's function. The label should read like a short description of the job this component does in the network. Use both the input and output evidence. + md.h(2, "Activation examples — what the model produces") + md.p( + "Same examples with <> shifted right by one — " + "showing the token that follows each firing position." + ) + md.extend(says_examples) -Examples of good labels across different component types: -- "word stem completion (stems → suffixes)" -- "closes dialogue with quotation marks" -- "object pronouns after verbs" -- "story-ending moral resolution vocabulary" -- "aquatic scene vocabulary (frog, river, pond)" -- "'of course' and abstract nouns after prepositions" + md.h(2, "Task") + md.p( + f"Give a {config.label_max_words}-word-or-fewer label describing this component's " + "function. The label should read like a short description of the job this component " + "does in the network. Use both the input and output evidence." + ) + md.p( + "Examples of good labels across different component types:\n" + '- "word stem completion (stems → suffixes)"\n' + '- "closes dialogue with quotation marks"\n' + '- "object pronouns after verbs"\n' + '- "story-ending moral resolution vocabulary"\n' + '- "aquatic scene vocabulary (frog, river, pond)"\n' + "- \"'of course' and abstract nouns after prepositions\"" + ) + md.p( + f'Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only.' + ) -Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only. -""" + return md.build() diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md index a063596f5..f502f8785 100644 --- a/spd/clustering/CLAUDE.md +++ b/spd/clustering/CLAUDE.md @@ -108,6 +108,18 @@ DistancesArray # Float[np.ndarray, "n_iters n_ens n_ens"] - `matching_dist.py` - Optimal matching distance via Hungarian algorithm - `merge_pair_samplers.py` - Strategies for selecting which pair to merge +## Utility Scripts + +**`get_cluster_mapping.py`**: Extracts cluster assignments at a specific iteration from a clustering run, outputs JSON mapping component labels to cluster indices (singletons mapped to `null`). + +```bash +python -m spd.clustering.scripts.get_cluster_mapping /path/to/clustering_run --iteration 299 +``` + +## App Integration + +To make a cluster mapping available in the app's dropdown for a run, add its path to `CANONICAL_RUNS` in `spd/app/frontend/src/lib/registry.ts` under the corresponding run's `clusterMappings` array. + ## Config Files Configs live in `spd/clustering/configs/`: diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index eccda019f..80e9c63bc 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,4 +6,4 @@ slurm_partition: null wandb_project: "spd" wandb_entity: "goodfire" create_git_snapshot: false -clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file +clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp.json" \ No newline at end of file diff --git a/spd/dataset_attributions/config.py b/spd/dataset_attributions/config.py index 6f02df0f9..8a515ab7e 100644 --- a/spd/dataset_attributions/config.py +++ b/spd/dataset_attributions/config.py @@ -14,7 +14,7 @@ class DatasetAttributionConfig(BaseConfig): spd_run_wandb_path: str - harvest_subrun_id: str | None = None + harvest_subrun_id: str n_batches: int | Literal["whole_dataset"] = 10_000 batch_size: int = 32 ci_threshold: float = 0.0 diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 2633267c3..da2f53505 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -36,7 +36,7 @@ def _build_alive_masks( model: ComponentModel, run_id: str, - harvest_subrun_id: str | None, + harvest_subrun_id: str, ) -> dict[str, Bool[Tensor, " n_components"]]: """Build masks of alive components (firing_density > 0) per target layer. @@ -48,11 +48,7 @@ def _build_alive_masks( for layer in model.target_module_paths } - if harvest_subrun_id is not None: - harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) - else: - harvest = HarvestRepo.open_most_recent(run_id, readonly=True) - assert harvest is not None, f"No harvest data for {run_id}" + harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) summary = harvest.get_summary() assert summary is not None, "Harvest summary not available" @@ -73,7 +69,6 @@ def harvest_attributions( rank: int, world_size: int, ) -> None: - device = torch.device(get_device()) logger.info(f"Loading model on {device}") diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 9e04cac46..961b652c8 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -45,25 +45,8 @@ def submit_attributions( job_suffix: str | None = None, snapshot_branch: str | None = None, dependency_job_id: str | None = None, - harvest_subrun_id: str | None = None, ) -> AttributionsSubmitResult: - """Submit multi-GPU attribution harvesting job to SLURM. - - Submits a job array where each task processes a subset of batches, then - submits a merge job that depends on all workers completing. Creates a git - snapshot to ensure consistent code across all workers. - - Args: - wandb_path: WandB run path for the target decomposition run. - config: Attribution SLURM configuration. - job_suffix: Optional suffix for SLURM job names (e.g., "1h" -> "spd-attr-1h"). - snapshot_branch: Git snapshot branch to use. If None, creates a new snapshot. - dependency_job_id: SLURM job to wait for before starting (e.g. harvest merge). - harvest_subrun_id: Harvest subrun for alive masks. If None, uses most recent. - - Returns: - AttributionsSubmitResult with array, merge results and subrun ID. - """ + """Submit multi-GPU attribution harvesting job to SLURM.""" n_gpus = config.n_gpus partition = config.partition time = config.time @@ -80,10 +63,7 @@ def submit_attributions( suffix = f"-{job_suffix}" if job_suffix else "" array_job_name = f"spd-attr{suffix}" - inner_config = config.config - if harvest_subrun_id is not None and inner_config.harvest_subrun_id is None: - inner_config = inner_config.model_copy(update={"harvest_subrun_id": harvest_subrun_id}) - config_json = inner_config.model_dump_json(exclude_none=True) + config_json = config.config.model_dump_json(exclude_none=True) # SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc. worker_commands = [] diff --git a/spd/graph_interp/db.py b/spd/graph_interp/db.py index 1d10cda29..1a6a83e37 100644 --- a/spd/graph_interp/db.py +++ b/spd/graph_interp/db.py @@ -48,6 +48,9 @@ """ +_LABEL_TABLES = ("output_labels", "input_labels", "unified_labels") + + class GraphInterpDB: """NFS-hosted. Uses open_nfs_sqlite (no WAL). Single writer, then read-only.""" @@ -60,11 +63,12 @@ def __init__(self, db_path: Path, readonly: bool = False) -> None: def mark_done(self) -> None: (self._db_path.parent / DONE_MARKER).touch() - # -- Output labels --------------------------------------------------------- + # -- Label CRUD (shared across output/input/unified) ----------------------- - def save_output_label(self, result: LabelResult) -> None: + def _save_label(self, table: str, result: LabelResult) -> None: + assert table in _LABEL_TABLES self._conn.execute( - "INSERT OR REPLACE INTO output_labels VALUES (?, ?, ?, ?, ?, ?)", + f"INSERT OR REPLACE INTO {table} VALUES (?, ?, ?, ?, ?, ?)", ( result.component_key, result.label, @@ -76,73 +80,52 @@ def save_output_label(self, result: LabelResult) -> None: ) self._conn.commit() - def get_output_label(self, component_key: str) -> LabelResult | None: + def _get_label(self, table: str, component_key: str) -> LabelResult | None: + assert table in _LABEL_TABLES row = self._conn.execute( - "SELECT * FROM output_labels WHERE component_key = ?", (component_key,) + f"SELECT * FROM {table} WHERE component_key = ?", (component_key,) ).fetchone() if row is None: return None return _row_to_label_result(row) - def get_all_output_labels(self) -> dict[str, LabelResult]: - rows = self._conn.execute("SELECT * FROM output_labels").fetchall() + def _get_all_labels(self, table: str) -> dict[str, LabelResult]: + assert table in _LABEL_TABLES + rows = self._conn.execute(f"SELECT * FROM {table}").fetchall() return {row["component_key"]: _row_to_label_result(row) for row in rows} + # -- Output labels --------------------------------------------------------- + + def save_output_label(self, result: LabelResult) -> None: + self._save_label("output_labels", result) + + def get_output_label(self, component_key: str) -> LabelResult | None: + return self._get_label("output_labels", component_key) + + def get_all_output_labels(self) -> dict[str, LabelResult]: + return self._get_all_labels("output_labels") + # -- Input labels ---------------------------------------------------------- def save_input_label(self, result: LabelResult) -> None: - self._conn.execute( - "INSERT OR REPLACE INTO input_labels VALUES (?, ?, ?, ?, ?, ?)", - ( - result.component_key, - result.label, - result.confidence, - result.reasoning, - result.raw_response, - result.prompt, - ), - ) - self._conn.commit() + self._save_label("input_labels", result) def get_input_label(self, component_key: str) -> LabelResult | None: - row = self._conn.execute( - "SELECT * FROM input_labels WHERE component_key = ?", (component_key,) - ).fetchone() - if row is None: - return None - return _row_to_label_result(row) + return self._get_label("input_labels", component_key) def get_all_input_labels(self) -> dict[str, LabelResult]: - rows = self._conn.execute("SELECT * FROM input_labels").fetchall() - return {row["component_key"]: _row_to_label_result(row) for row in rows} + return self._get_all_labels("input_labels") # -- Unified labels -------------------------------------------------------- def save_unified_label(self, result: LabelResult) -> None: - self._conn.execute( - "INSERT OR REPLACE INTO unified_labels VALUES (?, ?, ?, ?, ?, ?)", - ( - result.component_key, - result.label, - result.confidence, - result.reasoning, - result.raw_response, - result.prompt, - ), - ) - self._conn.commit() + self._save_label("unified_labels", result) def get_unified_label(self, component_key: str) -> LabelResult | None: - row = self._conn.execute( - "SELECT * FROM unified_labels WHERE component_key = ?", (component_key,) - ).fetchone() - if row is None: - return None - return _row_to_label_result(row) + return self._get_label("unified_labels", component_key) def get_all_unified_labels(self) -> dict[str, LabelResult]: - rows = self._conn.execute("SELECT * FROM unified_labels").fetchall() - return {row["component_key"]: _row_to_label_result(row) for row in rows} + return self._get_all_labels("unified_labels") def get_completed_unified_keys(self) -> set[str]: rows = self._conn.execute("SELECT component_key FROM unified_labels").fetchall() @@ -178,12 +161,10 @@ def get_all_prompt_edges(self) -> list[PromptEdge]: rows = self._conn.execute("SELECT * FROM prompt_edges").fetchall() return [_row_to_prompt_edge(row) for row in rows] - # -- Config ---------------------------------------------------------------- - # -- Stats ----------------------------------------------------------------- def get_label_count(self, table: str) -> int: - assert table in ("output_labels", "input_labels", "unified_labels") + assert table in _LABEL_TABLES row = self._conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() assert row is not None return row[0] diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py index 00b6e243b..f0c92725a 100644 --- a/spd/graph_interp/interpret.py +++ b/spd/graph_interp/interpret.py @@ -12,7 +12,6 @@ import asyncio from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable -from functools import partial from pathlib import Path from typing import Literal @@ -36,13 +35,15 @@ format_unification_prompt, ) from spd.graph_interp.schemas import LabelResult, PromptEdge -from spd.harvest.analysis import get_input_token_stats, get_output_token_stats +from spd.harvest.analysis import TokenPRLift, get_input_token_stats, get_output_token_stats from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData from spd.harvest.storage import CorrelationStorage, TokenStatsStorage from spd.log import logger GetRelated = Callable[[str, dict[str, LabelResult]], list[RelatedComponent]] Step = Callable[[list[str], dict[str, LabelResult]], Awaitable[dict[str, LabelResult]]] +MakePrompt = Callable[["ComponentData", "TokenPRLift", list[RelatedComponent]], str] def run_graph_interp( @@ -106,23 +107,13 @@ def _to_canon(concrete_key: str) -> str: layer, idx = concrete_key.rsplit(":", 1) return f"{concrete_to_canon[layer]}:{idx}" - def _make_get_targets(metric: AttrMetric) -> "graph_context.GetAttributed": + def _make_get_attributed( + method: Callable[..., list[DatasetAttributionEntry]], metric: AttrMetric + ) -> "graph_context.GetAttributed": def get( key: str, k: int, sign: Literal["positive", "negative"] ) -> list[DatasetAttributionEntry]: - return _translate_entries( - attribution_storage.get_top_targets(_to_canon(key), k=k, sign=sign, metric=metric) - ) - - return get - - def _make_get_sources(metric: AttrMetric) -> "graph_context.GetAttributed": - def get( - key: str, k: int, sign: Literal["positive", "negative"] - ) -> list[DatasetAttributionEntry]: - return _translate_entries( - attribution_storage.get_top_sources(_to_canon(key), k=k, sign=sign, metric=metric) - ) + return _translate_entries(method(_to_canon(key), k=k, sign=sign, metric=metric)) return get @@ -138,77 +129,49 @@ def get(key: str, labels_so_far: dict[str, LabelResult]) -> list[RelatedComponen return get - # -- Layer processors ------------------------------------------------------ + # -- Layer processor (shared for output and input passes) -------------------- - async def process_output_layer( + def _make_process_layer( get_related: GetRelated, save_label: Callable[[LabelResult], None], - pending: list[str], - labels_so_far: dict[str, LabelResult], - ) -> dict[str, LabelResult]: - def jobs() -> Iterable[LLMJob]: - for key in pending: - component = harvest.get_component(key) - assert component is not None, f"Component {key} not found in harvest DB" - o_stats = get_output_token_stats(token_stats, key, app_tok, top_k=50) - assert o_stats is not None, f"No output token stats for {key}" - - related = get_related(key, labels_so_far) - db.save_prompt_edges([ - PromptEdge( - component_key=key, related_key=r.component_key, - pass_name="output", attribution=r.attribution, - related_label=r.label, related_confidence=r.confidence, + pass_name: Literal["output", "input"], + get_token_stats: Callable[[str], TokenPRLift | None], + make_prompt: MakePrompt, + ) -> Step: + async def process( + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + stats = get_token_stats(key) + assert stats is not None, f"No {pass_name} token stats for {key}" + + related = get_related(key, labels_so_far) + db.save_prompt_edges( + [ + PromptEdge( + component_key=key, + related_key=r.component_key, + pass_name=pass_name, + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] ) - for r in related - ]) - prompt = format_output_prompt( - component=component, - model_metadata=model_metadata, - app_tok=app_tok, - output_token_stats=o_stats, - related=related, - label_max_words=config.label_max_words, - max_examples=config.max_examples, - ) - yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) - - return await _collect_labels(llm_map, jobs(), len(pending), save_label) - - async def process_input_layer( - get_related: GetRelated, - save_label: Callable[[LabelResult], None], - pending: list[str], - labels_so_far: dict[str, LabelResult], - ) -> dict[str, LabelResult]: - def jobs() -> Iterable[LLMJob]: - for key in pending: - component = harvest.get_component(key) - assert component is not None, f"Component {key} not found in harvest DB" - i_stats = get_input_token_stats(token_stats, key, app_tok, top_k=20) - assert i_stats is not None, f"No input token stats for {key}" - - related = get_related(key, labels_so_far) - db.save_prompt_edges([ - PromptEdge( - component_key=key, related_key=r.component_key, - pass_name="input", attribution=r.attribution, - related_label=r.label, related_confidence=r.confidence, + yield LLMJob( + prompt=make_prompt(component, stats, related), + schema=LABEL_SCHEMA, + key=key, ) - for r in related - ]) - prompt = format_input_prompt( - component=component, - model_metadata=model_metadata, - app_tok=app_tok, - input_token_stats=i_stats, - related=related, - label_max_words=config.label_max_words, - max_examples=config.max_examples, - ) - yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) - return await _collect_labels(llm_map, jobs(), len(pending), save_label) + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + return process # -- Scan (fold over layers) ----------------------------------------------- @@ -282,18 +245,48 @@ def jobs() -> Iterable[LLMJob]: db = GraphInterpDB(db_path) metric = config.attr_metric - get_targets = _make_get_targets(metric) - get_sources = _make_get_sources(metric) + get_targets = _make_get_attributed(attribution_storage.get_top_targets, metric) + get_sources = _make_get_attributed(attribution_storage.get_top_sources, metric) + + def _output_prompt( + component: ComponentData, stats: TokenPRLift, related: list[RelatedComponent] + ) -> str: + return format_output_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + output_token_stats=stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + + def _input_prompt( + component: ComponentData, stats: TokenPRLift, related: list[RelatedComponent] + ) -> str: + return format_input_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + input_token_stats=stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) - label_output = partial( - process_output_layer, + label_output = _make_process_layer( _get_related(get_targets), db.save_output_label, + "output", + lambda key: get_output_token_stats(token_stats, key, app_tok, top_k=50), + _output_prompt, ) - label_input = partial( - process_input_layer, + label_input = _make_process_layer( _get_related(get_sources), db.save_input_label, + "input", + lambda key: get_input_token_stats(token_stats, key, app_tok, top_k=20), + _input_prompt, ) async def _run() -> None: @@ -368,5 +361,3 @@ def _check_error_rate(n_errors: int, n_done: int) -> None: raise RuntimeError( f"Error rate {n_errors / total:.0%} ({n_errors}/{total}) exceeds 5% threshold" ) - - diff --git a/spd/graph_interp/prompts.py b/spd/graph_interp/prompts.py index f2160ed0c..f00a0d671 100644 --- a/spd/graph_interp/prompts.py +++ b/spd/graph_interp/prompts.py @@ -15,12 +15,14 @@ density_note, human_layer_desc, layer_position_note, + token_pmi_pairs, ) from spd.autointerp.schemas import ModelMetadata from spd.graph_interp.graph_context import RelatedComponent from spd.graph_interp.schemas import LabelResult from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData +from spd.utils.markdown import Md LABEL_SCHEMA: dict[str, object] = { "type": "object", @@ -33,11 +35,17 @@ "additionalProperties": False, } +JSON_INSTRUCTION = ( + 'Respond with JSON: {"label": "...", "confidence": "low|medium|high", "reasoning": "..."}' +) + +UNCLEAR_NOTE = 'Say "unclear" if the evidence is too weak.' + def _component_header( component: ComponentData, model_metadata: ModelMetadata, -) -> str: +) -> Md: canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) position_note = layer_position_note(canonical, model_metadata.n_blocks) @@ -49,13 +57,17 @@ def _component_header( else "extremely rare" ) + md = Md() + md.h(2, "Context").bullets( + [ + f"Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model", + f"Firing rate: {component.firing_density * 100:.2f}% ({rate_str})", + ] + ) context_notes = " ".join(filter(None, [position_note, dens_note])) - - return f"""\ -## Context -- Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model -- Firing rate: {component.firing_density * 100:.2f}% ({rate_str}) -{context_notes}""" + if context_notes: + md.p(context_notes) + return md def format_output_prompt( @@ -67,36 +79,35 @@ def format_output_prompt( label_max_words: int, max_examples: int, ) -> str: - header = _component_header(component, model_metadata) + output_pmi = token_pmi_pairs(app_tok, component.output_token_pmi.top) - output_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] - if component.output_token_pmi.top - else None + md = Md() + md.p( + "You are analyzing a component in a neural network to understand its " + "OUTPUT FUNCTION — what it does when it fires." ) - output_section = build_output_section(output_token_stats, output_pmi) - says = build_says_examples(component, app_tok, max_examples) - related_table = _format_related_table(related, model_metadata, app_tok) + md.extend(_component_header(component, model_metadata)) - return f"""\ -You are analyzing a component in a neural network to understand its OUTPUT FUNCTION — what it does when it fires. + md.h(2, "Output tokens (what the model produces when this component fires)") + md.extend(build_output_section(output_token_stats, output_pmi)) -{header} + md.h(2, "Activation examples — what the model produces") + md.extend(build_says_examples(component, app_tok, max_examples)) -## Output tokens (what the model produces when this component fires) -{output_section} -## Activation examples — what the model produces -{says} -## Downstream components (what this component influences) -These components in later layers are most influenced by this component (by gradient attribution): -{related_table} -## Task -Give a {label_max_words}-word-or-fewer label describing this component's OUTPUT FUNCTION — what it does when it fires. + md.h(2, "Downstream components (what this component influences)") + md.p( + "These components in later layers are most influenced by this component (by gradient attribution):" + ) + md.extend(_format_related(related, model_metadata, app_tok)) -Say "unclear" if the evidence is too weak. + md.h(2, "Task") + md.p( + f"Give a {label_max_words}-word-or-fewer label describing this component's " + "OUTPUT FUNCTION — what it does when it fires." + ) + md.p(UNCLEAR_NOTE).p(JSON_INSTRUCTION) -Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} -""" + return md.build() def format_input_prompt( @@ -108,36 +119,33 @@ def format_input_prompt( label_max_words: int, max_examples: int, ) -> str: - header = _component_header(component, model_metadata) + input_pmi = token_pmi_pairs(app_tok, component.input_token_pmi.top) - input_pmi = ( - [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] - if component.input_token_pmi.top - else None + md = Md() + md.p( + "You are analyzing a component in a neural network to understand its " + "INPUT FUNCTION — what triggers it to fire." ) - input_section = build_input_section(input_token_stats, input_pmi) - fires_on = build_fires_on_examples(component, app_tok, max_examples) - related_table = _format_related_table(related, model_metadata, app_tok) + md.extend(_component_header(component, model_metadata)) - return f"""\ -You are analyzing a component in a neural network to understand its INPUT FUNCTION — what triggers it to fire. + md.h(2, "Input tokens (what causes this component to fire)") + md.extend(build_input_section(input_token_stats, input_pmi)) -{header} + md.h(2, "Activation examples — where the component fires") + md.extend(build_fires_on_examples(component, app_tok, max_examples)) -## Input tokens (what causes this component to fire) -{input_section} -## Activation examples — where the component fires -{fires_on} -## Upstream components (what feeds into this component) -These components in earlier layers most strongly attribute to this component: -{related_table} -## Task -Give a {label_max_words}-word-or-fewer label describing this component's INPUT FUNCTION — what conditions trigger it to fire. + md.h(2, "Upstream components (what feeds into this component)") + md.p("These components in earlier layers most strongly attribute to this component:") + md.extend(_format_related(related, model_metadata, app_tok)) -Say "unclear" if the evidence is too weak. + md.h(2, "Task") + md.p( + f"Give a {label_max_words}-word-or-fewer label describing this component's " + "INPUT FUNCTION — what conditions trigger it to fire." + ) + md.p(UNCLEAR_NOTE).p(JSON_INSTRUCTION) -Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} -""" + return md.build() def format_unification_prompt( @@ -149,45 +157,47 @@ def format_unification_prompt( label_max_words: int, max_examples: int, ) -> str: - header = _component_header(component, model_metadata) - fires_on = build_fires_on_examples(component, app_tok, max_examples) - says = build_says_examples(component, app_tok, max_examples) - - return f"""\ -A neural network component has been analyzed from two perspectives. - -{header} - -## Activation examples — where the component fires -{fires_on} -## Activation examples — what the model produces -{says} -## Two-perspective analysis - -OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence}) - Reasoning: {output_label.reasoning} - -INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence}) - Reasoning: {input_label.reasoning} + md = Md() + md.p("A neural network component has been analyzed from two perspectives.") + md.extend(_component_header(component, model_metadata)) + + md.h(2, "Activation examples — where the component fires") + md.extend(build_fires_on_examples(component, app_tok, max_examples)) + + md.h(2, "Activation examples — what the model produces") + md.extend(build_says_examples(component, app_tok, max_examples)) + + md.h(2, "Two-perspective analysis") + md.p( + f'OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence})\n' + f" Reasoning: {output_label.reasoning}\n\n" + f'INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence})\n' + f" Reasoning: {input_label.reasoning}" + ) -## Task -Synthesize these into a single unified label (max {label_max_words} words) that captures the component's complete role. If input and output suggest the same concept, unify them. If they describe genuinely different aspects (e.g. fires on X, produces Y), combine both. + md.h(2, "Task") + md.p( + f"Synthesize these into a single unified label (max {label_max_words} words) " + "that captures the component's complete role. If input and output suggest the " + "same concept, unify them. If they describe genuinely different aspects " + "(e.g. fires on X, produces Y), combine both." + ) + md.p(JSON_INSTRUCTION) -Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} -""" + return md.build() -def _format_related_table( +def _format_related( components: list[RelatedComponent], model_metadata: ModelMetadata, app_tok: AppTokenizer, -) -> str: - # Filter: only show labeled components and token entries (embed/output) +) -> Md: visible = [n for n in components if n.label is not None or _is_token_entry(n.component_key)] + md = Md() if not visible: - return "(no related components with labels found)\n" + md.p("(no related components with labels found)") + return md - # Normalize attributions: strongest = 1.0 max_attr = max(abs(n.attribution) for n in visible) norm = max_attr if max_attr > 0 else 1.0 @@ -195,18 +205,14 @@ def _format_related_table( for n in visible: display = _component_display(n.component_key, model_metadata, app_tok) rel_attr = n.attribution / norm - - parts = [f" {display} (relative attribution: {rel_attr:+.2f}"] - if n.pmi is not None: - parts.append(f", co-firing PMI: {n.pmi:.2f}") - parts.append(")") - - line = "".join(parts) + pmi_str = f", co-firing PMI: {n.pmi:.2f}" if n.pmi is not None else "" + line = f" {display} (relative attribution: {rel_attr:+.2f}{pmi_str})" if n.label is not None: line += f'\n label: "{n.label}" (confidence: {n.confidence})' lines.append(line) - return "\n".join(lines) + "\n" + md.p("\n".join(lines)) + return md def _is_token_entry(key: str) -> bool: diff --git a/spd/graph_interp/repo.py b/spd/graph_interp/repo.py index 6667c4e1e..a76590e3d 100644 --- a/spd/graph_interp/repo.py +++ b/spd/graph_interp/repo.py @@ -11,7 +11,8 @@ import yaml -from spd.graph_interp.db import DONE_MARKER, GraphInterpDB +from spd.autointerp.db import DONE_MARKER +from spd.graph_interp.db import GraphInterpDB from spd.graph_interp.schemas import LabelResult, PromptEdge, get_graph_interp_dir diff --git a/spd/harvest/analysis.py b/spd/harvest/analysis.py index d0d92aac3..739e78fe9 100644 --- a/spd/harvest/analysis.py +++ b/spd/harvest/analysis.py @@ -106,7 +106,6 @@ def get_correlated_components( return output - def has_component(storage: CorrelationStorage, component_key: str) -> bool: """Check if a component exists in the storage.""" return component_key in storage.key_to_idx diff --git a/spd/investigate/agent_prompt.py b/spd/investigate/agent_prompt.py index d53a47ac3..1e68073f1 100644 --- a/spd/investigate/agent_prompt.py +++ b/spd/investigate/agent_prompt.py @@ -160,27 +160,23 @@ def _format_model_info(model_info: dict[str, Any]) -> str: - """Format model architecture info for inclusion in the agent prompt.""" parts = [f"- **Architecture**: {model_info.get('summary', 'Unknown')}"] - target_config = model_info.get("target_model_config") - if target_config: - if "n_layer" in target_config: - parts.append(f"- **Layers**: {target_config['n_layer']}") - if "n_embd" in target_config: - parts.append(f"- **Hidden dim**: {target_config['n_embd']}") - if "vocab_size" in target_config: - parts.append(f"- **Vocab size**: {target_config['vocab_size']}") - if "n_ctx" in target_config: - parts.append(f"- **Context length**: {target_config['n_ctx']}") + tc = model_info.get("target_model_config", {}) + for key, label in [ + ("n_layer", "Layers"), + ("n_embd", "Hidden dim"), + ("vocab_size", "Vocab size"), + ("n_ctx", "Context length"), + ]: + if key in tc: + parts.append(f"- **{label}**: {tc[key]}") topology = model_info.get("topology") if topology and topology.get("block_structure"): block = topology["block_structure"][0] - attn = ", ".join(block.get("attn_projections", [])) - ffn = ", ".join(block.get("ffn_projections", [])) - parts.append(f"- **Attention projections**: {attn}") - parts.append(f"- **FFN projections**: {ffn}") + parts.append(f"- **Attention projections**: {', '.join(block.get('attn_projections', []))}") + parts.append(f"- **FFN projections**: {', '.join(block.get('ffn_projections', []))}") return "\n".join(parts) diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py index 54806ed36..1cf9910fb 100644 --- a/spd/investigate/scripts/run_agent.py +++ b/spd/investigate/scripts/run_agent.py @@ -30,7 +30,6 @@ def write_mcp_config(inv_dir: Path, port: int) -> Path: - """Write MCP configuration file for Claude Code.""" mcp_config = { "mcpServers": { "spd": { @@ -45,7 +44,6 @@ def write_mcp_config(inv_dir: Path, port: int) -> Path: def write_claude_settings(inv_dir: Path) -> None: - """Write Claude Code settings to pre-grant MCP tool permissions.""" claude_dir = inv_dir / ".claude" claude_dir.mkdir(exist_ok=True) settings = {"permissions": {"allow": ["mcp__spd__*"]}} @@ -53,7 +51,6 @@ def write_claude_settings(inv_dir: Path) -> None: def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: - """Find an available port starting from start_port.""" for offset in range(max_attempts): port = start_port + offset with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -68,7 +65,6 @@ def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: def wait_for_backend(port: int, timeout: float = 120.0) -> bool: - """Wait for the backend to become healthy.""" url = f"http://localhost:{port}/api/health" start = time.time() while time.time() - start < timeout: @@ -83,7 +79,6 @@ def wait_for_backend(port: int, timeout: float = 120.0) -> bool: def load_run(port: int, wandb_path: str, context_length: int) -> None: - """Load the SPD run into the backend. Raises on failure.""" url = f"http://localhost:{port}/api/runs/load" params = {"wandb_path": wandb_path, "context_length": context_length} resp = requests.post(url, params=params, timeout=300) @@ -93,15 +88,12 @@ def load_run(port: int, wandb_path: str, context_length: int) -> None: def fetch_model_info(port: int) -> dict[str, Any]: - """Fetch model architecture info from the backend.""" resp = requests.get(f"http://localhost:{port}/api/pretrain_info/loaded", timeout=30) assert resp.status_code == 200, f"Failed to fetch model info: {resp.status_code} {resp.text}" - result: dict[str, Any] = resp.json() - return result + return resp.json() def log_event(events_path: Path, event: InvestigationEvent) -> None: - """Append an event to the events log.""" with open(events_path, "a") as f: f.write(event.model_dump_json() + "\n") @@ -174,8 +166,7 @@ def run_agent( stderr=subprocess.STDOUT, ) - def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: - _ = frame + def cleanup(signum: int | None = None, _frame: FrameType | None = None) -> None: logger.info(f"[{inv_id}] Cleaning up...") if backend_proc.poll() is None: backend_proc.terminate() diff --git a/spd/postprocess/__init__.py b/spd/postprocess/__init__.py index 35922436b..ccd209fcd 100644 --- a/spd/postprocess/__init__.py +++ b/spd/postprocess/__init__.py @@ -99,12 +99,15 @@ def postprocess(config: PostprocessConfig, dependency_job_id: str | None = None) attr_result = None if config.attributions is not None: assert isinstance(decomp_cfg, SPDHarvestConfig) + attr_inner = config.attributions.config.model_copy( + update={"harvest_subrun_id": harvest_result.subrun_id} + ) + attr_slurm = config.attributions.model_copy(update={"config": attr_inner}) attr_result = submit_attributions( wandb_path=decomp_cfg.wandb_path, - config=config.attributions, + config=attr_slurm, snapshot_branch=snapshot_branch, dependency_job_id=harvest_result.merge_result.job_id, - harvest_subrun_id=harvest_result.subrun_id, ) # === 5. Graph interp (depends on harvest merge + attribution merge) === diff --git a/spd/utils/markdown.py b/spd/utils/markdown.py new file mode 100644 index 000000000..0098c4d8e --- /dev/null +++ b/spd/utils/markdown.py @@ -0,0 +1,44 @@ +"""Minimal Markdown document builder for prompt construction. + +Atomic unit is a block (paragraph, heading, list, etc.). +build() joins blocks with double newlines. +""" + + +class Md: + """Accumulates Markdown blocks with a fluent API. + + Each method appends a block and returns self for chaining. + Call .build() to get the final string (blocks joined by blank lines). + """ + + def __init__(self) -> None: + self._blocks: list[str] = [] + + def h(self, level: int, text: str) -> "Md": + self._blocks.append(f"{'#' * level} {text}") + return self + + def p(self, text: str) -> "Md": + self._blocks.append(text) + return self + + def bullets(self, items: list[str]) -> "Md": + self._blocks.append("\n".join(f"- {item}" for item in items)) + return self + + def labeled_list(self, label: str, items: list[str]) -> "Md": + lines = [label] + [f"- {item}" for item in items] + self._blocks.append("\n".join(lines)) + return self + + def numbered(self, items: list[str]) -> "Md": + self._blocks.append("\n".join(f"{i}. {item}" for i, item in enumerate(items, 1))) + return self + + def extend(self, other: "Md") -> "Md": + self._blocks.extend(other._blocks) + return self + + def build(self) -> str: + return "\n\n".join(self._blocks) diff --git a/spd/utils/sqlite.py b/spd/utils/sqlite.py index 7ac591780..a75fa1976 100644 --- a/spd/utils/sqlite.py +++ b/spd/utils/sqlite.py @@ -9,10 +9,10 @@ - Readonly uses ?immutable=1 (no lock files created at all) - Write mode uses default DELETE journal -2. **Local databases** (app prompt_attr.db): - - Live at REPO_ROOT/.data/ on local filesystem - - WAL mode is fine and preferred for concurrent read/write - - Don't use this helper — configure directly +2. **App database** (prompt_attr.db): + - Lives at SPD_OUT_DIR/app/ on NFS (shared across team) + - Uses DELETE journal mode with fcntl.flock write locking + - Managed by PromptAttrDB in spd/app/backend/database.py """ import sqlite3 @@ -26,9 +26,7 @@ def open_nfs_sqlite(path: Path, readonly: bool) -> sqlite3.Connection: Write: default DELETE journal (WAL breaks on NFS). """ if readonly: - conn = sqlite3.connect( - f"file:{path}?immutable=1", uri=True, check_same_thread=False - ) + conn = sqlite3.connect(f"file:{path}?immutable=1", uri=True, check_same_thread=False) else: conn = sqlite3.connect(str(path), check_same_thread=False) conn.row_factory = sqlite3.Row diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index f39ef385f..f64a140c7 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -16,6 +16,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB from spd.app.backend.routers import graphs as graphs_router +from spd.app.backend.routers import intervention as intervention_router from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import RunState, StateManager @@ -54,6 +55,7 @@ def app_with_state(): # Patch DEVICE in all router modules to use CPU for tests with ( mock.patch.object(graphs_router, "DEVICE", DEVICE), + mock.patch.object(intervention_router, "DEVICE", DEVICE), mock.patch.object(runs_router, "DEVICE", DEVICE), ): db = PromptAttrDB(db_path=Path(":memory:"), check_same_thread=False) @@ -147,6 +149,7 @@ def app_with_state(): harvest=None, interp=None, attributions=None, + graph_interp=None, ) manager = StateManager.get() @@ -231,6 +234,49 @@ def test_compute_graph(app_with_prompt: tuple[TestClient, int]): assert "outputProbs" in data +def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClient, int]): + """Run-and-save intervention should use graph-linked prompt tokens (no text in request).""" + client, prompt_id = app_with_prompt + + graph_response = client.post( + "/api/graphs", + params={"prompt_id": prompt_id, "normalize": "none", "ci_threshold": 0.0}, + ) + assert graph_response.status_code == 200 + events = [line for line in graph_response.text.strip().split("\n") if line.startswith("data:")] + final_data = json.loads(events[-1].replace("data: ", "")) + graph_data = final_data["data"] + graph_id = graph_data["id"] + + selected_nodes = [ + key + for key, ci in graph_data["nodeCiVals"].items() + if not key.startswith("embed:") and not key.startswith("output:") and ci > 0 + ] + assert len(selected_nodes) > 0 + + request = { + "graph_id": graph_id, + "selected_nodes": selected_nodes[:5], + "top_k": 5, + "adv_pgd": {"n_steps": 1, "step_size": 1.0}, + } + response = client.post("/api/intervention/run", json=request) + assert response.status_code == 200 + body = response.json() + assert body["selected_nodes"] == request["selected_nodes"] + result = body["result"] + assert len(result["input_tokens"]) > 0 + assert len(result["ci"]) > 0 + assert len(result["stochastic"]) > 0 + assert len(result["adversarial"]) > 0 + assert result["target_sans"] is None + assert "ci_loss" in result + assert "stochastic_loss" in result + assert "adversarial_loss" in result + assert result["target_sans_loss"] is None + + # ----------------------------------------------------------------------------- # Streaming: Prompt Generation # ----------------------------------------------------------------------------- diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py index 75841ef2c..a972c1107 100644 --- a/tests/dataset_attributions/test_storage.py +++ b/tests/dataset_attributions/test_storage.py @@ -195,7 +195,7 @@ def test_single_file(self, tmp_path: Path): def _deterministic_storage( - regular_val: float = 10.0, + _regular_val: float = 10.0, embed_val: float = 6.0, ci_sum_val: float = 50.0, act_sq_sum_val: float = 400.0, @@ -316,10 +316,12 @@ class TestMergeNumericCorrectness: def test_merge_equals_sum_of_parts(self, tmp_path: Path): """Two workers with known values; merged queries should equal manual computation.""" - s1 = _deterministic_storage(embed_val=4.0, ci_sum_val=20.0, act_sq_sum_val=100.0, - embed_count_val=80, n_tokens=40) - s2 = _deterministic_storage(embed_val=8.0, ci_sum_val=30.0, act_sq_sum_val=500.0, - embed_count_val=120, n_tokens=60) + s1 = _deterministic_storage( + embed_val=4.0, ci_sum_val=20.0, act_sq_sum_val=100.0, embed_count_val=80, n_tokens=40 + ) + s2 = _deterministic_storage( + embed_val=8.0, ci_sum_val=30.0, act_sq_sum_val=500.0, embed_count_val=120, n_tokens=60 + ) p1, p2 = tmp_path / "r0.pt", tmp_path / "r1.pt" s1.save(p1)