diff --git a/HARVEST_DATA_SCHEMA_MIGRATION_NOTES.md b/HARVEST_DATA_SCHEMA_MIGRATION_NOTES.md new file mode 100644 index 000000000..826fa58b7 --- /dev/null +++ b/HARVEST_DATA_SCHEMA_MIGRATION_NOTES.md @@ -0,0 +1,141 @@ +# Harvest DB Schema Migration: `mean_ci` to `firing_density` + `mean_activations` + +**Date investigated**: 2026-02-18 +**Current branch**: `feature/attn_plots` +**Status**: Not yet migrated. Workaround in place for plotting scripts. + +## What happened + +A colleague is generalizing the harvest pipeline to work across decomposition methods (SPD, MOLT, CLT, SAE) on branches `feature/harvest-generic` and `feature/autointerp-generic`. As part of this, the harvest DB schema changed. The key commit is `70eceb8f` ("Generalize harvest pipeline over decomposition methods") by Claude SPD1, dated 2026-02-16. + +A new harvest sub-run for `s-275c8f21` was created on 2026-02-18 using code from `feature/harvest-generic`, producing data with the new schema. This data is incompatible with the code on `dev`, `main`, and `feature/attn_plots`, which still expect the old schema. + +## Schema change details + +### Old schema (on `dev`, `main`, `feature/attn_plots`) + +```sql +CREATE TABLE components ( + component_key TEXT PRIMARY KEY, + layer TEXT NOT NULL, + component_idx INTEGER NOT NULL, + mean_ci REAL NOT NULL, -- mean causal importance across all tokens + activation_examples TEXT NOT NULL, + input_token_pmi TEXT NOT NULL, + output_token_pmi TEXT NOT NULL +); +``` + +`ActivationExample` dataclass fields: `token_ids: list[int]`, `ci_values: list[float]`, `component_acts: list[float]` + +### New schema (on `feature/harvest-generic`, commit `70eceb8f`) + +```sql +CREATE TABLE components ( + component_key TEXT PRIMARY KEY, + layer TEXT NOT NULL, + component_idx INTEGER NOT NULL, + firing_density REAL NOT NULL, -- proportion of tokens where component fired (0-1) + mean_activations TEXT NOT NULL, -- JSON dict, e.g. {"causal_importance": 0.007} + activation_examples TEXT NOT NULL, + input_token_pmi TEXT NOT NULL, + output_token_pmi TEXT NOT NULL +); +``` + +`ActivationExample` dataclass fields: `token_ids: list[int]`, `firings: list[bool]`, `activations: dict[str, list[float]]` + +### Field mapping + +| Old field | New field | Notes | +|---|---|---| +| `mean_ci` (float) | `mean_activations["causal_importance"]` (float, inside JSON dict) | Same semantic meaning for SPD runs | +| *(not present)* | `firing_density` (float) | New metric: proportion of tokens where component fired | +| `ci_values` (on ActivationExample) | `activations["causal_importance"]` (on ActivationExample) | Per-token CI values, now keyed by activation type | +| `component_acts` (on ActivationExample) | `activations["component_activation"]` (on ActivationExample) | Per-token component activations | +| *(not present)* | `firings` (on ActivationExample) | Boolean per-token firing indicators | + +### Example new data row + +From `s-275c8f21` sub-run `h-20260218_000000`: +``` +firing_density: 0.011455078125 +mean_activations: {"causal_importance": 0.007389060687273741} +``` + +## Branches involved + +| Branch | Schema version | Status | +|---|---|---| +| `main` | Old (`mean_ci`) | Production | +| `dev` | Old (`mean_ci`) | Development | +| `feature/attn_plots` | Old (`mean_ci`) | Current work branch | +| `feature/harvest-generic` | New (`firing_density` + `mean_activations`) | Colleague's WIP | +| `feature/autointerp-generic` | New | Colleague's WIP | + +The schema change commits (`70eceb8f`, `5e66fd49`, `ad68187d`) exist **only** on `feature/harvest-generic` and `feature/autointerp-generic`. They are NOT on `dev` or `main`. + +## Current workaround + +The broken sub-run was renamed so the `HarvestRepo` skips it: + +``` +_h-20260218_000000.bak <-- new schema, renamed with _ prefix +h-20260212_150336 <-- old schema, now picked as "latest" by HarvestRepo +``` + +`HarvestRepo.open()` picks the latest `h-*` directory by lexicographic sort. The `_` prefix prevents the glob match on `d.name.startswith("h-")` (see `spd/harvest/repo.py:46`). + +## Files that need updating when migrating + +### Core harvest module (update schema definitions) + +1. **`spd/harvest/schemas.py`** — `ComponentSummary` and `ComponentData` dataclasses: replace `mean_ci: float` with `firing_density: float` + `mean_activations: dict[str, float]`. Also update `ActivationExample`. +2. **`spd/harvest/db.py`** — SQL schema, `_serialize_component()`, `_deserialize_component()`, `get_summary()`, `get_all_components()`. +3. **`spd/harvest/harvester.py`** — `build_results()` yields `ComponentData`; needs to compute `firing_density` and `mean_activations` dict. + +Reference implementation: `git show 70eceb8f:spd/harvest/schemas.py` and `git show 70eceb8f:spd/harvest/db.py`. + +### App backend (API schemas + endpoints) + +4. **`spd/app/backend/schemas.py`** — `SubcomponentMetadata.mean_ci` and `SubcomponentActivationContexts.mean_ci`. +5. **`spd/app/backend/routers/activation_contexts.py`** — Extracts and sorts by `mean_ci` in 3 endpoints. + +### Autointerp module + +6. **`spd/autointerp/interpret.py`** — Sorts components by `c.mean_ci` (line 116). +7. **`spd/autointerp/strategies/compact_skeptical.py`** — Uses `mean_ci * 100` and `1 / component.mean_ci` for LLM prompt formatting. + +### Dataset attributions + +8. **`spd/dataset_attributions/harvest.py`** — Filters alive components with `summary[key].mean_ci > ci_threshold` (line 90). + +### Plotting scripts (9 files, all follow same pattern) + +All have `MIN_MEAN_CI` constant and `_get_alive_indices()` that filters on `s.mean_ci > threshold`: + +9. `spd/scripts/plot_qk_c_attention_contributions/plot_qk_c_attention_contributions.py` +10. `spd/scripts/attention_stories/attention_stories.py` +11. `spd/scripts/characterize_induction_components/characterize_induction_components.py` +12. `spd/scripts/plot_kv_vt_similarity/plot_kv_vt_similarity.py` +13. `spd/scripts/plot_attention_weights/plot_attention_weights.py` +14. `spd/scripts/plot_kv_coactivation/plot_kv_coactivation.py` +15. `spd/scripts/plot_per_head_component_activations/plot_per_head_component_activations.py` +16. `spd/scripts/plot_head_spread/plot_head_spread.py` +17. `spd/scripts/plot_component_head_norms/plot_component_head_norms.py` + +### Dedicated CI visualization + +18. **`spd/scripts/plot_mean_ci/plot_mean_ci.py`** — Entire script dedicated to visualizing `mean_ci` distributions. May need renaming or deprecation. + +## Recommended migration approach + +Once `feature/harvest-generic` is stable and merged to `dev`: + +1. **Port core schema changes** from `feature/harvest-generic` (items 1-3 above). The reference implementation at commit `70eceb8f` has the complete updated `db.py` and `schemas.py`. + +2. **Update consumers** (items 4-18). For filtering "alive" components, the equivalent of `mean_ci > threshold` is `mean_activations["causal_importance"] > threshold`. Consider whether `firing_density` would be a better filter (it's a cleaner concept: "does this component fire often enough?"). + +3. **No legacy fallback needed** — per repo conventions (CLAUDE.md: "Don't add legacy fallbacks or migration code"). Old harvest data should be re-harvested with new code if needed. + +4. **Consider extracting `_get_alive_indices`** into a shared utility — it's duplicated across 9 plotting scripts with identical logic. diff --git a/attention_head_report.md b/attention_head_report.md new file mode 100644 index 000000000..053a3314d --- /dev/null +++ b/attention_head_report.md @@ -0,0 +1,172 @@ +# Attention Head Characterization: s-275c8f21 + +Model: 4-layer, 6-head LlamaSimpleMLP (d_model=768, head_dim=128), pretrained model t-32d1bb3b. + +All scripts live in `spd/scripts/detect_*/` and output to `spd/scripts/detect_*/out/s-275c8f21/`. + +## Analyses + +### Previous-Token Heads + +**Method**: On real text (eval split, 100 batches of 32), extract the offset-1 diagonal of each head's attention matrix — i.e., `attn[i, i-1]` — and average across positions and batches. + +**Results**: +| Head | Score | +|------|-------| +| L1H1 | 0.604 | +| L0H5 | 0.308 | + +All other heads score below 0.1. L1H1 is a clear previous-token head, spending over 60% of its attention on the immediately preceding token. L0H5 is weaker but still notable. + +### Induction Heads + +**Method**: Synthetic data — repeated random token sequences `[A B C ... | A B C ...]`. Measures the "offset diagonal" of attention in the second half: at position `L+k`, how much attention goes to position `k+1` (the token that followed the current token's earlier occurrence). This is the textbook induction pattern. 100 batches of 32, half-sequence length 256. + +**Results**: +| Head | Score | +|------|-------| +| L2H4 | 0.629 | + +No other head scores above 0.1. L2H4 is a strong, clean induction head. + +The L1H1 → L2H4 pairing forms the classic two-layer induction circuit: L1H1 shifts information one position back (previous-token), composing with L2H4's key-query matching to attend to "what came after this token last time." + +### Duplicate-Token Heads + +**Method**: On real text, build a boolean mask of positions where a prior token has the same ID, then measure mean attention to those same-token positions. Only positions with at least one prior duplicate contribute to the score, and batches are weighted by the number of valid positions. + +**Results**: +| Head | Score | +|------|-------| +| L0H4 | 0.323 | +| L0H2 | 0.202 | + +All other heads below 0.05. Both heads are in layer 0, suggesting duplicate-token detection happens early. + +### Successor Heads + +**Method**: Constructs ordinal sequences (digits, letters, number words, days, months) as comma-separated lists and measures attention from each element to its ordinal predecessor (2 positions back, since commas intervene). A control condition uses random words in place of ordinals, with the same positional structure. The "signal" is ordinal score minus control score, isolating semantic successor attention from positional artifacts. + +**Results** (signal > 0.05): +| Head | Ordinal | Control | Signal | +|------|---------|---------|--------| +| L0H2 | 0.379 | 0.073 | +0.307 | +| L0H4 | 0.155 | 0.001 | +0.154 | +| L1H0 | 0.098 | 0.041 | +0.058 | +| L1H1 | 0.174 | 0.108 | +0.067 | +| L3H0 | 0.192 | 0.121 | +0.070 | +| L1H2 | 0.497 | 0.443 | +0.054 | + +L0H2 is the standout successor head. L0H4 is secondary. Several other heads show modest signals. + +Note that L1H2 has high ordinal attention (0.497) but nearly as high control attention (0.443), suggesting it attends strongly to position-2-back regardless of content. The control subtraction properly removes this. + +### S-Inhibition Heads + +**Method**: Two-pronged analysis using IOI (Indirect Object Identification) prompts of the form "When Alice and Bob went to the store, Bob gave a drink to" → Alice. + +1. **Data-driven**: Measures attention from the final position to the second occurrence of the subject name (S2). High S2 attention means the head is "looking at" the repeated name. +2. **Weight-based**: Computes the OV copy score `W_U[t] @ W_O_h @ W_V_h @ W_E[t]` averaged over name tokens. Negative values indicate the head suppresses (rather than promotes) the attended token's logit. + +An S-inhibition head should have high S2 attention *and* a negative copy score. + +**Results** (candidates: attn > 0.1 and copy < 0): +| Head | Attn to S2 | OV Copy | Assessment | +|------|-----------|---------|------------| +| L3H2 | 0.377 | -0.029 | Strongest candidate | +| L2H1 | 0.151 | -0.001 | Weak candidate | + +L3H2 is the clearest S-inhibition candidate: it strongly attends to the repeated subject and has a negative copy score (suppression). L2H1 attends to S2 moderately but its copy score is only marginally negative. + +Several other heads have high S2 attention but positive copy scores (e.g., L3H0 at attn=0.156, copy=+0.007), suggesting they *copy* the subject rather than inhibit it — a different role in the IOI circuit. + +### Delimiter Heads + +**Method**: On real text, identifies delimiter token IDs (`.` `,` `;` `:` `!` `?` `\n` and multi-char variants) via the tokenizer, then measures the mean fraction of each head's attention landing on delimiter tokens. Compares to the baseline delimiter frequency in the data (~10.7%). Reports the ratio over baseline. + +**Results**: No head exceeds 2.0x baseline. Highest ratios: +| Head | Raw Attn | Ratio | +|------|----------|-------| +| L0H5 | 0.187 | 1.74x | +| L1H0 | 0.184 | 1.72x | +| L1H4 | 0.178 | 1.66x | + +This model does not appear to have dedicated delimiter heads. Most heads sit in the 1.0-1.7x range — modestly above baseline but not specialized. This could reflect the model's small size, or it could mean delimiter attention is distributed across heads rather than concentrated. + +### Positional Heads + +**Method**: On real text, builds a mean attention profile by relative offset for each head (offset = query_pos - key_pos). Also measures attention to absolute position 0 (BOS). A "positional head" has high attention concentrated at a specific offset; a "BOS head" attends heavily to position 0. + +**Results — offset-based**: +| Head | Max Offset Score | Peak Offset | +|------|-----------------|-------------| +| L1H1 | 0.604 | 1 | +| L0H5 | 0.308 | 1 | +| L1H0 | 0.265 | 1 | +| L1H3 | 0.226 | 2 | + +L1H1 and L0H5 are the same heads already identified as previous-token heads, confirming the result from a different angle. L1H3 peaks at offset 2 — it preferentially attends two positions back. + +**Results — BOS attention**: +| Head | BOS Score | +|------|-----------| +| L2H4 | 0.489 | +| L2H5 | 0.355 | +| L2H3 | 0.337 | +| L2H1 | 0.318 | +| L2H0 | 0.232 | +| L2H2 | 0.217 | +| L3H3 | 0.248 | +| L3H2 | 0.208 | +| L3H4 | 0.206 | + +All six heads in layer 2 have substantial BOS attention (0.22–0.49). Layer 3 also shows moderate BOS attention in several heads. Layers 0–1 show negligible BOS attention (< 0.01). + +## Cross-Cutting Observations + +### Multi-functional early heads + +L0H2 and L0H4 both serve as duplicate-token *and* successor heads. These are layer-0 heads operating directly on token embeddings, suggesting the behaviors may share a mechanism: attending to tokens that are "similar" to the current one (exact match for duplicate-token, ordinal neighbor for successor). Whether this reflects a single underlying computation or two coincidentally co-located behaviors is unclear from these analyses alone. + +**Hypothesis**: L0H2 and L0H4 may implement a general "embedding similarity" attention pattern that manifests as duplicate-token detection on repeated tokens and successor detection on ordinal sequences. Testing this would require measuring the correlation between these heads' attention weights and embedding cosine similarity. + +### The induction circuit + +The L1H1 (previous-token) → L2H4 (induction) circuit is clean and well-separated. L1H1 scores 0.604 on previous-token and L2H4 scores 0.629 on induction, with no other head approaching either score. This is the textbook two-layer induction circuit. + +### Layer 2 as a BOS sink + +The uniform high BOS attention across all of layer 2 is striking. L2H4 — the induction head — has the highest BOS score (0.489) despite also being the strongest induction head. This might seem contradictory, but BOS attention and induction attention operate on different token positions: BOS attention is measured as an average across *all* query positions, while induction attention is measured specifically at positions following repeated sequences. L2H4 likely defaults to BOS when there's no induction pattern to match, using position 0 as an attention sink. + +**Hypothesis**: Layer 2's BOS attention may serve as a "no-op" or default state. When a head doesn't have a strong content-based signal, it parks attention on BOS rather than distributing it noisily. This is a known phenomenon in transformer models (sometimes called "attention sinking"), and BOS is a natural sink since it's always available and semantically neutral in context. + +### S-inhibition is late and sparse + +Only L3H2 shows a convincing S-inhibition signal (layer 3, near the output). This makes architectural sense: S-inhibition requires first identifying the repeated subject (which depends on earlier duplicate-token and induction mechanisms) before suppressing it. The fact that it appears in the final layer is consistent with it being a downstream consumer of earlier head outputs. + +### No dedicated delimiter heads + +The absence of strong delimiter heads is a genuine null result, not a limitation of the method. The method would have detected them if present (the baseline-ratio approach has no inherent ceiling). This model apparently handles structural boundaries through other means, or distributes delimiter attention diffusely. + +### Caveats + +- All data-driven scores are averages. A head with a moderate average score might be strongly specialized on a subset of inputs and inactive on others. Per-example distributions would be more informative but are not captured here. +- The IOI template is a single fixed pattern. S-inhibition scores might differ with varied sentence structures. +- The successor head control condition (random words) controls for positional patterns but not for all confounds — e.g., if the tokenizer assigns similar embeddings to ordinal tokens, heads might use embedding similarity rather than "knowing" ordinal structure. +- OV copy scores (used in S-inhibition) are a linear approximation. They measure the direct path through one head and don't account for nonlinear interactions or composition with other heads/layers. + +## Summary Table + +| Head | Primary Role(s) | Evidence Strength | +|------|----------------|-------------------| +| L0H2 | Successor, duplicate-token | Strong (signal=0.307, dup=0.202) | +| L0H4 | Duplicate-token, successor | Strong (dup=0.323, signal=0.154) | +| L0H5 | Previous-token | Moderate (0.308) | +| L1H1 | Previous-token | Strong (0.604) | +| L1H3 | Offset-2 positional | Moderate (0.226 at offset 2) | +| L2H1 | Weak S-inhibition candidate | Weak (attn=0.151, copy=-0.001) | +| L2H4 | Induction, BOS sink | Strong induction (0.629), strong BOS (0.489) | +| L2H* | BOS sink (all layer 2) | Strong (0.22–0.49 across all heads) | +| L3H2 | S-inhibition | Moderate (attn=0.377, copy=-0.029) | + +Heads not listed individually (L0H1, L0H3, L1H0, L1H2, L1H4, L1H5, L3H0–L3H1, L3H3–L3H5) did not show strong specialization in any analysis, though several had modest signals across multiple categories. Layer-2 heads without individual roles (L2H0, L2H2, L2H3, L2H5) are covered by the L2H* BOS sink row. diff --git a/pyproject.toml b/pyproject.toml index 88c3405a8..aa310ff62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ known-third-party = ["wandb"] [tool.pyright] include = ["spd", "tests"] -exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend"] +exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend", "spd/scripts/detect_*"] stubPath = "typings" # Having type stubs for transformers shaves 10 seconds off basedpyright calls strictListInference = true diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 3804ce756..64375eb94 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -16,6 +16,7 @@ import fire import torch import uvicorn +from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -42,6 +43,7 @@ from spd.log import logger from spd.utils.distributed_utils import get_device +load_dotenv() DEVICE = get_device() diff --git a/spd/autointerp/repo.py b/spd/autointerp/repo.py index cae089059..704602c89 100644 --- a/spd/autointerp/repo.py +++ b/spd/autointerp/repo.py @@ -49,7 +49,7 @@ def open(cls, run_id: str) -> "InterpRepo | None": if not db_path.exists(): return None return cls( - db=InterpDB(db_path, readonly=True), + db=InterpDB(db_path), subrun_dir=subrun_dir, run_id=run_id, ) diff --git a/spd/clustering/configs/crc/s-275c8f21-10k.json b/spd/clustering/configs/crc/s-275c8f21-10k.json new file mode 100644 index 000000000..8148cf544 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-10k.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1, + "iters": 10000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 32, + "wandb_project": null, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } + } diff --git a/spd/clustering/configs/crc/s-275c8f21-25k.json b/spd/clustering/configs/crc/s-275c8f21-25k.json new file mode 100644 index 000000000..b2195818f --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-25k.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1, + "iters": 25000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 32, + "wandb_project": null, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } + } diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha0.1.json b/spd/clustering/configs/crc/s-275c8f21-alpha0.1.json new file mode 100644 index 000000000..48847ebc3 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha0.1.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 0.1, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha0.5.json b/spd/clustering/configs/crc/s-275c8f21-alpha0.5.json new file mode 100644 index 000000000..b17f948ba --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha0.5.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 0.5, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha1.json b/spd/clustering/configs/crc/s-275c8f21-alpha1.json new file mode 100644 index 000000000..4a6dd5431 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha1.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1.0, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha100.json b/spd/clustering/configs/crc/s-275c8f21-alpha100.json new file mode 100644 index 000000000..86f66473b --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha100.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 100.0, + "iters": 100000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha2.json b/spd/clustering/configs/crc/s-275c8f21-alpha2.json new file mode 100644 index 000000000..b3c3aa0d7 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha2.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 2.0, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha3.json b/spd/clustering/configs/crc/s-275c8f21-alpha3.json new file mode 100644 index 000000000..83a81d2f8 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha3.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 3.0, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha30.json b/spd/clustering/configs/crc/s-275c8f21-alpha30.json new file mode 100644 index 000000000..68db20a8c --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha30.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 30.0, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha5.json b/spd/clustering/configs/crc/s-275c8f21-alpha5.json new file mode 100644 index 000000000..53b59b5e7 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha5.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 5.0, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21-alpha8.json b/spd/clustering/configs/crc/s-275c8f21-alpha8.json new file mode 100644 index 000000000..96cc562d5 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21-alpha8.json @@ -0,0 +1,21 @@ +{ + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "merge_config": { + "activation_threshold": 0.1, + "alpha": 8.0, + "iters": 30000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + }, + "wandb_project": "spd", + "wandb_entity": "goodfire" +} diff --git a/spd/clustering/configs/crc/s-275c8f21.json b/spd/clustering/configs/crc/s-275c8f21.json new file mode 100644 index 000000000..a557dcaf2 --- /dev/null +++ b/spd/clustering/configs/crc/s-275c8f21.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1, + "iters": 20000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/s-275c8f21", + "batch_size": 256, + "wandb_project": null, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } + } diff --git a/spd/clustering/configs/pipeline-s-275c8f21-10k.yaml b/spd/clustering/configs/pipeline-s-275c8f21-10k.yaml new file mode 100644 index 000000000..e6b6b6ee1 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-10k.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-10k.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: false diff --git a/spd/clustering/configs/pipeline-s-275c8f21-25k.yaml b/spd/clustering/configs/pipeline-s-275c8f21-25k.yaml new file mode 100644 index 000000000..a7484a01c --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-25k.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-25k.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: false diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha0.1.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha0.1.yaml new file mode 100644 index 000000000..a44c8f377 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha0.1.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha0.1.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha0.5.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha0.5.yaml new file mode 100644 index 000000000..fb0f50727 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha0.5.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha0.5.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha1.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha1.yaml new file mode 100644 index 000000000..2763d8377 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha1.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha1.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha100.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha100.yaml new file mode 100644 index 000000000..731c604bb --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha100.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha100.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha2.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha2.yaml new file mode 100644 index 000000000..5500064b4 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha2.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha2.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha3.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha3.yaml new file mode 100644 index 000000000..064bd29f5 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha3.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha3.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha30.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha30.yaml new file mode 100644 index 000000000..5ec0177fd --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha30.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha30.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha5.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha5.yaml new file mode 100644 index 000000000..978373e7c --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha5.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha5.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21-alpha8.yaml b/spd/clustering/configs/pipeline-s-275c8f21-alpha8.yaml new file mode 100644 index 000000000..0c5a8a52c --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21-alpha8.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21-alpha8.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true diff --git a/spd/clustering/configs/pipeline-s-275c8f21.yaml b/spd/clustering/configs/pipeline-s-275c8f21.yaml new file mode 100644 index 000000000..4cff8f048 --- /dev/null +++ b/spd/clustering/configs/pipeline-s-275c8f21.yaml @@ -0,0 +1,8 @@ +clustering_run_config_path: "spd/clustering/configs/crc/s-275c8f21.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: false diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c8e86f0fc..1cfdf857a 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -71,8 +71,8 @@ def _load_lm_batch( config_kwargs_: dict[str, Any] = { **dict( - is_tokenized=False, - streaming=False, + is_tokenized=cfg.task_config.is_tokenized, + streaming=cfg.task_config.streaming, ), **(config_kwargs or {}), } diff --git a/spd/scripts/attention_ablation_experiment/REPORT.md b/spd/scripts/attention_ablation_experiment/REPORT.md new file mode 100644 index 000000000..ed1b81e8b --- /dev/null +++ b/spd/scripts/attention_ablation_experiment/REPORT.md @@ -0,0 +1,124 @@ +# Attention Ablation Experiment — Technical Report + +## Overview + +The script runs two types of experiments: +1. **Head ablation**: Zero out specific attention heads in the *target model* (no SPD involved) +2. **Component ablation**: Zero out specific SPD parameter components within attention projections (q/k/v/o\_proj) + +For each, it does a baseline forward pass and an ablated forward pass, captures attention patterns from both, and compares logits. + +--- + +## Head Ablation + +**Model used:** The raw pretrained `LlamaSimpleMLP` target model, loaded independently — no SPD model involved. + +**Mechanism:** The `patched_attention_forward` context manager monkey-patches every `CausalSelfAttention.forward` method. The patched forward: + +1. Calls `q_proj(x)`, `k_proj(x)`, `v_proj(x)` — these are the *original* target model linear layers (no component decomposition) +2. Reshapes to `(B, n_heads, T, head_dim)`, applies rotary embeddings, repeats KV for GQA +3. Computes explicit `softmax(QK^T / sqrt(d))` (flash attention is disabled) +4. Stores the attention pattern `(n_heads, T, T)` averaged over the batch dim +5. Computes `y = att @ v` giving `(B, n_heads, T, head_dim)` +6. **Ablation**: for specified `(layer, head)` pairs, zeros out `y[:, head, :, :]` (or at specific positions) +7. Reshapes and passes through `o_proj` + +**Baseline pass**: patched forward with no ablations — captures patterns, produces logits. + +**Ablated pass**: patched forward with head zeroing — captures patterns, produces logits. + +**What "ablating a head" means here**: The head's QKV computation and attention still happens normally. What's zeroed is the head's *contribution to the output* — `att @ v` for that head is set to 0 before the `o_proj` linear layer. So from `o_proj`'s perspective, that head contributes nothing. + +--- + +## Component Ablation — Deterministic Mode + +**Models used:** The SPD `ComponentModel` wrapping the target model. `target_model = spd_model.target_model` — same instance. + +**Mechanism:** + +1. **Mask construction**: For every module in the SPD model, creates a mask tensor of shape `(batch, C)` where `C` is the number of components. Baseline: all ones. Ablated: all ones except the target component indices are set to 0. + +2. These are wrapped via `make_mask_infos()` into `ComponentsMaskInfo` objects (containing `component_mask` and `routing_mask="all"`). + +3. **Forward pass**: `spd_model(input_ids, mask_infos=...)` registers PyTorch forward hooks on each target module (e.g. `h.1.attn.q_proj`). During the forward: + - The hook intercepts the module's input `x` and output + - Instead of using the original module output, it calls `components(x, mask=component_mask)` which computes: **`output = sum_c(mask[c] * outer(U[c], V[:, c]) @ x)`** — i.e. the reconstructed output is a masked sum of rank-1 component contributions + - With `routing_mask="all"`, *all positions* use the component reconstruction (not the original module) + +4. The `patched_attention_forward` context manager is also active on `target_model` (same instance), so it captures attention patterns from within the attention block. The q/k/v projections fire through SPD's hooks (producing component-masked outputs), then the patched attention forward computes softmax attention manually. + +**What "ablating a component" means here**: Setting `mask[c] = 0` removes that component's rank-1 contribution from the module's reconstructed weight matrix. With all-ones mask: `W_recon = sum_c(U[c] V[c]^T)`. With component `c` zeroed: `W_recon = sum_{i!=c}(U[i] V[i]^T)`. + +--- + +## Component Ablation — Stochastic Mode + +**Additional step**: Before masking, runs a forward pass with `cache_type="input"` to cache activations, then computes CI (causal importance) via `spd_model.calc_causal_importances()`. + +**Mask formula** (from `calc_stochastic_component_mask_info`): +``` +mask[c] = CI[c] + (1 - CI[c]) * random_source[c] +``` +Where `random_source` is `torch.rand_like(ci)` for "continuous" sampling. + +- If `CI[c] = 1`: mask is always 1 (component always fully active) +- If `CI[c] = 0`: mask is uniformly random in [0, 1) +- If `CI[c] = 0.5`: mask is in [0.5, 1) + +**Ablation**: For the target components, CI is forced to 0 *before* sampling, so their masks become purely random rather than biased toward 1. + +**Averaging**: Runs `n_mask_samples` (default 10) stochastic forward passes for both baseline and ablated. Averages logits and attention patterns across samples. This gives an expectation over the stochastic masking distribution. + +**Important detail**: Both baseline and ablated use stochastic masks — so the baseline is *not* the all-ones deterministic case. The baseline has stochastic noise too, just with the original CI values. The comparison isolates the effect of removing CI for specific components. + +--- + +## Component Ablation — Adversarial Mode + +Runs PGD (projected gradient descent) to find worst-case masks: + +1. Computes CI as in stochastic mode +2. Calls `pgd_masked_recon_loss_update()` which optimizes adversarial sources to maximize reconstruction loss subject to the CI constraint (`mask = CI + (1-CI) * source`) +3. Reports baseline and ablated PGD loss + +**However**, for the attention pattern visualization and prediction table, it **falls back to deterministic masks**. The PGD losses are logged but the plots show the same thing as deterministic mode. + +--- + +## Attention Pattern Capture + +**What's stored**: `att.float().mean(dim=0).detach().cpu()` — shape `(n_heads, T, T)`, averaged over the batch dimension. + +**Implication**: With `batch_size=1` (default), the batch mean is a no-op. With larger batches, patterns would be averaged across different sequences within the batch. + +**Across samples**: Patterns are accumulated and divided by `n_samples` to get the mean. + +--- + +## Metrics + +**`frac_top1_changed`**: Total positions where `argmax(baseline) != argmax(ablated)`, divided by total positions across all samples. Uses first item in batch only. + +**`mean_kl_divergence`**: `KL(baseline || ablated)` averaged over positions, then averaged over samples. Computed as: +```python +kl = F.kl_div(ablated_log_probs, baseline_log_probs.exp(), reduction="batchmean") +``` +This is `sum_pos sum_vocab p_baseline * log(p_baseline / p_ablated) / n_positions`. Also uses first batch item only. + +--- + +## Potential Validity Concerns + +1. **Adversarial mode plots don't show adversarial masks**: The PGD loss is computed but plots/predictions use deterministic masks. This means the adversarial mode's attention patterns and prediction tables are identical to deterministic mode. + +2. **Stochastic baseline also has noise**: The baseline in stochastic mode isn't the "true" model — it's a stochastic reconstruction. So the comparison measures "effect of removing this component from the stochastic ensemble" rather than "effect of removing this component from the faithful reconstruction." + +3. **Batch dimension in patterns**: Attention patterns are averaged over batch dim. With `batch_size=1` this is fine. With larger batches, different sequences' patterns would be mixed. + +4. **KL direction**: `KL(baseline || ablated)` — measures how surprised baseline would be by ablated. If ablated is more uniform, KL can be moderate even if top-1 changes a lot. + +5. **Head ablation zeros post-attention, not pre-attention**: The head still computes QKV and attention weights. Only its contribution to the residual stream (via `o_proj`) is zeroed. This is a standard choice (same as what TransformerLens does) but worth noting. + +6. **Component ablation in deterministic mode uses all-ones baseline**: This means the baseline is the SPD reconstruction `W = sum_c(U[c]V[c]^T)`, not the original target model weights. If the reconstruction isn't perfect, the baseline already differs from the original model. diff --git a/spd/scripts/attention_ablation_experiment/__init__.py b/spd/scripts/attention_ablation_experiment/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/attention_ablation_experiment/attention_ablation_experiment.py b/spd/scripts/attention_ablation_experiment/attention_ablation_experiment.py new file mode 100644 index 000000000..9c6c9e048 --- /dev/null +++ b/spd/scripts/attention_ablation_experiment/attention_ablation_experiment.py @@ -0,0 +1,1906 @@ +"""Measure effects of ablating attention heads or SPD parameter components. + +Supports three ablation modes for components: + - deterministic: all-ones masks as baseline, zero out target components + - stochastic: CI-based masks with stochastic sources, target CI forced to 0 + - adversarial: PGD-optimized worst-case masks, target CI forced to 0 + +Usage: + # Head ablation + python -m spd.scripts.attention_ablation_experiment.attention_ablation_experiment \ + wandb:goodfire/spd/runs/ --heads L0H3,L1H5 + + # Component ablation (deterministic) + python -m spd.scripts.attention_ablation_experiment.attention_ablation_experiment \ + wandb:goodfire/spd/runs/ --components "h.0.attn.q_proj:3,h.1.attn.k_proj:7" + + # Component ablation (stochastic) + python -m spd.scripts.attention_ablation_experiment.attention_ablation_experiment \ + wandb:goodfire/spd/runs/ --components "h.0.attn.q_proj:3" \ + --ablation_mode stochastic --n_mask_samples 10 + + # Component ablation (adversarial / PGD) + python -m spd.scripts.attention_ablation_experiment.attention_ablation_experiment \ + wandb:goodfire/spd/runs/ --components "h.0.attn.q_proj:3" \ + --ablation_mode adversarial --pgd_steps 50 --pgd_step_size 0.01 +""" + +import math +import random +import re +from collections.abc import Generator, Iterable +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal, NamedTuple + +import fire +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +from jaxtyping import Float, Int +from torch import Tensor + +from spd.configs import LMTaskConfig, PGDConfig, SamplingType +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import ComponentModel, OutputWithCache, SPDRunInfo +from spd.models.components import ComponentsMaskInfo, make_mask_infos +from spd.pretrain.models.llama_simple_mlp import CausalSelfAttention, LlamaSimpleMLP +from spd.routing import AllLayersRouter +from spd.spd_types import ModelPath +from spd.utils.component_utils import calc_stochastic_component_mask_info +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent + +AblationMode = Literal["deterministic", "stochastic", "adversarial"] + + +@dataclass +class ComponentHeadAblation: + """Subtract a component's contribution from a specific head's q/k at a position.""" + + layer: int + qk: Literal["q", "k"] + v_col: Tensor # (d_in,) + u_row: Tensor # (d_out,) + head: int + pos: int + + +def _extract_component_vectors( + spd_model: ComponentModel, + module_name: str, + comp_idx: int, +) -> tuple[Tensor, Tensor]: + """Return (V_col, U_row) for a specific component. Both detached.""" + components = spd_model.components[module_name] + return components.V[:, comp_idx].detach(), components.U[comp_idx, :].detach() + + +# ────────────────────────────────────────────────────────────────────────────── +# Parsing helpers +# ────────────────────────────────────────────────────────────────────────────── + + +def parse_heads(spec: str) -> list[tuple[int, int]]: + """Parse "L0H3,L1H5" → [(0, 3), (1, 5)].""" + heads: list[tuple[int, int]] = [] + for token in spec.split(","): + token = token.strip() + m = re.fullmatch(r"L(\d+)H(\d+)", token) + assert m is not None, f"Bad head spec: {token!r}, expected e.g. L0H3" + heads.append((int(m.group(1)), int(m.group(2)))) + return heads + + +def parse_components(spec: str) -> list[tuple[str, int]]: + """Parse "h.0.attn.q_proj:3,h.1.attn.k_proj:7" → [("h.0.attn.q_proj", 3), ...].""" + components: list[tuple[str, int]] = [] + for token in spec.split(","): + token = token.strip() + parts = token.rsplit(":", 1) + assert len(parts) == 2, f"Bad component spec: {token!r}, expected e.g. h.0.attn.q_proj:3" + components.append((parts[0], int(parts[1]))) + return components + + +# ────────────────────────────────────────────────────────────────────────────── +# Patched attention forward (context manager) +# ────────────────────────────────────────────────────────────────────────────── + +AttentionPatterns = dict[int, Float[Tensor, "n_heads T T"]] +ValueVectors = dict[int, Float[Tensor, "n_heads T head_dim"]] +AttnOutputs = dict[int, Float[Tensor, "T d_model"]] + + +class AttentionData(NamedTuple): + patterns: AttentionPatterns # layer → (n_heads, T, T) + values: ValueVectors # layer → (n_heads, T, head_dim) + attn_outputs: AttnOutputs # layer → (T, d_model) + + +@contextmanager +def patched_attention_forward( + target_model: LlamaSimpleMLP, + head_pos_ablations: list[tuple[int, int, int]] | None = None, + value_pos_ablations: list[tuple[int, int]] | None = None, + value_head_pos_ablations: list[tuple[int, int, int]] | None = None, + component_head_ablations: list[ComponentHeadAblation] | None = None, +) -> Generator[AttentionData]: + """Replace each CausalSelfAttention.forward to capture attention patterns and values. + + Yields AttentionData containing: + - patterns: layer_index → attention pattern tensor (n_heads, T, T), mean over batch + - values: layer_index → value vectors (n_heads, T, head_dim), mean over batch + """ + patterns: AttentionPatterns = {} + values: ValueVectors = {} + attn_outs: AttnOutputs = {} + originals: dict[int, object] = {} + + for layer_idx, block in enumerate(target_model._h): + attn: CausalSelfAttention = block.attn + originals[layer_idx] = attn.forward + + def _make_patched_forward(attn_module: CausalSelfAttention, li: int) -> object: + def _patched_forward( + x: Float[Tensor, "batch pos d_model"], + attention_mask: Int[Tensor, "batch offset_pos"] | None = None, + position_ids: Int[Tensor, "batch pos"] | None = None, + _past_key_value: tuple[Tensor, Tensor] | None = None, + ) -> Float[Tensor, "batch pos d_model"]: + B, T, C = x.size() + + q = attn_module.q_proj(x) + k = attn_module.k_proj(x) + v = attn_module.v_proj(x) + + if component_head_ablations is not None: + for abl in component_head_ablations: + if abl.layer == li and abl.pos < T: + comp_act = x[:, abl.pos, :] @ abl.v_col.to(x.device) + comp_out = comp_act.unsqueeze(-1) * abl.u_row.to(x.device) + hd = attn_module.head_dim + hs = slice(abl.head * hd, (abl.head + 1) * hd) + if abl.qk == "q": + q[:, abl.pos, hs] -= comp_out[:, hs] + else: + k[:, abl.pos, hs] -= comp_out[:, hs] + + q = q.view(B, T, attn_module.n_head, attn_module.head_dim).transpose(1, 2) + k = k.view(B, T, attn_module.n_key_value_heads, attn_module.head_dim).transpose( + 1, 2 + ) + v = v.view(B, T, attn_module.n_key_value_heads, attn_module.head_dim).transpose( + 1, 2 + ) + + if position_ids is None: + if attention_mask is not None: + position_ids = attn_module.get_offset_position_ids(0, attention_mask) + else: + position_ids = torch.arange(T, device=x.device).unsqueeze(0) + + position_ids = position_ids.clamp(max=attn_module.n_ctx - 1) + cos = attn_module.rotary_cos[position_ids].to(q.dtype) # pyright: ignore[reportIndexIssue] + sin = attn_module.rotary_sin[position_ids].to(q.dtype) # pyright: ignore[reportIndexIssue] + q, k = attn_module.apply_rotary_pos_emb(q, k, cos, sin) + + if attn_module.use_grouped_query_attention and attn_module.repeat_kv_heads > 1: + k = k.repeat_interleave(attn_module.repeat_kv_heads, dim=1) + v = v.repeat_interleave(attn_module.repeat_kv_heads, dim=1) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(attn_module.head_dim)) + causal_mask = attn_module.bias[:, :, :T, :T] # pyright: ignore[reportIndexIssue] + att = att.masked_fill(causal_mask == 0, float("-inf")) + att = F.softmax(att, dim=-1) + + patterns[li] = att.float().mean(dim=0).detach().cpu() + + if value_pos_ablations is not None: + for abl_layer, abl_pos in value_pos_ablations: + if abl_layer == li and abl_pos < T: + v[:, :, abl_pos, :] = 0.0 + if value_head_pos_ablations is not None: + for abl_layer, abl_head, abl_pos in value_head_pos_ablations: + if abl_layer == li and abl_pos < T: + v[:, abl_head, abl_pos, :] = 0.0 + + y = att @ v # (B, n_head, T, head_dim) + + if head_pos_ablations is not None: + for abl_layer, abl_head, abl_pos in head_pos_ablations: + if abl_layer == li and abl_pos < T: + y[:, abl_head, abl_pos, :] = 0.0 + + values[li] = v.float().mean(dim=0).detach().cpu() + + y = y.transpose(1, 2).contiguous().view(B, T, C) + y = attn_module.o_proj(y) + + attn_outs[li] = y.float().mean(dim=0).detach().cpu() + + return y + + return _patched_forward + + attn.forward = _make_patched_forward(attn, layer_idx) # pyright: ignore[reportAttributeAccessIssue] + + try: + yield AttentionData(patterns, values, attn_outs) + finally: + for layer_idx, block in enumerate(target_model._h): + block.attn.forward = originals[layer_idx] # pyright: ignore[reportAttributeAccessIssue] + + +# ────────────────────────────────────────────────────────────────────────────── +# Plotting +# ────────────────────────────────────────────────────────────────────────────── + + +def plot_attention_grid( + patterns: AttentionPatterns, + title: str, + path: Path, + max_pos: int, +) -> None: + n_layers = len(patterns) + n_heads = patterns[0].shape[0] + + fig, axes = plt.subplots(n_layers, n_heads, figsize=(n_heads * 3, n_layers * 3), squeeze=False) + + for layer_idx in range(n_layers): + for h in range(n_heads): + ax = axes[layer_idx, h] + pat = patterns[layer_idx][h, :max_pos, :max_pos].numpy() + ax.imshow(pat, aspect="auto", cmap="viridis", vmin=0) + ax.set_title(f"L{layer_idx}H{h}", fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + if h == 0: + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + + fig.suptitle(title, fontsize=13, fontweight="bold") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_attention_diff( + baseline: AttentionPatterns, + ablated: AttentionPatterns, + title: str, + path: Path, + max_pos: int, +) -> None: + n_layers = len(baseline) + n_heads = baseline[0].shape[0] + + fig, axes = plt.subplots(n_layers, n_heads, figsize=(n_heads * 3, n_layers * 3), squeeze=False) + + for layer_idx in range(n_layers): + for h in range(n_heads): + ax = axes[layer_idx, h] + diff = ( + ablated[layer_idx][h, :max_pos, :max_pos] + - baseline[layer_idx][h, :max_pos, :max_pos] + ).numpy() + vmax = max(abs(diff.min()), abs(diff.max()), 1e-8) + ax.imshow(diff, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + ax.set_title(f"L{layer_idx}H{h}", fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + if h == 0: + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + + fig.suptitle(title, fontsize=13, fontweight="bold") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_value_norms( + values: ValueVectors, + title: str, + path: Path, + max_pos: int, +) -> None: + n_layers = len(values) + fig, axes = plt.subplots(n_layers, 1, figsize=(8, n_layers * 2.5), squeeze=False) + + for layer_idx in range(n_layers): + ax = axes[layer_idx, 0] + norms = values[layer_idx][:, :max_pos, :].norm(dim=-1).numpy() # (n_heads, max_pos) + im = ax.imshow(norms, aspect="auto", cmap="viridis") + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + ax.set_xlabel("Position", fontsize=8) + n_heads = norms.shape[0] + ax.set_yticks(range(n_heads)) + ax.set_yticklabels([f"H{h}" for h in range(n_heads)], fontsize=8) + fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02) + + fig.suptitle(title, fontsize=13, fontweight="bold") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_value_norms_diff( + baseline_values: ValueVectors, + ablated_values: ValueVectors, + title: str, + path: Path, + max_pos: int, +) -> None: + n_layers = len(baseline_values) + fig, axes = plt.subplots(n_layers, 1, figsize=(8, n_layers * 2.5), squeeze=False) + + for layer_idx in range(n_layers): + ax = axes[layer_idx, 0] + baseline_norms = baseline_values[layer_idx][:, :max_pos, :].norm(dim=-1) + ablated_norms = ablated_values[layer_idx][:, :max_pos, :].norm(dim=-1) + diff = (ablated_norms - baseline_norms).numpy() + vmax = max(abs(diff.min()), abs(diff.max()), 1e-8) + im = ax.imshow(diff, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + ax.set_xlabel("Position", fontsize=8) + n_heads = diff.shape[0] + ax.set_yticks(range(n_heads)) + ax.set_yticklabels([f"H{h}" for h in range(n_heads)], fontsize=8) + fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02) + + fig.suptitle(title, fontsize=13, fontweight="bold") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def compute_ablation_metrics( + baseline_attn_outputs: AttnOutputs, + ablated_attn_outputs: AttnOutputs, +) -> tuple[AttnOutputs, AttnOutputs]: + """Compute per-position normalized inner product and cosine similarity. + + Normalized IP: dot(baseline, ablated) / ||baseline||². 1.0 at unaffected positions. + Cosine sim: cos(baseline, ablated). 1.0 at unaffected positions. + + Returns (normalized_ips, cosine_sims) where each is layer → (T,). + """ + normalized_ips: AttnOutputs = {} + cosine_sims: AttnOutputs = {} + for layer_idx in baseline_attn_outputs: + baseline = baseline_attn_outputs[layer_idx] + ablated = ablated_attn_outputs[layer_idx] + ip = (baseline * ablated).sum(dim=-1) + baseline_norm_sq = (baseline * baseline).sum(dim=-1) + normalized_ips[layer_idx] = ip / baseline_norm_sq.clamp(min=1e-8) + norms_product = baseline.norm(dim=-1) * ablated.norm(dim=-1) + cosine_sims[layer_idx] = ip / norms_product.clamp(min=1e-8) + return normalized_ips, cosine_sims + + +def plot_per_position_line( + values: AttnOutputs, + title: str, + path: Path, + max_pos: int, + baseline_y: float = 0.0, + ylim: tuple[float, float] | None = None, +) -> None: + n_layers = len(values) + fig, axes = plt.subplots(n_layers, 1, figsize=(10, n_layers * 2), squeeze=False) + + for layer_idx in range(n_layers): + ax = axes[layer_idx, 0] + vals = values[layer_idx][:max_pos].numpy() + ax.plot(vals, linewidth=0.8) + ax.axhline(y=baseline_y, color="gray", linewidth=0.5, linestyle="--") + if ylim is not None: + ax.set_ylim(*ylim) + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + ax.set_xlabel("Position", fontsize=8) + + fig.suptitle(title, fontsize=13, fontweight="bold") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def compute_ablation_metrics_at_pos( + baseline_attn_outputs: AttnOutputs, + ablated_attn_outputs: AttnOutputs, + pos: int, +) -> tuple[dict[int, float], dict[int, float]]: + """Compute normalized inner product and cosine similarity at a single position. + + Returns (normalized_ips, cosine_sims) — layer -> scalar. + """ + normalized_ips: dict[int, float] = {} + cosine_sims: dict[int, float] = {} + for layer_idx in baseline_attn_outputs: + baseline_vec = baseline_attn_outputs[layer_idx][pos] + ablated_vec = ablated_attn_outputs[layer_idx][pos] + ip = (baseline_vec * ablated_vec).sum().item() + baseline_norm_sq = (baseline_vec * baseline_vec).sum().item() + normalized_ips[layer_idx] = ip / max(baseline_norm_sq, 1e-8) + norms_product = baseline_vec.norm().item() * ablated_vec.norm().item() + cosine_sims[layer_idx] = ip / max(norms_product, 1e-8) + return normalized_ips, cosine_sims + + +def plot_output_similarity_bars( + means: dict[int, float], + stds: dict[int, float], + title: str, + path: Path, +) -> None: + layers = sorted(means.keys()) + mean_vals = [means[li] for li in layers] + std_vals = [stds[li] for li in layers] + layer_labels = [f"L{li}" for li in layers] + + fig, ax = plt.subplots(figsize=(max(4, len(layers) * 1.2), 4)) + ax.bar(layer_labels, mean_vals, yerr=std_vals, capsize=4, color="steelblue", alpha=0.8) + ax.axhline(y=0, color="gray", linewidth=0.5, linestyle="--") + ax.set_xlabel("Layer") + ax.set_ylabel("Value") + ax.set_title(title, fontsize=11, fontweight="bold") + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +# ────────────────────────────────────────────────────────────────────────────── +# Token prediction table & stats +# ────────────────────────────────────────────────────────────────────────────── + + +def log_prediction_table( + input_ids: Tensor, + baseline_logits: Tensor, + ablated_logits: Tensor, + tokenizer: object, + last_n: int = 20, +) -> int: + """Log per-position prediction changes. Returns count of changed positions.""" + seq_len = input_ids.shape[0] + baseline_probs = F.softmax(baseline_logits, dim=-1) + ablated_probs = F.softmax(ablated_logits, dim=-1) + + baseline_top = baseline_probs.argmax(dim=-1) + ablated_top = ablated_probs.argmax(dim=-1) + + changed_mask = baseline_top != ablated_top + changed_positions = changed_mask.nonzero(as_tuple=True)[0].tolist() + show_positions = set(changed_positions) | set(range(max(0, seq_len - last_n), seq_len)) + + decode = tokenizer.decode # pyright: ignore[reportAttributeAccessIssue] + lines = [f"{'Pos':>4} | {'Token':>10} | {'Baseline (prob)':>20} | {'Ablated (prob)':>20} | Chg"] + lines.append("-" * 85) + + for pos in sorted(show_positions): + tok = decode([input_ids[pos].item()]).replace("\n", "\\n") + b_id = int(baseline_top[pos].item()) + a_id = int(ablated_top[pos].item()) + b_tok = decode([b_id]).replace("\n", "\\n") + a_tok = decode([a_id]).replace("\n", "\\n") + b_prob = baseline_probs[pos, b_id].item() + a_prob = ablated_probs[pos, a_id].item() + changed = " *" if pos in changed_positions else "" + lines.append( + f"{pos:>4} | {tok:>10} | {b_tok:>10} ({b_prob:.3f}) | {a_tok:>10} ({a_prob:.3f}) |{changed}" + ) + + logger.info("Prediction table:\n" + "\n".join(lines)) + return len(changed_positions) + + +def calc_mean_kl_divergence( + baseline_logits: Tensor, + ablated_logits: Tensor, +) -> float: + """KL(baseline || ablated) averaged over positions, for first item in batch.""" + baseline_log_probs = F.log_softmax(baseline_logits, dim=-1) + ablated_log_probs = F.log_softmax(ablated_logits, dim=-1) + kl = F.kl_div(ablated_log_probs, baseline_log_probs.exp(), reduction="batchmean") + return kl.item() + + +# ────────────────────────────────────────────────────────────────────────────── +# Component mask construction +# ────────────────────────────────────────────────────────────────────────────── + + +def _build_deterministic_masks( + model: ComponentModel, + ablated_components: list[tuple[str, int]], + batch_shape: tuple[int, int], + device: torch.device, + ablation_pos: int, +) -> tuple[dict[str, ComponentsMaskInfo], dict[str, ComponentsMaskInfo]]: + """Build all-ones baseline and ablated mask_infos for deterministic mode. + + Masks have shape (batch, seq_len, C) and the target component is zeroed only at + ablation_pos. + """ + baseline_masks: dict[str, Float[Tensor, "batch seq_len C"]] = {} + ablated_masks: dict[str, Float[Tensor, "batch seq_len C"]] = {} + + for module_name in model.target_module_paths: + c = model.module_to_c[module_name] + baseline_masks[module_name] = torch.ones(*batch_shape, c, device=device) + ablated_masks[module_name] = torch.ones(*batch_shape, c, device=device) + + for module_name, comp_idx in ablated_components: + assert module_name in ablated_masks, f"Module {module_name!r} not in model" + ablated_masks[module_name][:, ablation_pos, comp_idx] = 0.0 + + return make_mask_infos(baseline_masks), make_mask_infos(ablated_masks) + + +def _build_deterministic_masks_multi_pos( + model: ComponentModel, + component_positions: list[tuple[str, int, int]], + batch_shape: tuple[int, int], + device: torch.device, +) -> tuple[dict[str, ComponentsMaskInfo], dict[str, ComponentsMaskInfo]]: + """Build masks where each component is zeroed at its own position. + + component_positions: list of (module_name, comp_idx, pos). + """ + baseline_masks: dict[str, Float[Tensor, "batch seq_len C"]] = {} + ablated_masks: dict[str, Float[Tensor, "batch seq_len C"]] = {} + + for module_name in model.target_module_paths: + c = model.module_to_c[module_name] + baseline_masks[module_name] = torch.ones(*batch_shape, c, device=device) + ablated_masks[module_name] = torch.ones(*batch_shape, c, device=device) + + for module_name, comp_idx, pos in component_positions: + assert module_name in ablated_masks, f"Module {module_name!r} not in model" + ablated_masks[module_name][:, pos, comp_idx] = 0.0 + + return make_mask_infos(baseline_masks), make_mask_infos(ablated_masks) + + +def _infer_layer_from_components(parsed_components: list[tuple[str, int]]) -> int: + """Extract layer index from component module paths (e.g. 'h.1.attn.q_proj' → 1).""" + layers: set[int] = set() + for module_name, _ in parsed_components: + m = re.search(r"h\.(\d+)\.", module_name) + assert m is not None, f"Cannot infer layer from {module_name!r}" + layers.add(int(m.group(1))) + assert len(layers) == 1, f"All components must be in the same layer, got layers {layers}" + return layers.pop() + + +def _build_prev_token_component_positions( + parsed_components: list[tuple[str, int]], + t: int, +) -> list[tuple[str, int, int]]: + """Assign positions based on module type: q_proj → t, k_proj → t-1.""" + positions: list[tuple[str, int, int]] = [] + for module_name, comp_idx in parsed_components: + if "q_proj" in module_name: + positions.append((module_name, comp_idx, t)) + elif "k_proj" in module_name: + positions.append((module_name, comp_idx, t - 1)) + else: + raise AssertionError( + f"prev_token_test only supports q_proj/k_proj components, got {module_name!r}" + ) + return positions + + +def _build_component_head_ablations( + spd_model: ComponentModel, + parsed_components: list[tuple[str, int]], + heads: list[tuple[int, int]], + t: int, +) -> list[ComponentHeadAblation]: + """Build per-head component ablations: q components at t, k components at t-1.""" + ablations: list[ComponentHeadAblation] = [] + for module_name, comp_idx in parsed_components: + v_col, u_row = _extract_component_vectors(spd_model, module_name, comp_idx) + if "q_proj" in module_name: + qk: Literal["q", "k"] = "q" + pos = t + elif "k_proj" in module_name: + qk = "k" + pos = t - 1 + else: + raise AssertionError( + f"per-head component ablation only supports q_proj/k_proj, got {module_name!r}" + ) + layer = _infer_layer_from_components([(module_name, comp_idx)]) + for _layer, head in heads: + ablations.append(ComponentHeadAblation(layer, qk, v_col, u_row, head, pos)) + return ablations + + +def _build_stochastic_masks( + _model: ComponentModel, + ci: dict[str, Float[Tensor, "batch C"]], + ablated_components: list[tuple[str, int]], + sampling: SamplingType, + ablation_pos: int, + seq_len: int, +) -> tuple[dict[str, ComponentsMaskInfo], dict[str, ComponentsMaskInfo]]: + """Build stochastic mask_infos: baseline uses original CI, ablated zeros target CIs. + + CI is expanded to (batch, seq_len, C) and the target component CI is zeroed only at + ablation_pos. + """ + router = AllLayersRouter() + + expanded_ci = {k: v.unsqueeze(1).expand(-1, seq_len, -1).clone() for k, v in ci.items()} + baseline_mask_infos = calc_stochastic_component_mask_info(expanded_ci, sampling, None, router) + + ablated_ci = {k: v.clone() for k, v in expanded_ci.items()} + for module_name, comp_idx in ablated_components: + assert module_name in ablated_ci, f"Module {module_name!r} not in model" + ablated_ci[module_name][:, ablation_pos, comp_idx] = 0.0 + ablated_mask_infos = calc_stochastic_component_mask_info(ablated_ci, sampling, None, router) + + return baseline_mask_infos, ablated_mask_infos + + +def _build_adversarial_masks( + model: ComponentModel, + batch: Int[Tensor, "batch pos"], + ci: dict[str, Float[Tensor, "... C"]], + target_out: Float[Tensor, "... vocab"], + ablated_components: list[tuple[str, int]], + pgd_config: PGDConfig, +) -> tuple[Float[Tensor, ""], Float[Tensor, ""]]: + """Run PGD for baseline and ablated, return (baseline_loss, ablated_loss).""" + from spd.metrics.pgd_utils import pgd_masked_recon_loss_update + + router = AllLayersRouter() + + baseline_sum_loss, baseline_n = pgd_masked_recon_loss_update( + model, batch, ci, None, target_out, "kl", router, pgd_config + ) + + ablated_ci = {k: v.clone() for k, v in ci.items()} + for module_name, comp_idx in ablated_components: + assert module_name in ablated_ci, f"Module {module_name!r} not in model" + ablated_ci[module_name][..., comp_idx] = 0.0 + + ablated_sum_loss, ablated_n = pgd_masked_recon_loss_update( + model, batch, ablated_ci, None, target_out, "kl", router, pgd_config + ) + + return baseline_sum_loss / baseline_n, ablated_sum_loss / ablated_n + + +# ────────────────────────────────────────────────────────────────────────────── +# Per-sample result + accumulation helpers +# ────────────────────────────────────────────────────────────────────────────── + + +class SampleResult(NamedTuple): + baseline_patterns: AttentionPatterns + ablated_patterns: AttentionPatterns + baseline_values: ValueVectors + ablated_values: ValueVectors + baseline_attn_outputs: AttnOutputs + ablated_attn_outputs: AttnOutputs + baseline_logits: Tensor # (batch, pos, vocab) + ablated_logits: Tensor # (batch, pos, vocab) + + +class PrevTokenSampleResult(NamedTuple): + baseline_attn_outputs: AttnOutputs + a_attn_outputs: AttnOutputs + b_all_attn_outputs: AttnOutputs + b_specific_attn_outputs: AttnOutputs + ab_all_attn_outputs: AttnOutputs + ab_specific_attn_outputs: AttnOutputs + baseline_logits: Tensor + a_logits: Tensor + a_b_all_logits: Tensor + ab_specific_logits: Tensor + + +def _add_patterns(accum: AttentionPatterns, new: AttentionPatterns) -> None: + for layer_idx, pat in new.items(): + if layer_idx in accum: + accum[layer_idx] = accum[layer_idx] + pat + else: + accum[layer_idx] = pat.clone() + + +def _scale_patterns(accum: AttentionPatterns, n: int) -> AttentionPatterns: + return {k: v / n for k, v in accum.items()} + + +# ────────────────────────────────────────────────────────────────────────────── +# Head ablation +# ────────────────────────────────────────────────────────────────────────────── + + +def _run_head_ablation( + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_heads: list[tuple[int, int]], + ablation_pos: int, +) -> SampleResult: + with patched_attention_forward(target_model) as baseline_data: + baseline_logits, _ = target_model(input_ids) + + pos_ablations = [(layer, head, ablation_pos) for layer, head in parsed_heads] + with patched_attention_forward(target_model, head_pos_ablations=pos_ablations) as ablated_data: + ablated_logits, _ = target_model(input_ids) + + assert baseline_logits is not None and ablated_logits is not None + return SampleResult( + baseline_data.patterns, + ablated_data.patterns, + baseline_data.values, + ablated_data.values, + baseline_data.attn_outputs, + ablated_data.attn_outputs, + baseline_logits, + ablated_logits, + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Component ablation +# ────────────────────────────────────────────────────────────────────────────── + + +def _run_component_ablation( + spd_model: ComponentModel, + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_components: list[tuple[str, int]], + ablation_mode: AblationMode, + n_mask_samples: int, + pgd_steps: int, + pgd_step_size: float, + ablation_pos: int, +) -> SampleResult: + match ablation_mode: + case "deterministic": + return _run_deterministic_component_ablation( + spd_model, target_model, input_ids, parsed_components, ablation_pos + ) + case "stochastic": + return _run_stochastic_component_ablation( + spd_model, target_model, input_ids, parsed_components, n_mask_samples, ablation_pos + ) + case "adversarial": + return _run_adversarial_component_ablation( + spd_model, + target_model, + input_ids, + parsed_components, + pgd_steps, + pgd_step_size, + ablation_pos, + ) + + +def _run_deterministic_component_ablation( + spd_model: ComponentModel, + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_components: list[tuple[str, int]], + ablation_pos: int, +) -> SampleResult: + batch_shape = (input_ids.shape[0], input_ids.shape[1]) + baseline_mask_infos, ablated_mask_infos = _build_deterministic_masks( + spd_model, parsed_components, batch_shape, input_ids.device, ablation_pos + ) + + with patched_attention_forward(target_model) as baseline_data: + baseline_out = spd_model(input_ids, mask_infos=baseline_mask_infos) + assert isinstance(baseline_out, Tensor) + + with patched_attention_forward(target_model) as ablated_data: + ablated_out = spd_model(input_ids, mask_infos=ablated_mask_infos) + assert isinstance(ablated_out, Tensor) + + return SampleResult( + baseline_data.patterns, + ablated_data.patterns, + baseline_data.values, + ablated_data.values, + baseline_data.attn_outputs, + ablated_data.attn_outputs, + baseline_out, + ablated_out, + ) + + +def _run_stochastic_component_ablation( + spd_model: ComponentModel, + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_components: list[tuple[str, int]], + n_mask_samples: int, + ablation_pos: int, +) -> SampleResult: + output_with_cache = spd_model(input_ids, cache_type="input") + assert isinstance(output_with_cache, OutputWithCache) + ci = spd_model.calc_causal_importances(output_with_cache.cache, "continuous").lower_leaky + + baseline_logits_accum: Tensor | None = None + ablated_logits_accum: Tensor | None = None + sample_baseline_patterns: AttentionPatterns = {} + sample_ablated_patterns: AttentionPatterns = {} + sample_baseline_values: ValueVectors = {} + sample_ablated_values: ValueVectors = {} + sample_baseline_attn_outs: AttnOutputs = {} + sample_ablated_attn_outs: AttnOutputs = {} + + stoch_seq_len = input_ids.shape[1] + for _s in range(n_mask_samples): + baseline_mask_infos, ablated_mask_infos = _build_stochastic_masks( + spd_model, ci, parsed_components, "continuous", ablation_pos, stoch_seq_len + ) + + with patched_attention_forward(target_model) as b_data: + b_out = spd_model(input_ids, mask_infos=baseline_mask_infos) + assert isinstance(b_out, Tensor) + + with patched_attention_forward(target_model) as a_data: + a_out = spd_model(input_ids, mask_infos=ablated_mask_infos) + assert isinstance(a_out, Tensor) + + if baseline_logits_accum is None: + baseline_logits_accum = b_out + ablated_logits_accum = a_out + else: + baseline_logits_accum = baseline_logits_accum + b_out + assert ablated_logits_accum is not None + ablated_logits_accum = ablated_logits_accum + a_out + + _add_patterns(sample_baseline_patterns, b_data.patterns) + _add_patterns(sample_ablated_patterns, a_data.patterns) + _add_patterns(sample_baseline_values, b_data.values) + _add_patterns(sample_ablated_values, a_data.values) + _add_patterns(sample_baseline_attn_outs, b_data.attn_outputs) + _add_patterns(sample_ablated_attn_outs, a_data.attn_outputs) + + assert baseline_logits_accum is not None and ablated_logits_accum is not None + return SampleResult( + _scale_patterns(sample_baseline_patterns, n_mask_samples), + _scale_patterns(sample_ablated_patterns, n_mask_samples), + _scale_patterns(sample_baseline_values, n_mask_samples), + _scale_patterns(sample_ablated_values, n_mask_samples), + _scale_patterns(sample_baseline_attn_outs, n_mask_samples), + _scale_patterns(sample_ablated_attn_outs, n_mask_samples), + baseline_logits_accum / n_mask_samples, + ablated_logits_accum / n_mask_samples, + ) + + +def _run_adversarial_component_ablation( + spd_model: ComponentModel, + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_components: list[tuple[str, int]], + pgd_steps: int, + pgd_step_size: float, + ablation_pos: int, +) -> SampleResult: + output_with_cache = spd_model(input_ids, cache_type="input") + assert isinstance(output_with_cache, OutputWithCache) + ci = spd_model.calc_causal_importances(output_with_cache.cache, "continuous").lower_leaky + + target_out = output_with_cache.output + + pgd_config = PGDConfig( + init="random", + step_size=pgd_step_size, + n_steps=pgd_steps, + mask_scope="unique_per_datapoint", + ) + + baseline_loss, ablated_loss = _build_adversarial_masks( + spd_model, input_ids, ci, target_out, parsed_components, pgd_config + ) + logger.info( + f"PGD losses — baseline: {baseline_loss.item():.4f}, ablated: {ablated_loss.item():.4f}" + ) + + # Capture attention patterns with deterministic masks for visualization + batch_shape = (input_ids.shape[0], input_ids.shape[1]) + baseline_mask_infos, ablated_mask_infos = _build_deterministic_masks( + spd_model, parsed_components, batch_shape, input_ids.device, ablation_pos + ) + + with patched_attention_forward(target_model) as baseline_data: + baseline_out = spd_model(input_ids, mask_infos=baseline_mask_infos) + assert isinstance(baseline_out, Tensor) + + with patched_attention_forward(target_model) as ablated_data: + ablated_out = spd_model(input_ids, mask_infos=ablated_mask_infos) + assert isinstance(ablated_out, Tensor) + + return SampleResult( + baseline_data.patterns, + ablated_data.patterns, + baseline_data.values, + ablated_data.values, + baseline_data.attn_outputs, + ablated_data.attn_outputs, + baseline_out, + ablated_out, + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Previous-token redundancy test +# ────────────────────────────────────────────────────────────────────────────── + + +def _capture_attn_outputs( + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + head_pos_ablations: list[tuple[int, int, int]] | None = None, + value_pos_ablations: list[tuple[int, int]] | None = None, + value_head_pos_ablations: list[tuple[int, int, int]] | None = None, + component_head_ablations: list[ComponentHeadAblation] | None = None, + spd_model: ComponentModel | None = None, + mask_infos: dict[str, ComponentsMaskInfo] | None = None, +) -> tuple[AttnOutputs, Tensor]: + """Run a forward pass capturing attention outputs and logits.""" + with patched_attention_forward( + target_model, + head_pos_ablations, + value_pos_ablations, + value_head_pos_ablations, + component_head_ablations, + ) as data: + if spd_model is not None: + out = spd_model(input_ids, mask_infos=mask_infos) + assert isinstance(out, Tensor) + else: + out, _ = target_model(input_ids) + assert out is not None + return data.attn_outputs, out + + +def _run_prev_token_head_ablation( + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_heads: list[tuple[int, int]], + value_heads: list[tuple[int, int]], + t: int, +) -> PrevTokenSampleResult: + head_abl = [(layer, head, t) for layer, head in parsed_heads] + layer = parsed_heads[0][0] + val_all = [(layer, t - 1)] + val_specific = [(layer, head, t - 1) for layer, head in value_heads] + + baseline_outs, baseline_logits = _capture_attn_outputs(target_model, input_ids) + a_outs, a_logits = _capture_attn_outputs(target_model, input_ids, head_pos_ablations=head_abl) + b_all_outs, _b_all_logits = _capture_attn_outputs( + target_model, input_ids, value_pos_ablations=val_all + ) + b_spec_outs, _b_spec_logits = _capture_attn_outputs( + target_model, input_ids, value_head_pos_ablations=val_specific + ) + ab_all_outs, a_b_all_logits = _capture_attn_outputs( + target_model, input_ids, head_pos_ablations=head_abl, value_pos_ablations=val_all + ) + ab_spec_outs, a_b_spec_logits = _capture_attn_outputs( + target_model, input_ids, head_pos_ablations=head_abl, value_head_pos_ablations=val_specific + ) + + return PrevTokenSampleResult( + baseline_outs, + a_outs, + b_all_outs, + b_spec_outs, + ab_all_outs, + ab_spec_outs, + baseline_logits, + a_logits, + a_b_all_logits, + a_b_spec_logits, + ) + + +def _run_prev_token_component_ablation( + spd_model: ComponentModel, + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_components: list[tuple[str, int]], + value_heads: list[tuple[int, int]], + t: int, +) -> PrevTokenSampleResult: + layer = _infer_layer_from_components(parsed_components) + component_positions = _build_prev_token_component_positions(parsed_components, t) + batch_shape = (input_ids.shape[0], input_ids.shape[1]) + + baseline_masks, ablated_masks = _build_deterministic_masks_multi_pos( + spd_model, component_positions, batch_shape, input_ids.device + ) + val_all = [(layer, t - 1)] + val_specific = [(layer, head, t - 1) for layer, head in value_heads] + + baseline_outs, baseline_logits = _capture_attn_outputs( + target_model, input_ids, spd_model=spd_model, mask_infos=baseline_masks + ) + a_outs, a_logits = _capture_attn_outputs( + target_model, input_ids, spd_model=spd_model, mask_infos=ablated_masks + ) + b_all_outs, _b_all_logits = _capture_attn_outputs( + target_model, + input_ids, + value_pos_ablations=val_all, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + b_spec_outs, _b_spec_logits = _capture_attn_outputs( + target_model, + input_ids, + value_head_pos_ablations=val_specific, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + ab_all_outs, a_b_all_logits = _capture_attn_outputs( + target_model, + input_ids, + value_pos_ablations=val_all, + spd_model=spd_model, + mask_infos=ablated_masks, + ) + ab_spec_outs, a_b_spec_logits = _capture_attn_outputs( + target_model, + input_ids, + value_head_pos_ablations=val_specific, + spd_model=spd_model, + mask_infos=ablated_masks, + ) + + return PrevTokenSampleResult( + baseline_outs, + a_outs, + b_all_outs, + b_spec_outs, + ab_all_outs, + ab_spec_outs, + baseline_logits, + a_logits, + a_b_all_logits, + a_b_spec_logits, + ) + + +def _run_prev_token_head_restricted_component_ablation( + spd_model: ComponentModel, + target_model: LlamaSimpleMLP, + input_ids: Int[Tensor, "batch pos"], + parsed_components: list[tuple[str, int]], + restrict_heads: list[tuple[int, int]], + value_heads: list[tuple[int, int]], + t: int, +) -> PrevTokenSampleResult: + """Per-head component ablation: subtract component contributions from specific heads only.""" + layer = _infer_layer_from_components(parsed_components) + comp_head_abls = _build_component_head_ablations( + spd_model, parsed_components, restrict_heads, t + ) + batch_shape = (input_ids.shape[0], input_ids.shape[1]) + + # All-ones baseline masks (SPD model reconstructs original output) + baseline_masks, _ = _build_deterministic_masks(spd_model, [], batch_shape, input_ids.device, t) + val_all = [(layer, t - 1)] + val_specific = [(layer, head, t - 1) for layer, head in value_heads] + + baseline_outs, baseline_logits = _capture_attn_outputs( + target_model, input_ids, spd_model=spd_model, mask_infos=baseline_masks + ) + a_outs, a_logits = _capture_attn_outputs( + target_model, + input_ids, + component_head_ablations=comp_head_abls, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + b_all_outs, _b_all_logits = _capture_attn_outputs( + target_model, + input_ids, + value_pos_ablations=val_all, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + b_spec_outs, _b_spec_logits = _capture_attn_outputs( + target_model, + input_ids, + value_head_pos_ablations=val_specific, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + ab_all_outs, a_b_all_logits = _capture_attn_outputs( + target_model, + input_ids, + component_head_ablations=comp_head_abls, + value_pos_ablations=val_all, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + ab_spec_outs, a_b_spec_logits = _capture_attn_outputs( + target_model, + input_ids, + component_head_ablations=comp_head_abls, + value_head_pos_ablations=val_specific, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + + return PrevTokenSampleResult( + baseline_outs, + a_outs, + b_all_outs, + b_spec_outs, + ab_all_outs, + ab_spec_outs, + baseline_logits, + a_logits, + a_b_all_logits, + a_b_spec_logits, + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Offset sweep +# ────────────────────────────────────────────────────────────────────────────── + + +def _run_offset_sweep( + target_model: LlamaSimpleMLP, + spd_model: ComponentModel | None, + loader: Iterable[dict[str, Tensor]], + is_head_ablation: bool, + parsed_heads: list[tuple[int, int]], + parsed_components: list[tuple[str, int]], + parsed_restrict_heads: list[tuple[int, int]], + n_samples: int, + max_offsets: int, + max_pos: int, + seq_len: int, + run_id: str, + label: str, + sim_dir: Path, + column_name: str, + device: torch.device, +) -> None: + """Sweep value ablation across offsets 1..max_offsets to profile which positions matter.""" + if is_head_ablation: + layer = parsed_heads[0][0] + else: + layer = _infer_layer_from_components(parsed_components) + + # offset -> layer -> list[float] + base_vs_a_nips: dict[str, dict[int, list[float]]] = {"nip": {}, "cos": {}} + base_vs_ab_by_offset: dict[int, dict[str, dict[int, list[float]]]] = { + offset: {"nip": {}, "cos": {}} for offset in range(1, max_offsets + 1) + } + n_processed = 0 + + with torch.no_grad(): + for i, batch_data in enumerate(loader): + if i >= n_samples: + break + + input_ids: Int[Tensor, "batch pos"] = batch_data[column_name][:, :seq_len].to(device) + + sample_seq_len = input_ids.shape[1] + rng = random.Random(i) + t = rng.randint(max_offsets, min(sample_seq_len, max_pos) - 1) + + # Compute baseline and A once + if is_head_ablation: + head_abl = [(ly, hd, t) for ly, hd in parsed_heads] + baseline_outs, _ = _capture_attn_outputs(target_model, input_ids) + a_outs, _ = _capture_attn_outputs( + target_model, input_ids, head_pos_ablations=head_abl + ) + elif parsed_restrict_heads: + assert spd_model is not None + comp_head_abls = _build_component_head_ablations( + spd_model, parsed_components, parsed_restrict_heads, t + ) + batch_shape = (input_ids.shape[0], input_ids.shape[1]) + baseline_masks, _ = _build_deterministic_masks( + spd_model, [], batch_shape, input_ids.device, t + ) + baseline_outs, _ = _capture_attn_outputs( + target_model, input_ids, spd_model=spd_model, mask_infos=baseline_masks + ) + a_outs, _ = _capture_attn_outputs( + target_model, + input_ids, + component_head_ablations=comp_head_abls, + spd_model=spd_model, + mask_infos=baseline_masks, + ) + else: + assert spd_model is not None + component_positions = _build_prev_token_component_positions(parsed_components, t) + batch_shape = (input_ids.shape[0], input_ids.shape[1]) + baseline_masks, ablated_masks = _build_deterministic_masks_multi_pos( + spd_model, component_positions, batch_shape, input_ids.device + ) + baseline_outs, _ = _capture_attn_outputs( + target_model, input_ids, spd_model=spd_model, mask_infos=baseline_masks + ) + a_outs, _ = _capture_attn_outputs( + target_model, input_ids, spd_model=spd_model, mask_infos=ablated_masks + ) + + base_vs_a_nip, base_vs_a_cos = compute_ablation_metrics_at_pos(baseline_outs, a_outs, t) + _accum_comparison(base_vs_a_nips, base_vs_a_nip, base_vs_a_cos) + + # Sweep offsets + for offset in range(1, max_offsets + 1): + val_pos = t - offset + assert val_pos >= 0 + val_all = [(layer, val_pos)] + + if is_head_ablation: + head_abl = [(ly, hd, t) for ly, hd in parsed_heads] + ab_outs, _ = _capture_attn_outputs( + target_model, + input_ids, + head_pos_ablations=head_abl, + value_pos_ablations=val_all, + ) + elif parsed_restrict_heads: + assert spd_model is not None + cha = _build_component_head_ablations( + spd_model, parsed_components, parsed_restrict_heads, t + ) + bs = (input_ids.shape[0], input_ids.shape[1]) + bm, _ = _build_deterministic_masks(spd_model, [], bs, input_ids.device, t) + ab_outs, _ = _capture_attn_outputs( + target_model, + input_ids, + component_head_ablations=cha, + value_pos_ablations=val_all, + spd_model=spd_model, + mask_infos=bm, + ) + else: + assert spd_model is not None + cp = _build_prev_token_component_positions(parsed_components, t) + bs = (input_ids.shape[0], input_ids.shape[1]) + _, am = _build_deterministic_masks_multi_pos( + spd_model, cp, bs, input_ids.device + ) + ab_outs, _ = _capture_attn_outputs( + target_model, + input_ids, + value_pos_ablations=val_all, + spd_model=spd_model, + mask_infos=am, + ) + + base_vs_ab_nip, base_vs_ab_cos = compute_ablation_metrics_at_pos( + baseline_outs, ab_outs, t + ) + _accum_comparison(base_vs_ab_by_offset[offset], base_vs_ab_nip, base_vs_ab_cos) + + n_processed += 1 + if (i + 1) % 10 == 0: + logger.info(f"Processed {i + 1}/{n_samples} samples") + + assert n_processed > 0, "No samples processed" + + # Compute means + base_vs_a_nip_mean = { + li: torch.tensor(vs).mean().item() for li, vs in base_vs_a_nips["nip"].items() + } + + # Plot offset profile for each layer + layers = sorted(base_vs_a_nip_mean.keys()) + offsets = list(range(1, max_offsets + 1)) + + for metric_key, metric_name in [("nip", "NIP"), ("cos", "Cosine Sim")]: + fig, axes = plt.subplots(len(layers), 1, figsize=(10, len(layers) * 2.5), squeeze=False) + + for li_idx, li in enumerate(layers): + ax = axes[li_idx, 0] + means = [ + torch.tensor(base_vs_ab_by_offset[o][metric_key][li]).mean().item() for o in offsets + ] + stds = [ + torch.tensor(base_vs_ab_by_offset[o][metric_key][li]).std().item() for o in offsets + ] + + ax.errorbar(offsets, means, yerr=stds, fmt="o-", capsize=3, markersize=4) + base_val = torch.tensor(base_vs_a_nips[metric_key][li]).mean().item() + ax.axhline( + y=base_val, color="red", linewidth=0.8, linestyle="--", label="Baseline vs A" + ) + ax.axhline(y=1.0, color="gray", linewidth=0.5, linestyle=":") + ax.set_ylabel(f"Layer {li}", fontsize=9) + ax.set_xlabel("Offset (t - offset)", fontsize=8) + ax.legend(fontsize=7) + + fig.suptitle( + f"{run_id} | {metric_name} offset profile (n={n_processed})", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + path = sim_dir / f"offset_profile_{metric_key}_{label}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + # Log summary + logger.section(f"Offset sweep (n={n_processed})") + logger.info(f"Baseline vs A NIP: {base_vs_a_nip_mean}") + for offset in offsets: + nip_means = { + li: torch.tensor(vs).mean().item() + for li, vs in base_vs_ab_by_offset[offset]["nip"].items() + } + logger.info(f" offset={offset}: Baseline vs AB NIP = {nip_means}") + + +# ────────────────────────────────────────────────────────────────────────────── +# Main entry point +# ────────────────────────────────────────────────────────────────────────────── + + +def _make_metric_bucket() -> dict[str, dict[int, list[float]]]: + return {"nip": {}, "cos": {}} + + +@dataclass +class _PrevTokenAggStats: + n_samples: int = 0 + base_vs_a: dict[str, dict[int, list[float]]] = field(default_factory=_make_metric_bucket) + base_vs_b_all: dict[str, dict[int, list[float]]] = field(default_factory=_make_metric_bucket) + base_vs_b_specific: dict[str, dict[int, list[float]]] = field( + default_factory=_make_metric_bucket + ) + base_vs_ab_all: dict[str, dict[int, list[float]]] = field(default_factory=_make_metric_bucket) + base_vs_ab_specific: dict[str, dict[int, list[float]]] = field( + default_factory=_make_metric_bucket + ) + a_vs_ab_all: dict[str, dict[int, list[float]]] = field(default_factory=_make_metric_bucket) + a_vs_ab_specific: dict[str, dict[int, list[float]]] = field(default_factory=_make_metric_bucket) + + +def _accum_comparison( + bucket: dict[str, dict[int, list[float]]], + nip: dict[int, float], + cos: dict[int, float], +) -> None: + for layer_idx, val in nip.items(): + bucket["nip"].setdefault(layer_idx, []).append(val) + for layer_idx, val in cos.items(): + bucket["cos"].setdefault(layer_idx, []).append(val) + + +def _run_prev_token_loop( + target_model: LlamaSimpleMLP, + spd_model: ComponentModel | None, + loader: Iterable[dict[str, Tensor]], + is_head_ablation: bool, + parsed_heads: list[tuple[int, int]], + parsed_components: list[tuple[str, int]], + parsed_value_heads: list[tuple[int, int]], + parsed_restrict_heads: list[tuple[int, int]], + n_samples: int, + max_plot_samples: int, + max_pos: int, + seq_len: int, + run_id: str, + label: str, + sim_dir: Path, + column_name: str, + device: torch.device, +) -> None: + stats = _PrevTokenAggStats() + comparisons = [ + ("base_vs_a", "Baseline vs A"), + ("base_vs_b_all", "Baseline vs B(all)"), + ("base_vs_b_specific", "Baseline vs B(specific)"), + ("base_vs_ab_all", "Baseline vs A+B(all)"), + ("base_vs_ab_specific", "Baseline vs A+B(specific)"), + ("a_vs_ab_all", "A vs A+B(all)"), + ("a_vs_ab_specific", "A vs A+B(specific)"), + ] + + with torch.no_grad(): + for i, batch_data in enumerate(loader): + if i >= n_samples: + break + + input_ids: Int[Tensor, "batch pos"] = batch_data[column_name][:, :seq_len].to(device) + + sample_seq_len = input_ids.shape[1] + rng = random.Random(i) + t = rng.randint(1, min(sample_seq_len, max_pos) - 1) + + if is_head_ablation: + result = _run_prev_token_head_ablation( + target_model, input_ids, parsed_heads, parsed_value_heads, t + ) + elif parsed_restrict_heads: + assert spd_model is not None + result = _run_prev_token_head_restricted_component_ablation( + spd_model, + target_model, + input_ids, + parsed_components, + parsed_restrict_heads, + parsed_value_heads, + t, + ) + else: + assert spd_model is not None + result = _run_prev_token_component_ablation( + spd_model, target_model, input_ids, parsed_components, parsed_value_heads, t + ) + + b = result.baseline_attn_outputs + pairs = [ + (b, result.a_attn_outputs), + (b, result.b_all_attn_outputs), + (b, result.b_specific_attn_outputs), + (b, result.ab_all_attn_outputs), + (b, result.ab_specific_attn_outputs), + (result.a_attn_outputs, result.ab_all_attn_outputs), + (result.a_attn_outputs, result.ab_specific_attn_outputs), + ] + + for (tag, desc), (out_a, out_b) in zip(comparisons, pairs, strict=True): + nip_at_t, cos_at_t = compute_ablation_metrics_at_pos(out_a, out_b, t) + _accum_comparison(getattr(stats, tag), nip_at_t, cos_at_t) + + if i < max_plot_samples: + sample_nip, sample_cos = compute_ablation_metrics(out_a, out_b) + plot_per_position_line( + sample_nip, + f"{run_id} | {desc} NIP sample {i} (t={t})", + sim_dir / f"{tag}_nip_sample{i}_{label}.png", + max_pos, + baseline_y=1.0, + ylim=(-1, 1), + ) + plot_per_position_line( + sample_cos, + f"{run_id} | {desc} cos sample {i} (t={t})", + sim_dir / f"{tag}_cosine_sim_sample{i}_{label}.png", + max_pos, + baseline_y=1.0, + ylim=(-1, 1), + ) + + stats.n_samples += 1 + if (i + 1) % 5 == 0: + logger.info(f"Processed {i + 1}/{n_samples} samples") + + assert stats.n_samples > 0, "No samples processed" + + for tag, desc in comparisons: + bucket = getattr(stats, tag) + nip_means = {li: torch.tensor(vs).mean().item() for li, vs in bucket["nip"].items()} + nip_stds = {li: torch.tensor(vs).std().item() for li, vs in bucket["nip"].items()} + cos_means = {li: torch.tensor(vs).mean().item() for li, vs in bucket["cos"].items()} + cos_stds = {li: torch.tensor(vs).std().item() for li, vs in bucket["cos"].items()} + + plot_output_similarity_bars( + nip_means, + nip_stds, + f"{run_id} | {desc} NIP (n={stats.n_samples})", + sim_dir / f"{tag}_nip_bars_{label}.png", + ) + plot_output_similarity_bars( + cos_means, + cos_stds, + f"{run_id} | {desc} cos (n={stats.n_samples})", + sim_dir / f"{tag}_cosine_sim_bars_{label}.png", + ) + + logger.section(f"{desc} at position t") + for layer_idx in sorted(nip_means.keys()): + logger.info( + f" Layer {layer_idx}: " + f"NIP = {nip_means[layer_idx]:.4f} ± {nip_stds[layer_idx]:.4f}, " + f"cos = {cos_means[layer_idx]:.4f} ± {cos_stds[layer_idx]:.4f}" + ) + + logger.info(f"All plots saved to {sim_dir}") + + +@dataclass +class _AggStats: + total_changed: int = 0 + total_positions: int = 0 + total_kl: float = 0.0 + n_samples: int = 0 + pos_ip_samples: dict[int, list[float]] = field(default_factory=dict) + pos_cos_samples: dict[int, list[float]] = field(default_factory=dict) + + +def run_attention_ablation( + wandb_path: ModelPath, + heads: str | None = None, + components: str | None = None, + ablation_mode: AblationMode = "deterministic", + n_samples: int = 10, + max_plot_samples: int = 6, + batch_size: int = 1, + n_mask_samples: int = 10, + pgd_steps: int = 50, + pgd_step_size: float = 0.01, + max_pos: int = 128, + prev_token_test: bool = False, + value_heads: str | None = None, + restrict_to_heads: str | None = None, + offset_sweep: int = 0, + seed: int = 42, +) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + + assert (heads is None) != (components is None), "Provide exactly one of --heads or --components" + if prev_token_test: + assert value_heads is not None, "--value_heads required when --prev_token_test is set" + if restrict_to_heads is not None: + assert components is not None, "--restrict_to_heads requires --components" + assert prev_token_test or offset_sweep > 0, ( + "--restrict_to_heads requires --prev_token_test or --offset_sweep" + ) + is_head_ablation = heads is not None + parsed_heads = parse_heads(heads) if heads else [] + parsed_components = parse_components(components) if components else [] + parsed_value_heads = parse_heads(value_heads) if value_heads else [] + parsed_restrict_heads = parse_heads(restrict_to_heads) if restrict_to_heads else [] + + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + config = run_info.config + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + spd_model: ComponentModel | None = None + if is_head_ablation: + assert config.pretrained_model_name is not None + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + target_model.requires_grad_(False) + for block in target_model._h: + block.attn.flash_attention = False + target_model = target_model.to(device) + else: + spd_model = ComponentModel.from_run_info(run_info) + spd_model.eval() + spd_model = spd_model.to(device) + target_model = spd_model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + for block in target_model._h: + block.attn.flash_attention = False + + seq_len = target_model.config.n_ctx + + # Data loader + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=1000, + ) + + def _short_module(m: str) -> str: + return m.replace("_proj", "") + + if is_head_ablation: + label = "_".join(f"L{layer}H{head}" for layer, head in parsed_heads) + else: + mode_suffix = "" if ablation_mode == "deterministic" else f"_{ablation_mode}" + label = "_".join(f"{_short_module(m)}:{c}" for m, c in parsed_components) + mode_suffix + logger.section(f"Attention ablation: {label}") + logger.info(f"run_id={run_id}, device={device}, n_samples={n_samples}") + + attn_dir = out_dir / "attention_patterns" + value_dir = out_dir / "value_norms" + sim_dir = out_dir / "output_similarity" + attn_dir.mkdir(parents=True, exist_ok=True) + value_dir.mkdir(parents=True, exist_ok=True) + sim_dir.mkdir(parents=True, exist_ok=True) + + if offset_sweep > 0: + _run_offset_sweep( + target_model=target_model, + spd_model=spd_model, + loader=loader, + is_head_ablation=is_head_ablation, + parsed_heads=parsed_heads, + parsed_components=parsed_components, + parsed_restrict_heads=parsed_restrict_heads, + n_samples=n_samples, + max_offsets=offset_sweep, + max_pos=max_pos, + seq_len=seq_len, + run_id=run_id, + label=label, + sim_dir=sim_dir, + column_name=task_config.column_name, + device=device, + ) + return + + if prev_token_test: + _run_prev_token_loop( + target_model=target_model, + spd_model=spd_model, + loader=loader, + is_head_ablation=is_head_ablation, + parsed_heads=parsed_heads, + parsed_components=parsed_components, + parsed_value_heads=parsed_value_heads, + parsed_restrict_heads=parsed_restrict_heads, + n_samples=n_samples, + max_plot_samples=max_plot_samples, + max_pos=max_pos, + seq_len=seq_len, + run_id=run_id, + label=label, + sim_dir=sim_dir, + column_name=task_config.column_name, + device=device, + ) + return + + accum_baseline_patterns: AttentionPatterns = {} + accum_ablated_patterns: AttentionPatterns = {} + accum_baseline_values: ValueVectors = {} + accum_ablated_values: ValueVectors = {} + stats = _AggStats() + + with torch.no_grad(): + for i, batch_data in enumerate(loader): + if i >= n_samples: + break + + input_ids: Int[Tensor, "batch pos"] = batch_data[task_config.column_name][ + :, :seq_len + ].to(device) + + sample_seq_len = input_ids.shape[1] + rng = random.Random(i) + ablation_pos = rng.randint(0, min(sample_seq_len, max_pos) - 1) + + if is_head_ablation: + result = _run_head_ablation(target_model, input_ids, parsed_heads, ablation_pos) + else: + assert spd_model is not None + result = _run_component_ablation( + spd_model, + target_model, + input_ids, + parsed_components, + ablation_mode, + n_mask_samples, + pgd_steps, + pgd_step_size, + ablation_pos, + ) + + if i < max_plot_samples: + # Per-sample attention plots + plot_attention_grid( + result.baseline_patterns, + f"{run_id} | Sample {i} baseline (pos={ablation_pos})", + attn_dir / f"baseline_sample{i}_{label}.png", + max_pos, + ) + plot_attention_grid( + result.ablated_patterns, + f"{run_id} | Sample {i} ablated (pos={ablation_pos})", + attn_dir / f"ablated_sample{i}_{label}.png", + max_pos, + ) + plot_attention_diff( + result.baseline_patterns, + result.ablated_patterns, + f"{run_id} | Sample {i} diff (pos={ablation_pos})", + attn_dir / f"diff_sample{i}_{label}.png", + max_pos, + ) + + # Per-sample value norm plots + plot_value_norms( + result.baseline_values, + f"{run_id} | Sample {i} value norms baseline", + value_dir / f"baseline_sample{i}_{label}.png", + max_pos, + ) + plot_value_norms( + result.ablated_values, + f"{run_id} | Sample {i} value norms ablated", + value_dir / f"ablated_sample{i}_{label}.png", + max_pos, + ) + plot_value_norms_diff( + result.baseline_values, + result.ablated_values, + f"{run_id} | Sample {i} value norms diff", + value_dir / f"diff_sample{i}_{label}.png", + max_pos, + ) + + # Per-sample per-position line plots (sanity check) + sample_ip, sample_cos = compute_ablation_metrics( + result.baseline_attn_outputs, result.ablated_attn_outputs + ) + plot_per_position_line( + sample_ip, + f"{run_id} | Sample {i} normalized IP (ablated pos={ablation_pos})", + sim_dir / f"normalized_ip_sample{i}_{label}.png", + max_pos, + baseline_y=1.0, + ylim=(-1, 1), + ) + plot_per_position_line( + sample_cos, + f"{run_id} | Sample {i} cosine sim (ablated pos={ablation_pos})", + sim_dir / f"cosine_sim_sample{i}_{label}.png", + max_pos, + baseline_y=1.0, + ylim=(-1, 1), + ) + + # Position-specific scalar measurement at ablated position + pos_ip, pos_cos = compute_ablation_metrics_at_pos( + result.baseline_attn_outputs, + result.ablated_attn_outputs, + ablation_pos, + ) + for layer_idx, val in pos_ip.items(): + stats.pos_ip_samples.setdefault(layer_idx, []).append(val) + for layer_idx, val in pos_cos.items(): + stats.pos_cos_samples.setdefault(layer_idx, []).append(val) + + # Per-sample prediction table + n_changed = log_prediction_table( + input_ids[0], result.baseline_logits[0], result.ablated_logits[0], tokenizer + ) + + # Accumulate stats + stats.total_changed += n_changed + stats.total_positions += sample_seq_len + stats.total_kl += calc_mean_kl_divergence( + result.baseline_logits[0], result.ablated_logits[0] + ) + stats.n_samples += 1 + + # Accumulate for mean plots + _add_patterns(accum_baseline_patterns, result.baseline_patterns) + _add_patterns(accum_ablated_patterns, result.ablated_patterns) + _add_patterns(accum_baseline_values, result.baseline_values) + _add_patterns(accum_ablated_values, result.ablated_values) + if (i + 1) % 5 == 0: + logger.info(f"Processed {i + 1}/{n_samples} samples") + + assert stats.n_samples > 0, "No samples processed" + + # Mean attention plots + mean_baseline_patterns = _scale_patterns(accum_baseline_patterns, stats.n_samples) + mean_ablated_patterns = _scale_patterns(accum_ablated_patterns, stats.n_samples) + + plot_attention_grid( + mean_baseline_patterns, + f"{run_id} | Baseline mean attention (n={stats.n_samples})", + attn_dir / f"baseline_mean_{label}.png", + max_pos, + ) + plot_attention_grid( + mean_ablated_patterns, + f"{run_id} | Ablated mean attention (n={stats.n_samples})", + attn_dir / f"ablated_mean_{label}.png", + max_pos, + ) + plot_attention_diff( + mean_baseline_patterns, + mean_ablated_patterns, + f"{run_id} | Attention diff mean (n={stats.n_samples})", + attn_dir / f"diff_mean_{label}.png", + max_pos, + ) + + # Mean value norm plots + mean_baseline_values = _scale_patterns(accum_baseline_values, stats.n_samples) + mean_ablated_values = _scale_patterns(accum_ablated_values, stats.n_samples) + + plot_value_norms( + mean_baseline_values, + f"{run_id} | Baseline mean value norms (n={stats.n_samples})", + value_dir / f"baseline_mean_{label}.png", + max_pos, + ) + plot_value_norms( + mean_ablated_values, + f"{run_id} | Ablated mean value norms (n={stats.n_samples})", + value_dir / f"ablated_mean_{label}.png", + max_pos, + ) + plot_value_norms_diff( + mean_baseline_values, + mean_ablated_values, + f"{run_id} | Value norms diff mean (n={stats.n_samples})", + value_dir / f"diff_mean_{label}.png", + max_pos, + ) + + # Position-specific bar charts (mean ± std across samples) + ip_means = {li: torch.tensor(vs).mean().item() for li, vs in stats.pos_ip_samples.items()} + ip_stds = {li: torch.tensor(vs).std().item() for li, vs in stats.pos_ip_samples.items()} + cos_means = {li: torch.tensor(vs).mean().item() for li, vs in stats.pos_cos_samples.items()} + cos_stds = {li: torch.tensor(vs).std().item() for li, vs in stats.pos_cos_samples.items()} + + plot_output_similarity_bars( + ip_means, + ip_stds, + f"{run_id} | Normalized IP at ablated pos (n={stats.n_samples})", + sim_dir / f"normalized_ip_bars_{label}.png", + ) + plot_output_similarity_bars( + cos_means, + cos_stds, + f"{run_id} | Cosine sim at ablated pos (n={stats.n_samples})", + sim_dir / f"cosine_sim_bars_{label}.png", + ) + + # Summary stats + frac_changed = stats.total_changed / stats.total_positions + mean_kl = stats.total_kl / stats.n_samples + logger.section("Summary") + logger.values( + { + "n_samples": stats.n_samples, + "frac_top1_changed": f"{frac_changed:.4f}", + "mean_kl_divergence": f"{mean_kl:.6f}", + } + ) + + logger.section("Position-specific similarity at ablated position") + for layer_idx in sorted(ip_means.keys()): + logger.info( + f" Layer {layer_idx}: " + f"NIP = {ip_means[layer_idx]:.4f} ± {ip_stds[layer_idx]:.4f}, " + f"cos = {cos_means[layer_idx]:.4f} ± {cos_stds[layer_idx]:.4f}" + ) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(run_attention_ablation) diff --git a/spd/scripts/attention_ablation_experiment/generate_with_ablation.py b/spd/scripts/attention_ablation_experiment/generate_with_ablation.py new file mode 100644 index 000000000..85bd5643a --- /dev/null +++ b/spd/scripts/attention_ablation_experiment/generate_with_ablation.py @@ -0,0 +1,524 @@ +"""Generate text completions with and without ablation at a single position. + +Produces an HTML comparison table with token-level alignment and color-coding. +Each sample picks a random position t, truncates the prompt to t+1 tokens, +then generates greedily. All ablations apply on the first forward pass only +(one predicted token). Subsequent tokens are generated without ablation but +conditioned on the (potentially different) first token. + +Usage: + python -m spd.scripts.attention_ablation_experiment.generate_with_ablation \ + wandb:goodfire/spd/runs/s-275c8f21 \ + --comp_sets '{"2c": "h.1.attn.q_proj:279,h.1.attn.k_proj:177"}' \ + --heads L1H1 --restrict_to_heads L1H1 \ + --n_samples 40 --prompt_len 16 --gen_len 24 +""" + +import html +import json +import random +from pathlib import Path +from typing import Any + +import fire +import torch +from jaxtyping import Int +from torch import Tensor + +from spd.configs import LMTaskConfig +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import ComponentsMaskInfo, make_mask_infos +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.attention_ablation_experiment.attention_ablation_experiment import ( + _build_component_head_ablations, + _build_deterministic_masks_multi_pos, + _build_prev_token_component_positions, + parse_components, + parse_heads, + patched_attention_forward, +) +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent + +CRAFTED_PROMPTS = [ + # Phrases where prev token strongly predicts next + ("Once upon a", "Phrase"), + ("The United States of", "Bigram"), + ("Thank you very", "Phrase"), + ("Dear Sir or", "Phrase"), + ("black and", "Phrase"), + ("war and", "Phrase"), + ("the king and", "Phrase"), + ("the end of the", "Phrase"), + ("open the", "Phrase"), + ("from A to", "Phrase"), + ("ready, set,", "Phrase"), + ("less than", "Comparison"), + # Prev token = key context for next word + ("New York", "Place"), + ("he said she", "Narrative"), + ("What is your", "Question"), + ("input and", "Phrase"), + ("north south east", "Directions"), + # Code: prev token determines syntax + ("import numpy as", "Code"), + ("if x ==", "Code"), + ("return self.", "Code"), + ("def f(x):", "Code"), + ("is not", "Code"), + ("for i in", "Code"), + ("x = x +", "Code"), + # Sequences / repetition: prev token predicts pattern + ("2 + 2 =", "Math"), + ("10, 20, 30,", "Counting"), + ("A B C D E F", "Alphabet"), + ("dog cat dog cat dog", "Repetition"), + ("red blue red blue red", "Repetition"), + ("yes or no? yes or", "Repetition"), + ("1 2 3 4 5 6 7", "Counting"), + ("mon tue wed thu", "Days"), + # Structured: prev token signals format + ("", "HTML"), + ("http://www.", "URL"), + ("rock, paper,", "Game"), + # Bigrams where the pair is a fixed expression + ("pro and", "Phrase"), + ("trial and", "Phrase"), + ("more or", "Phrase"), + ("sooner or", "Phrase"), + ("back and", "Phrase"), +] + + +# ────────────────────────────────────────────────────────────────────────────── +# Generation +# ────────────────────────────────────────────────────────────────────────────── + + +def _build_baseline_mask_infos( + spd_model: ComponentModel, + device: torch.device, +) -> dict[str, ComponentsMaskInfo]: + """All-ones masks so the SPD model uses component reconstruction (not target passthrough).""" + masks = {name: torch.ones(1, c, device=device) for name, c in spd_model.module_to_c.items()} + return make_mask_infos(masks) + + +class Prediction: + __slots__ = ("token_id", "logits") + + def __init__(self, token_id: int, logits: Tensor): + self.token_id = token_id + self.logits = logits + + +def _predict_next_token( + target_model: LlamaSimpleMLP, + prompt_ids: Int[Tensor, "1 prompt_len"], + **ablation_kwargs: Any, +) -> Prediction: + """Run one forward pass with ablation and return the greedy next token + logits.""" + spd_model: ComponentModel | None = ablation_kwargs.pop("spd_model", None) + mask_infos: dict[str, ComponentsMaskInfo] | None = ablation_kwargs.pop("mask_infos", None) + + with patched_attention_forward(target_model, **ablation_kwargs): + if spd_model is not None: + baseline = _build_baseline_mask_infos(spd_model, prompt_ids.device) + out = spd_model(prompt_ids, mask_infos=mask_infos or baseline) + assert isinstance(out, Tensor) + logits = out + else: + logits, _ = target_model(prompt_ids) + assert logits is not None + + last_logits = logits[0, -1].detach().cpu() + return Prediction(int(last_logits.argmax().item()), last_logits) + + +# ────────────────────────────────────────────────────────────────────────────── +# Condition definitions +# ────────────────────────────────────────────────────────────────────────────── + + +def _head_label(heads: list[tuple[int, int]]) -> str: + return ",".join(f"L{ly}H{hd}" for ly, hd in heads) + + +ConditionResult = tuple[str, Prediction, str] # (name, prediction, baseline_name) + + +def _build_conditions( + target_model: LlamaSimpleMLP, + spd_model: ComponentModel, + prompt_ids: Int[Tensor, "1 seq_len"], + t: int, + parsed_heads: list[tuple[int, int]], + comp_sets: dict[str, list[tuple[str, int]]], + parsed_restrict_heads: list[tuple[int, int]], + n_layers: int, +) -> list[ConditionResult]: + """Run all conditions and return (name, prediction, baseline_name) triples.""" + assert t >= 1, f"t must be >= 1, got {t}" + seq_len = prompt_ids.shape[1] + conditions: list[ConditionResult] = [] + TARGET = "Target model" + SPD = "SPD baseline" + + def predict(**kwargs: Any) -> Prediction: + return _predict_next_token(target_model, prompt_ids, **kwargs) + + # --- Baselines --- + conditions.append((TARGET, predict(), TARGET)) + conditions.append((SPD, predict(spd_model=spd_model), SPD)) + + # --- Head ablation: zero head output at t --- + if parsed_heads: + head_abl = [(layer, head, t) for layer, head in parsed_heads] + conditions.append( + ( + f"Head ablated ({_head_label(parsed_heads)})", + predict(head_pos_ablations=head_abl), + TARGET, + ) + ) + + # --- Value ablations: zero values at specific positions --- + # Layer derived from parsed_heads (tests whether the head's layer uses values from t-1). + if parsed_heads: + val_layer = parsed_heads[0][0] + hl = _head_label(parsed_heads) + + conditions.append( + ( + f"Vals @t-1 (all heads, L{val_layer})", + predict(value_pos_ablations=[(val_layer, t - 1)]), + TARGET, + ) + ) + conditions.append( + ( + f"Vals @t-1 ({hl})", + predict(value_head_pos_ablations=[(ly, hd, t - 1) for ly, hd in parsed_heads]), + TARGET, + ) + ) + if t >= 2: + conditions.append( + ( + f"Vals @t-1,t-2 (all heads, L{val_layer})", + predict(value_pos_ablations=[(val_layer, t - 1), (val_layer, t - 2)]), + TARGET, + ) + ) + conditions.append( + ( + f"Vals @all prev (all heads, L{val_layer})", + predict(value_pos_ablations=[(val_layer, p) for p in range(seq_len)]), + TARGET, + ) + ) + + conditions.append( + ( + "Vals @all prev (ALL layers)", + predict( + value_pos_ablations=[(ly, p) for ly in range(n_layers) for p in range(seq_len)] + ), + TARGET, + ) + ) + + # --- Component ablations --- + # Full: zero component masks at t (q) / t-1 (k), affects all heads + # Per-head: subtract component contribution from restrict_heads' rows only + for set_name, comps in comp_sets.items(): + cp = _build_prev_token_component_positions(comps, t) + bs = (prompt_ids.shape[0], prompt_ids.shape[1]) + _, ablated_masks = _build_deterministic_masks_multi_pos( + spd_model, cp, bs, prompt_ids.device + ) + conditions.append( + ( + f"Full comp ({set_name})", + predict(spd_model=spd_model, mask_infos=ablated_masks), + SPD, + ) + ) + if parsed_restrict_heads: + cha = _build_component_head_ablations(spd_model, comps, parsed_restrict_heads, t) + conditions.append( + ( + f"Per-head {_head_label(parsed_restrict_heads)} ({set_name})", + predict(spd_model=spd_model, component_head_ablations=cha), + SPD, + ) + ) + + return conditions + + +# ────────────────────────────────────────────────────────────────────────────── +# HTML rendering +# ────────────────────────────────────────────────────────────────────────────── + +HTML_HEADER = """\ + +""" + + +def _render_sample_html( + prompt_tokens: list[str], + conditions: list[ConditionResult], + t: int, + label: str, + decode_tok: Any, + top_k: int = 20, +) -> str: + # Build lookup from condition name to logits for baseline resolution + logits_by_name: dict[str, Tensor] = {} + for name, pred, _baseline in conditions: + logits_by_name[name] = pred.logits + + ref_tok = decode_tok([conditions[0][1].token_id]) + + def _fmt_tok(tok: str) -> str: + return html.escape(tok).replace("\n", "\\n").replace(" ", " ") + + h: list[str] = [] + token_spans = [] + for i, tok in enumerate(prompt_tokens): + escaped = _fmt_tok(tok) + css = "prompt-tok-abl" if i == t else ("prompt-tok-prev" if i == t - 1 else "prompt-tok") + token_spans.append(f'{escaped}') + prompt_html = '|'.join(token_spans) + h.append( + f'

{html.escape(label)} | t={t}

' + f'
{prompt_html}
' + ) + + h.append('
') + + # Header: label, predicted, baseline change, k increase cols, k decrease cols + h.append("") + for j in range(top_k): + h.append(f"") + for j in range(top_k): + h.append(f"") + h.append("") + + def _logit_cell(tok_id: int, val: float, positive: bool) -> str: + tok = _fmt_tok(decode_tok([tok_id])) + css_class = "logit-pos" if positive else "logit-neg" + return f'' + + for name, pred, baseline_name in conditions: + tok = decode_tok([pred.token_id]) + css = "match" if tok == ref_tok else "diff" + + if name == baseline_name: + empty_cells = '' * (1 + 2 * top_k) + h.append( + f'' + f'' + f"{empty_cells}" + ) + continue + + # Logit diff vs appropriate baseline + base_logits = logits_by_name[baseline_name] + diff = pred.logits - base_logits + + # Change in baseline's predicted token logit + base_pred_id = int(base_logits.argmax().item()) + base_pred_tok = _fmt_tok(decode_tok([base_pred_id])) + base_pred_change = diff[base_pred_id].item() + change_css = "logit-neg" if base_pred_change < 0 else "logit-pos" + + # Top-k increases and decreases + top_inc_vals, top_inc_ids = diff.topk(top_k) + top_dec_vals, top_dec_ids = (-diff).topk(top_k) + + row = f'' + row += f'' + row += ( + f'' + ) + for j in range(top_k): + row += _logit_cell(int(top_inc_ids[j].item()), top_inc_vals[j].item(), True) + for j in range(top_k): + row += _logit_cell(int(top_dec_ids[j].item()), -top_dec_vals[j].item(), False) + row += "" + h.append(row) + + h.append("
predictedbase tok
logit Δ
inc {j + 1}dec {j + 1}
{tok}
{val:+.1f}
-
{html.escape(name)}{_fmt_tok(tok)}
{html.escape(name)}{_fmt_tok(tok)}{base_pred_tok} ' + f'{base_pred_change:+.1f}
") + return "\n".join(h) + + +# ────────────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────────────── + + +def generate_with_ablation( + wandb_path: ModelPath, + comp_sets: str | dict[str, str] | None = None, + heads: str | None = None, + restrict_to_heads: str | None = None, + n_samples: int = 40, + prompt_len: int = 16, + include_crafted: bool = True, + seed: int = 42, +) -> None: + """Generate comparison HTML with multiple ablation conditions. + + Args: + comp_sets: JSON dict mapping set names to component specs, e.g. + '{"2c": "h.1.attn.q_proj:279,h.1.attn.k_proj:177"}' + heads: Head spec for head ablation, e.g. "L1H1" + restrict_to_heads: Head spec for per-head component ablation + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + + parsed_comp_sets: dict[str, list[tuple[str, int]]] = {} + if comp_sets is not None: + raw: dict[str, str] = json.loads(comp_sets) if isinstance(comp_sets, str) else comp_sets + for name, spec in raw.items(): + parsed_comp_sets[name] = parse_components(spec) + + parsed_heads = parse_heads(heads) if heads else [] + parsed_restrict_heads = parse_heads(restrict_to_heads) if restrict_to_heads else [] + + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + config = run_info.config + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + spd_model = ComponentModel.from_run_info(run_info) + spd_model.eval() + spd_model = spd_model.to(device) + target_model = spd_model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + for block in target_model._h: + block.attn.flash_attention = False + n_layers = len(target_model._h) + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, tokenizer = create_data_loader( + dataset_config=dataset_config, batch_size=1, buffer_size=1000 + ) + encode = tokenizer.encode + decode_tok = tokenizer.decode # pyright: ignore[reportAttributeAccessIssue] + + out_dir = SCRIPT_DIR / "out" / run_id / "generations" + out_dir.mkdir(parents=True, exist_ok=True) + + all_tables: list[str] = [] + + def run_sample(prompt_ids: Tensor, t: int, label: str) -> None: + prompt_tokens = [decode_tok([tid]) for tid in prompt_ids[0].tolist()] + conditions = _build_conditions( + target_model, + spd_model, + prompt_ids, + t, + parsed_heads, + parsed_comp_sets, + parsed_restrict_heads, + n_layers, + ) + all_tables.append(_render_sample_html(prompt_tokens, conditions, t, label, decode_tok)) + + with torch.no_grad(): + # Dataset samples: take first prompt_len tokens, pick random t, truncate to t+1 + n_collected = 0 + for i, batch_data in enumerate(loader): + if n_collected >= n_samples: + break + input_ids: Int[Tensor, "1 seq"] = batch_data[task_config.column_name][ + :, :prompt_len + ].to(device) + # Skip non-ASCII samples (non-English text) + text = decode_tok(input_ids[0].tolist()) + if not text.isascii(): + continue + rng = random.Random(i) + t = rng.randint(1, min(input_ids.shape[1], prompt_len) - 1) + run_sample(input_ids[:, : t + 1], t, f"Dataset sample {i}") + n_collected += 1 + if n_collected % 10 == 0: + logger.info(f"Dataset: {n_collected}/{n_samples}") + + # Crafted prompts: use full text, ablate at last token + if include_crafted: + for idx, (text, desc) in enumerate(CRAFTED_PROMPTS): + token_ids = encode(text) + ids_list: list[int] = ( + token_ids if isinstance(token_ids, list) else token_ids.ids # pyright: ignore[reportAttributeAccessIssue] + ) + ids_tensor = torch.tensor([ids_list], device=device) + run_sample(ids_tensor, ids_tensor.shape[1] - 1, f"Crafted: {desc}") + if (idx + 1) % 10 == 0: + logger.info(f"Crafted: {idx + 1}/{len(CRAFTED_PROMPTS)}") + + # Write HTML + comp_desc = ", ".join( + f"{name} ({len(comps)})" for name, comps in parsed_comp_sets.items() + ) + html_parts = [ + HTML_HEADER, + "

Generation Comparison: Ablation Effects

", + f'

Model: {run_id} | {n_layers} layers

', + f'

Component sets: {comp_desc}

' if comp_desc else "", + '

All ablations apply on the first generated token only. ' + "Subsequent tokens generated normally. Green = matches target. Red = differs.

", + *all_tables, + "", + ] + html_path = out_dir / "comparison.html" + html_path.write_text("\n".join(html_parts)) + logger.info(f"Saved {html_path} ({len(all_tables)} samples)") + + +if __name__ == "__main__": + fire.Fire(generate_with_ablation) diff --git a/spd/scripts/attention_ablation_experiment/plot_attn_pattern_diffs.py b/spd/scripts/attention_ablation_experiment/plot_attn_pattern_diffs.py new file mode 100644 index 000000000..4f0cf0e2f --- /dev/null +++ b/spd/scripts/attention_ablation_experiment/plot_attn_pattern_diffs.py @@ -0,0 +1,277 @@ +"""Plot attention pattern changes from component ablation across heads. + +Compares four conditions at layer 1: + - Target model baseline + - SPD model baseline (all-ones masks) + - Full component ablation (q/k components zeroed at t/t-1) + - Per-head component ablation (restricted to specific heads) + +Produces two plots: + - Raw attention distributions at query position t, averaged over samples + - Attention differences (ablated - SPD baseline) + +Usage: + python -m spd.scripts.attention_ablation_experiment.plot_attn_pattern_diffs \ + wandb:goodfire/spd/runs/s-275c8f21 \ + --components "h.1.attn.q_proj:279,h.1.attn.k_proj:177" \ + --restrict_to_heads L1H1 \ + --n_samples 1024 +""" + +import random +from pathlib import Path + +import fire +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Int +from torch import Tensor + +from spd.configs import LMTaskConfig +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.attention_ablation_experiment.attention_ablation_experiment import ( + _build_component_head_ablations, + _build_deterministic_masks_multi_pos, + _build_prev_token_component_positions, + _infer_layer_from_components, + parse_components, + parse_heads, + patched_attention_forward, +) +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +matplotlib.use("Agg") + +SCRIPT_DIR = Path(__file__).parent + + +def plot_attn_pattern_diffs( + wandb_path: ModelPath, + components: str, + restrict_to_heads: str, + n_samples: int = 1024, + max_offset_show: int = 20, + seed: int = 42, +) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + + parsed_components = parse_components(components) + parsed_restrict_heads = parse_heads(restrict_to_heads) + layer = _infer_layer_from_components(parsed_components) + + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + config = run_info.config + + spd_model = ComponentModel.from_run_info(run_info) + spd_model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + spd_model = spd_model.to(device) + target_model = spd_model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + for block in target_model._h: + block.attn.flash_attention = False + + seq_len = target_model.config.n_ctx + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, _ = create_data_loader(dataset_config=dataset_config, batch_size=1, buffer_size=1000) + + out_dir = SCRIPT_DIR / "out" / run_id / "attn_pattern_diffs" + out_dir.mkdir(parents=True, exist_ok=True) + + n_heads = target_model.config.n_head + conditions = ["target_baseline", "spd_baseline", "full_comp", "perhead_comp"] + accum: dict[str, dict[int, dict[int, list[float]]]] = { + c: {h: {o: [] for o in range(max_offset_show + 1)} for h in range(n_heads)} + for c in conditions + } + + restrict_label = "_".join(f"L{ly}H{hd}" for ly, hd in parsed_restrict_heads) + logger.section(f"Attention pattern diffs (n={n_samples}, restrict={restrict_label})") + + with torch.no_grad(): + for i, batch_data in enumerate(loader): + if i >= n_samples: + break + input_ids: Int[Tensor, "batch pos"] = batch_data[task_config.column_name][ + :, :seq_len + ].to(device) + + sample_seq_len = input_ids.shape[1] + rng = random.Random(i) + t = rng.randint(max_offset_show, min(sample_seq_len, 128) - 1) + + bs = (input_ids.shape[0], input_ids.shape[1]) + cp = _build_prev_token_component_positions(parsed_components, t) + baseline_masks, full_ablated_masks = _build_deterministic_masks_multi_pos( + spd_model, cp, bs, input_ids.device + ) + comp_head_abls = _build_component_head_ablations( + spd_model, parsed_components, parsed_restrict_heads, t + ) + + with patched_attention_forward(target_model) as d: + target_model(input_ids) + target_pat = d.patterns + + with patched_attention_forward(target_model) as d: + spd_model(input_ids, mask_infos=baseline_masks) + spd_pat = d.patterns + + with patched_attention_forward(target_model) as d: + spd_model(input_ids, mask_infos=full_ablated_masks) + full_pat = d.patterns + + with patched_attention_forward( + target_model, component_head_ablations=comp_head_abls + ) as d: + spd_model(input_ids, mask_infos=baseline_masks) + perhead_pat = d.patterns + + pats = { + "target_baseline": target_pat, + "spd_baseline": spd_pat, + "full_comp": full_pat, + "perhead_comp": perhead_pat, + } + for cond, pat in pats.items(): + for h in range(n_heads): + for o in range(max_offset_show + 1): + kp = t - o + if kp >= 0: + accum[cond][h][o].append(pat[layer][h, t, kp].item()) + + if (i + 1) % 100 == 0: + logger.info(f"Processed {i + 1}/{n_samples}") + + offsets = list(range(max_offset_show + 1)) + + # --- Plot 1: Raw attention values --- + styles = { + "target_baseline": ("k", "-", 1.5, "Target baseline"), + "spd_baseline": ("b", "-", 1.5, "SPD baseline"), + "full_comp": ("r", "-", 1.5, "Full comp ablation"), + "perhead_comp": ("g", "--", 1.5, f"Per-head comp ({restrict_label})"), + } + + all_means = [ + np.mean(accum[c][h][o]) for c in conditions for h in range(n_heads) for o in offsets + ] + raw_ymax = max(all_means) * 1.1 + + fig, axes = plt.subplots(n_heads, 1, figsize=(14, n_heads * 2.5), squeeze=False) + for h in range(n_heads): + ax = axes[h, 0] + for cond, (color, ls, lw, label) in styles.items(): + means = [np.mean(accum[cond][h][o]) for o in offsets] + stds = [np.std(accum[cond][h][o]) for o in offsets] + ax.plot(offsets, means, color=color, linestyle=ls, linewidth=lw, label=label) + ax.fill_between( + offsets, + [m - s for m, s in zip(means, stds, strict=True)], + [m + s for m, s in zip(means, stds, strict=True)], + alpha=0.1, + color=color, + ) + ax.set_ylim(-0.02, raw_ymax) + ax.set_ylabel(f"H{h}", fontsize=10, fontweight="bold") + ax.set_xlim(-0.5, max_offset_show + 0.5) + ax.set_xticks(offsets) + if h == 0: + ax.legend(fontsize=7, loc="upper right") + if h == n_heads - 1: + ax.set_xlabel("Offset from query position", fontsize=9) + + fig.suptitle( + f"Layer {layer} mean attention at query pos t (n={n_samples})", + fontsize=14, + fontweight="bold", + ) + fig.tight_layout() + path = out_dir / f"attn_dist_mean_n{n_samples}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + # --- Plot 2: Differences from SPD baseline --- + diff_styles = { + "full_comp": ("r", "-", 1.5, "Full comp - SPD baseline"), + "perhead_comp": ("g", "--", 1.5, "Per-head comp - SPD baseline"), + } + + all_diff_means = [] + for cond in ["full_comp", "perhead_comp"]: + for h in range(n_heads): + for o in offsets: + diffs = [ + a - b + for a, b in zip(accum[cond][h][o], accum["spd_baseline"][h][o], strict=True) + ] + all_diff_means.append(np.mean(diffs)) + diff_ymin = min(all_diff_means) * 1.15 + diff_ymax = max(max(all_diff_means) * 1.15, 0.05) + + fig, axes = plt.subplots(n_heads, 1, figsize=(14, n_heads * 2.5), squeeze=False) + for h in range(n_heads): + ax = axes[h, 0] + for cond, (color, ls, lw, label) in diff_styles.items(): + diffs_by_offset = [] + for o in offsets: + sample_diffs = [ + a - b + for a, b in zip(accum[cond][h][o], accum["spd_baseline"][h][o], strict=True) + ] + diffs_by_offset.append(sample_diffs) + means = [np.mean(d) for d in diffs_by_offset] + stds = [np.std(d) for d in diffs_by_offset] + ax.plot(offsets, means, color=color, linestyle=ls, linewidth=lw, label=label) + ax.fill_between( + offsets, + [m - s for m, s in zip(means, stds, strict=True)], + [m + s for m, s in zip(means, stds, strict=True)], + alpha=0.15, + color=color, + ) + ax.axhline(y=0, color="gray", linewidth=0.5, linestyle=":") + ax.set_ylim(diff_ymin, diff_ymax) + ax.set_ylabel(f"H{h}", fontsize=10, fontweight="bold") + ax.set_xlim(-0.5, max_offset_show + 0.5) + ax.set_xticks(offsets) + if h == 0: + ax.legend(fontsize=7, loc="upper right") + if h == n_heads - 1: + ax.set_xlabel("Offset from query position", fontsize=9) + + fig.suptitle( + f"Layer {layer} attention change from ablation (n={n_samples})", + fontsize=14, + fontweight="bold", + ) + fig.tight_layout() + path = out_dir / f"attn_diff_mean_n{n_samples}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +if __name__ == "__main__": + fire.Fire(plot_attn_pattern_diffs) diff --git a/spd/scripts/attention_stories/__init__.py b/spd/scripts/attention_stories/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/attention_stories/attention_stories.py b/spd/scripts/attention_stories/attention_stories.py new file mode 100644 index 000000000..c4f40d8c3 --- /dev/null +++ b/spd/scripts/attention_stories/attention_stories.py @@ -0,0 +1,708 @@ +"""Generate PDF reports tracing Q -> K -> V attention component interaction chains. + +For each layer, produces a multi-page PDF: + - Page 1: Layer overview with Q->K attention contributions at multiple RoPE offsets + and K->V CI co-occurrence heatmaps + - Pages 2+: Individual Q component "stories" with positive and negative attention separated + into two columns, showing which K components the Q looks for vs avoids, and what V + information those K components carry forward + +The Q->K attention contribution is a weight-only measure (V-norm-scaled U dot products +with RoPE applied at specified relative position offsets, summed across heads). The K->V +association uses CI co-occurrence counts (number of tokens where both components are +causally important). + +Usage: + python -m spd.scripts.attention_stories.attention_stories \ + wandb:goodfire/spd/runs/ +""" + +import textwrap +from dataclasses import dataclass +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.backends.backend_pdf import PdfPages +from matplotlib.gridspec import GridSpec +from numpy.typing import NDArray + +from spd.autointerp.repo import InterpRepo +from spd.autointerp.schemas import InterpretationResult +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.harvest.storage import CorrelationStorage +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.rope_aware_qk import compute_qk_rope_coefficients, evaluate_qk_at_offsets +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +MIN_MEAN_CI = 0.01 +N_STORIES_PER_LAYER = 10 +TOP_K_PER_SIDE = 5 +TOP_V_PER_K = 3 +N_K_TEXT_PER_SIDE = 3 +TEXT_WRAP_WIDTH = 75 +LINE_HEIGHT = 0.013 +STORY_OFFSETS = [0, 1, 2, 4, 8] + + +@dataclass +class ComponentInfo: + idx: int + causal_importance: float + label: str | None + reasoning: str | None + + +@dataclass +class KPartner: + info: ComponentInfo + attention_contribution: float # peak W(Δ) value (at offset with max |W|) + contributions_by_offset: list[tuple[int, float]] # (offset, value) for each offset + v_partners: list[tuple[ComponentInfo, float]] # (v_info, cooccurrence_count) + + +def _get_alive_indices(summary: dict[str, ComponentSummary], module_path: str) -> list[int]: + """Return component indices sorted by CI descending, filtered to alive.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > MIN_MEAN_CI + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _get_component_info( + component_idx: int, + module_path: str, + summary: dict[str, ComponentSummary], + interp: dict[str, InterpretationResult], +) -> ComponentInfo: + key = f"{module_path}:{component_idx}" + ci = summary[key].mean_activations["causal_importance"] + result = interp.get(key) + return ComponentInfo( + idx=component_idx, + causal_importance=ci, + label=result.label if result else None, + reasoning=result.reasoning if result else None, + ) + + +def _compute_attention_contributions( + q_component: LinearComponents, + k_component: LinearComponents, + q_alive: list[int], + k_alive: list[int], + n_q_heads: int, + n_kv_heads: int, + head_dim: int, + rotary_cos: torch.Tensor, + rotary_sin: torch.Tensor, +) -> NDArray[np.floating]: + """Compute (n_offsets, n_q_alive, n_k_alive) summed attention contributions at each offset. + + V-norm-scaled U dot products with RoPE at STORY_OFFSETS, summed across heads. + """ + V_q_norms = torch.linalg.norm(q_component.V[:, q_alive], dim=0).float() + V_k_norms = torch.linalg.norm(k_component.V[:, k_alive], dim=0).float() + + U_q = q_component.U[q_alive].float() * V_q_norms[:, None] + U_q = U_q.reshape(len(q_alive), n_q_heads, head_dim) + + U_k = k_component.U[k_alive].float() * V_k_norms[:, None] + U_k = U_k.reshape(len(k_alive), n_kv_heads, head_dim) + + g = n_q_heads // n_kv_heads + U_k_expanded = U_k.repeat_interleave(g, dim=1) + + head_results = [] + for h in range(n_q_heads): + A, B = compute_qk_rope_coefficients(U_q[:, h, :], U_k_expanded[:, h, :]) + W_h = evaluate_qk_at_offsets(A, B, rotary_cos, rotary_sin, STORY_OFFSETS, head_dim) + head_results.append(W_h) # (n_offsets, n_q, n_k) + + # (n_heads, n_offsets, n_q, n_k) -> sum across heads -> (n_offsets, n_q, n_k) + return torch.stack(head_results).sum(dim=0).cpu().numpy() + + +def _compute_cooccurrence_matrix( + corr: CorrelationStorage, + k_path: str, + v_path: str, + k_alive: list[int], + v_alive: list[int], +) -> NDArray[np.floating]: + """Compute (n_v_alive, n_k_alive) CI co-occurrence count matrix.""" + k_corr_idx = [corr.key_to_idx[f"{k_path}:{idx}"] for idx in k_alive] + v_corr_idx = [corr.key_to_idx[f"{v_path}:{idx}"] for idx in v_alive] + + k_idx = torch.tensor(k_corr_idx) + v_idx = torch.tensor(v_corr_idx) + return corr.count_ij[v_idx[:, None], k_idx[None, :]].float().numpy() + + +# -- Page renderers ----------------------------------------------------------- + + +def _render_overview_page( + pdf: PdfPages, + W_by_offset: NDArray[np.floating], + cooccur: NDArray[np.floating] | None, + q_alive: list[int], + k_alive: list[int], + v_alive: list[int], + layer_idx: int, + run_id: str, +) -> None: + overview_offsets = STORY_OFFSETS[:2] # Show Δ=0 and Δ=1 + n_qk = len(overview_offsets) + n_panels = n_qk + (1 if cooccur is not None else 0) + fig, axes = plt.subplots(1, n_panels, figsize=(7 * n_panels, 8.5), squeeze=False) + + # Shared color scale across QK panels + qk_vmax = float(max(np.abs(W_by_offset[idx]).max() for idx in range(n_qk))) or 1.0 + + for panel_idx in range(n_qk): + ax = axes[0, panel_idx] + W = W_by_offset[panel_idx] + im = ax.imshow(W, aspect="auto", cmap="RdBu_r", vmin=-qk_vmax, vmax=qk_vmax) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) + ax.set_title(f"Q\u2192K attention (\u0394={overview_offsets[panel_idx]})", fontsize=11) + ax.set_xlabel("K component") + ax.set_ylabel("Q component") + ax.set_xticks(range(len(k_alive))) + ax.set_xticklabels([f"C{idx}" for idx in k_alive], fontsize=5, rotation=90) + ax.set_yticks(range(len(q_alive))) + ax.set_yticklabels([f"C{idx}" for idx in q_alive], fontsize=5) + + # K->V CI co-occurrence + if cooccur is not None: + ax_kv = axes[0, n_qk] + im2 = ax_kv.imshow(cooccur, aspect="auto", cmap="Purples", vmin=0) + fig.colorbar(im2, ax=ax_kv, shrink=0.8, pad=0.02, label="CI co-occurrence count") + ax_kv.set_title("K\u2192V CI co-occurrence", fontsize=11) + ax_kv.set_xlabel("K component") + ax_kv.set_ylabel("V component") + ax_kv.set_xticks(range(len(k_alive))) + ax_kv.set_xticklabels([f"C{idx}" for idx in k_alive], fontsize=5, rotation=90) + ax_kv.set_yticks(range(len(v_alive))) + ax_kv.set_yticklabels([f"C{idx}" for idx in v_alive], fontsize=5) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 Overview (ci>{MIN_MEAN_CI})\n" + f"Q: {len(q_alive)} alive | K: {len(k_alive)} alive | V: {len(v_alive)} alive", + fontsize=14, + fontweight="bold", + ) + fig.tight_layout(rect=(0, 0, 1, 0.93)) + pdf.savefig(fig) + plt.close(fig) + + +def _render_bar_chart( + ax: plt.Axes, + partners: list[KPartner], + color: str, + title: str, + labels_on_right: bool, +) -> None: + """Render a horizontal bar chart of K partners.""" + if not partners: + ax.set_visible(False) + return + + y_pos = np.arange(len(partners)) + values = [abs(kp.attention_contribution) for kp in partners] + + ax.barh(y_pos, values, color=color, height=0.6) + + ytick_labels = [f"K C{kp.info.idx}" for kp in partners] + ax.set_yticks(y_pos) + ax.set_yticklabels(ytick_labels, fontsize=7, fontweight="bold") + ax.invert_yaxis() + ax.set_xlabel("|attention contribution| (peak)", fontsize=7) + ax.set_title(title, fontsize=10, fontweight="bold") + ax.tick_params(axis="x", labelsize=7) + + if labels_on_right: + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + + +def _render_kv_text( + ax: plt.Axes, + partners: list[KPartner], +) -> None: + """Render K->V text with bold headers and regular reasoning.""" + ax.axis("off") + if not partners: + return + + y = 0.98 + for kp in partners[:N_K_TEXT_PER_SIDE]: + # K component header with multi-offset breakdown + offset_str = ", ".join(f"\u0394={d}: {v:+.2f}" for d, v in kp.contributions_by_offset) + k_header = f"K C{kp.info.idx} (ci={kp.info.causal_importance:.3f}) [{offset_str}]" + if kp.info.label: + k_header += f' \u2014 "{kp.info.label}"' + ax.text( + 0.01, + y, + k_header, + fontsize=7, + fontweight="bold", + fontfamily="monospace", + va="top", + transform=ax.transAxes, + ) + y -= LINE_HEIGHT + + # K reasoning (regular) + if kp.info.reasoning: + wrapped = textwrap.fill( + kp.info.reasoning, + width=TEXT_WRAP_WIDTH, + initial_indent=" ", + subsequent_indent=" ", + ) + n_lines = wrapped.count("\n") + 1 + ax.text( + 0.01, + y, + wrapped, + fontsize=6, + fontfamily="monospace", + va="top", + transform=ax.transAxes, + ) + y -= LINE_HEIGHT * n_lines + + # V partners + for v_info, count in kp.v_partners: + v_header = ( + f" \u2192 V C{v_info.idx} (co-occ={count:.0f}, ci={v_info.causal_importance:.3f})" + ) + if v_info.label: + v_header += f' \u2014 "{v_info.label}"' + ax.text( + 0.01, + y, + v_header, + fontsize=7, + fontweight="bold", + fontfamily="monospace", + va="top", + transform=ax.transAxes, + ) + y -= LINE_HEIGHT + + if v_info.reasoning: + wrapped = textwrap.fill( + v_info.reasoning, + width=TEXT_WRAP_WIDTH, + initial_indent=" ", + subsequent_indent=" ", + ) + n_lines = wrapped.count("\n") + 1 + ax.text( + 0.01, + y, + wrapped, + fontsize=6, + fontfamily="monospace", + va="top", + transform=ax.transAxes, + ) + y -= LINE_HEIGHT * n_lines + + y -= LINE_HEIGHT * 0.3 # gap between K blocks + + +def _render_story_page( + pdf: PdfPages, + q_info: ComponentInfo, + pos_partners: list[KPartner], + neg_partners: list[KPartner], + layer_idx: int, + run_id: str, +) -> None: + fig = plt.figure(figsize=(16, 18)) + gs = GridSpec( + 3, + 2, + figure=fig, + height_ratios=[0.08, 0.15, 0.77], + wspace=0.25, + hspace=0.15, + ) + + # -- Header (spans both columns) ----------------------------------------- + ax_header = fig.add_subplot(gs[0, :]) + ax_header.axis("off") + + header_line = f"Q Component C{q_info.idx} | ci = {q_info.causal_importance:.4f}" + if q_info.label: + header_line += f' | "{q_info.label}"' + ax_header.text( + 0.5, + 0.7, + header_line, + fontsize=12, + fontweight="bold", + ha="center", + va="center", + transform=ax_header.transAxes, + ) + if q_info.reasoning: + reasoning_wrapped = textwrap.fill(q_info.reasoning, width=130) + ax_header.text( + 0.5, + 0.15, + reasoning_wrapped, + fontsize=9, + ha="center", + va="center", + transform=ax_header.transAxes, + ) + + # -- Positive bar chart (left) -------------------------------------------- + ax_pos_bars = fig.add_subplot(gs[1, 0]) + _render_bar_chart( + ax_pos_bars, + pos_partners, + "#4477AA", + "Looks for (positive attention)", + labels_on_right=False, + ) + + # -- Negative bar chart (right) ------------------------------------------- + ax_neg_bars = fig.add_subplot(gs[1, 1]) + _render_bar_chart( + ax_neg_bars, + neg_partners, + "#CC6677", + "Avoids (negative attention)", + labels_on_right=True, + ) + + # -- Positive K->V text (left) -------------------------------------------- + ax_pos_text = fig.add_subplot(gs[2, 0]) + ax_pos_text.set_title("K \u2192 V associations", fontsize=9) + _render_kv_text(ax_pos_text, pos_partners) + + # -- Negative K->V text (right) ------------------------------------------- + ax_neg_text = fig.add_subplot(gs[2, 1]) + ax_neg_text.set_title("K \u2192 V associations", fontsize=9) + _render_kv_text(ax_neg_text, neg_partners) + + fig.suptitle( + f"{run_id} | Layer {layer_idx}", + fontsize=10, + fontstyle="italic", + ) + fig.subplots_adjust( + left=0.06, + right=0.94, + top=0.97, + bottom=0.02, + wspace=0.25, + hspace=0.12, + ) + pdf.savefig(fig) + plt.close(fig) + + +# -- Markdown output ---------------------------------------------------------- + + +def _md_component(info: ComponentInfo, prefix: str, extra: str = "") -> str: + line = f"**{prefix} C{info.idx}** (ci={info.causal_importance:.3f})" + if extra: + line += f" {extra}" + if info.label: + line += f' \u2014 **"{info.label}"**' + if info.reasoning: + line += f"\n {info.reasoning}" + return line + + +def _md_k_partners(partners: list[KPartner], section_title: str) -> str: + if not partners: + return f"### {section_title}\n\n(none)\n" + lines = [f"### {section_title}\n"] + for kp in partners: + offset_str = " ".join(f"\u0394={d}: {v:+.3f}" for d, v in kp.contributions_by_offset) + lines.append(_md_component(kp.info, "K", extra=f"[{offset_str}]")) + if kp.v_partners: + for v_info, count in kp.v_partners: + lines.append( + " " + _md_component(v_info, "\u2192 V", extra=f"(co-occ={count:.0f})") + ) + else: + lines.append(" (no strong V associations)") + lines.append("") + return "\n".join(lines) + + +def _write_layer_markdown( + md_path: Path, + run_id: str, + layer_idx: int, + q_alive: list[int], + k_alive: list[int], + v_alive: list[int], + stories: list[tuple[ComponentInfo, list[KPartner], list[KPartner]]], +) -> None: + lines = [ + f"# {run_id} \u2014 Layer {layer_idx} Attention Stories\n", + f"Q: {len(q_alive)} alive | K: {len(k_alive)} alive | V: {len(v_alive)} alive" + f" (ci>{MIN_MEAN_CI})\n", + "---\n", + ] + + for q_info, pos_partners, neg_partners in stories: + header = f"## Q Component C{q_info.idx} (ci={q_info.causal_importance:.4f})" + if q_info.label: + header += f' \u2014 "{q_info.label}"' + lines.append(header + "\n") + if q_info.reasoning: + lines.append(f"{q_info.reasoning}\n") + + lines.append(_md_k_partners(pos_partners, "Looks for (positive attention)")) + lines.append(_md_k_partners(neg_partners, "Avoids (negative attention)")) + lines.append("---\n") + + md_path.write_text("\n".join(lines)) + logger.info(f"Saved {md_path}") + + +# -- Main --------------------------------------------------------------------- + + +def _build_k_partners( + k_ranks: NDArray[np.integer], + peak_values: NDArray[np.floating], + q_offset_slice: NDArray[np.floating], + k_alive: list[int], + v_alive: list[int], + k_path: str, + v_path: str, + summary: dict[str, ComponentSummary], + interp: dict[str, InterpretationResult], + cooccur: NDArray[np.floating] | None, +) -> list[KPartner]: + partners: list[KPartner] = [] + for k_rank in k_ranks: + k_idx = k_alive[int(k_rank)] + k_info = _get_component_info(k_idx, k_path, summary, interp) + k_contrib = float(peak_values[int(k_rank)]) + contrib_by_offset = [ + (STORY_OFFSETS[i], float(q_offset_slice[i, int(k_rank)])) + for i in range(len(STORY_OFFSETS)) + ] + + v_partners: list[tuple[ComponentInfo, float]] = [] + if cooccur is not None and v_alive: + cooccur_col = cooccur[:, int(k_rank)] + top_v_ranks = np.argsort(-cooccur_col)[:TOP_V_PER_K] + for v_rank in top_v_ranks: + count = float(cooccur_col[int(v_rank)]) + if count <= 0: + continue + v_idx = v_alive[int(v_rank)] + v_info = _get_component_info(v_idx, v_path, summary, interp) + v_partners.append((v_info, count)) + + partners.append( + KPartner( + info=k_info, + attention_contribution=k_contrib, + contributions_by_offset=contrib_by_offset, + v_partners=v_partners, + ) + ) + return partners + + +def generate_attention_stories(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + corr = repo.get_correlations() + + # Autointerp data (optional) + interp: dict[str, InterpretationResult] = {} + interp_repo = InterpRepo.open(run_id) + if interp_repo is not None: + interp = interp_repo.get_all_interpretations() + logger.info(f"Loaded {len(interp)} autointerp interpretations") + else: + logger.info("No autointerp data found (labels will be omitted)") + + if corr is not None: + logger.info( + f"Loaded correlations: {len(corr.component_keys)} components, {corr.count_total} tokens" + ) + else: + logger.info("No correlation data found (K\u2192V associations will be omitted)") + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + blocks = target_model._h + assert not blocks[0].attn.rotary_adjacent_pairs, "RoPE math assumes non-adjacent pairs layout" + head_dim = blocks[0].attn.head_dim + n_q_heads = blocks[0].attn.n_head + n_kv_heads = blocks[0].attn.n_key_value_heads + n_layers = len(blocks) + logger.info( + f"Model: {n_layers} layers, head_dim={head_dim}, " + f"n_q_heads={n_q_heads}, n_kv_heads={n_kv_heads}" + ) + + with torch.no_grad(): + for layer_idx in range(n_layers): + q_path = f"h.{layer_idx}.attn.q_proj" + k_path = f"h.{layer_idx}.attn.k_proj" + v_path = f"h.{layer_idx}.attn.v_proj" + + q_alive = _get_alive_indices(summary, q_path) + k_alive = _get_alive_indices(summary, k_path) + v_alive = _get_alive_indices(summary, v_path) + logger.info(f"Layer {layer_idx}: Q={len(q_alive)}, K={len(k_alive)}, V={len(v_alive)}") + + if not q_alive or not k_alive: + logger.info(f"Layer {layer_idx}: skipping (no alive Q or K)") + continue + + q_component = model.components[q_path] + k_component = model.components[k_path] + assert isinstance(q_component, LinearComponents) + assert isinstance(k_component, LinearComponents) + + rotary_cos = blocks[layer_idx].attn.rotary_cos + rotary_sin = blocks[layer_idx].attn.rotary_sin + assert isinstance(rotary_cos, torch.Tensor) + assert isinstance(rotary_sin, torch.Tensor) + + W_by_offset = _compute_attention_contributions( + q_component, + k_component, + q_alive, + k_alive, + n_q_heads, + n_kv_heads, + head_dim, + rotary_cos, + rotary_sin, + ) + + cooccur: NDArray[np.floating] | None = None + if corr is not None and v_alive: + cooccur = _compute_cooccurrence_matrix(corr, k_path, v_path, k_alive, v_alive) + + # Build all stories for this layer + stories: list[tuple[ComponentInfo, list[KPartner], list[KPartner]]] = [] + for q_rank, q_idx in enumerate(q_alive[:N_STORIES_PER_LAYER]): + q_info = _get_component_info(q_idx, q_path, summary, interp) + + # (n_offsets, n_k_alive) for this Q component + q_offset_slice = W_by_offset[:, q_rank, :] + + # Rank K partners by peak |W(Δ)| across offsets + peak_offset_idx = np.argmax(np.abs(q_offset_slice), axis=0) # (n_k_alive,) + peak_values = q_offset_slice[peak_offset_idx, np.arange(q_offset_slice.shape[1])] + + pos_mask = peak_values > 0 + neg_mask = peak_values < 0 + + pos_ranks = np.where(pos_mask)[0] + pos_ranks = pos_ranks[np.argsort(-peak_values[pos_ranks])][:TOP_K_PER_SIDE] + + neg_ranks = np.where(neg_mask)[0] + neg_ranks = neg_ranks[np.argsort(peak_values[neg_ranks])][:TOP_K_PER_SIDE] + + pos_partners = _build_k_partners( + pos_ranks, + peak_values, + q_offset_slice, + k_alive, + v_alive, + k_path, + v_path, + summary, + interp, + cooccur, + ) + neg_partners = _build_k_partners( + neg_ranks, + peak_values, + q_offset_slice, + k_alive, + v_alive, + k_path, + v_path, + summary, + interp, + cooccur, + ) + stories.append((q_info, pos_partners, neg_partners)) + + # Write PDF + pdf_path = out_dir / f"layer{layer_idx}.pdf" + with PdfPages(pdf_path) as pdf: + _render_overview_page( + pdf, + W_by_offset, + cooccur, + q_alive, + k_alive, + v_alive, + layer_idx, + run_id, + ) + for q_info, pos_partners, neg_partners in stories: + _render_story_page( + pdf, + q_info, + pos_partners, + neg_partners, + layer_idx, + run_id, + ) + logger.info(f"Saved {pdf_path}") + + # Write companion markdown + md_path = out_dir / f"layer{layer_idx}.md" + _write_layer_markdown( + md_path, + run_id, + layer_idx, + q_alive, + k_alive, + v_alive, + stories, + ) + + logger.info(f"All reports saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(generate_attention_stories) diff --git a/spd/scripts/characterize_induction_components/__init__.py b/spd/scripts/characterize_induction_components/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/spd/scripts/characterize_induction_components/__init__.py @@ -0,0 +1 @@ + diff --git a/spd/scripts/characterize_induction_components/characterize_induction_components.py b/spd/scripts/characterize_induction_components/characterize_induction_components.py new file mode 100644 index 000000000..0d99cb80a --- /dev/null +++ b/spd/scripts/characterize_induction_components/characterize_induction_components.py @@ -0,0 +1,603 @@ +"""Characterize which SPD components mediate L2H4's induction behavior. + +Bridges head-level analysis (detect_induction_heads) with component-level decomposition +by measuring each component's causal contribution to the induction attention pattern. + +Four phases: + 1. Weight-based component-head mapping (Frobenius norms) + 2. Per-component induction score via ablation + 3. Cross-head analysis of top induction components + 4. "Why not perfect?" analysis of attention mass allocation + +Usage: + python -m spd.scripts.characterize_induction_components.characterize_induction_components \ + wandb:goodfire/spd/runs/ +""" + +import math +from io import StringIO +from pathlib import Path + +import fire +import numpy as np +import torch +from numpy.typing import NDArray +from torch.nn import functional as F + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents + +# Suppress buffer access issues (rotary_cos, rotary_sin, bias) on CausalSelfAttention +# pyright: reportIndexIssue=false +from spd.pretrain.models.llama_simple_mlp import CausalSelfAttention, LlamaSimpleMLP +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +MIN_MEAN_CI = 0.01 +_default_n_batches = 20 +_default_half_seq_len = 256 +BATCH_SIZE = 32 +TARGET_LAYER = 2 +TARGET_HEAD = 4 +TOP_N = 10 + + +def _get_alive_indices(summary: dict[str, ComponentSummary], module_path: str) -> list[int]: + """Return component indices with CI > MIN_MEAN_CI, sorted descending.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > MIN_MEAN_CI + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _head_entropy(fracs: NDArray[np.floating]) -> float: + """Shannon entropy of a distribution (in bits). Clips zeros to avoid log(0).""" + fracs = fracs[fracs > 0] + return float(-np.sum(fracs * np.log2(fracs))) + + +# ── Phase 1: Weight-based component-head mapping ───────────────────────────── + + +def _compute_head_norm_fractions( + component: LinearComponents, + alive_indices: list[int], + proj_name: str, + head_dim: int, + n_heads: int, +) -> NDArray[np.floating]: + """Compute (n_alive, n_heads) array of per-head norm fractions for each component. + + For q/k/v_proj: head h uses rows [h*head_dim:(h+1)*head_dim] of U. + For o_proj: head h uses columns [h*head_dim:(h+1)*head_dim] of V. + + Returns fractions (each row sums to 1). + """ + n_alive = len(alive_indices) + norms = np.zeros((n_alive, n_heads), dtype=np.float32) + + for row, c_idx in enumerate(alive_indices): + if proj_name in ("q_proj", "k_proj", "v_proj"): + u_c = component.U[c_idx].float() + v_norm = torch.linalg.norm(component.V[:, c_idx].float()).item() + for h in range(n_heads): + head_norm = torch.linalg.norm(u_c[h * head_dim : (h + 1) * head_dim]).item() + norms[row, h] = head_norm * v_norm + else: + v_c = component.V[:, c_idx].float() + u_norm = torch.linalg.norm(component.U[c_idx].float()).item() + for h in range(n_heads): + head_norm = torch.linalg.norm(v_c[h * head_dim : (h + 1) * head_dim]).item() + norms[row, h] = head_norm * u_norm + + row_totals = norms.sum(axis=1, keepdims=True) + row_totals = np.maximum(row_totals, 1e-12) + fracs = norms / row_totals + return fracs + + +def _run_phase1( + model: ComponentModel, + summary: dict[str, ComponentSummary], + head_dim: int, + n_heads: int, + out: StringIO, +) -> dict[str, tuple[list[int], NDArray[np.floating]]]: + """Phase 1: Weight-based component-head mapping. + + Returns {proj_name: (alive_indices, head_norm_fracs)} for L2 q_proj and k_proj. + """ + out.write("=" * 80 + "\n") + out.write("PHASE 1: Weight-based component-head mapping\n") + out.write("=" * 80 + "\n\n") + + results: dict[str, tuple[list[int], NDArray[np.floating]]] = {} + + for proj_name in ("q_proj", "k_proj"): + module_path = f"h.{TARGET_LAYER}.attn.{proj_name}" + component = model.components[module_path] + assert isinstance(component, LinearComponents) + + alive = _get_alive_indices(summary, module_path) + fracs = _compute_head_norm_fractions(component, alive, proj_name, head_dim, n_heads) + results[proj_name] = (alive, fracs) + + out.write(f"── {module_path} ({len(alive)} alive components) ──\n") + out.write(f"{'Comp':>6} {'H4 frac':>8} {'Dom head':>9} {'Entropy':>8} {'Class':>16}\n") + out.write("-" * 55 + "\n") + + n_concentrated = n_involved = n_minor = 0 + for i, c_idx in enumerate(alive): + h4_frac = fracs[i, TARGET_HEAD] + dom_head = int(np.argmax(fracs[i])) + entropy = _head_entropy(fracs[i]) + if h4_frac > 0.5: + cls = "H4-concentrated" + n_concentrated += 1 + elif h4_frac > 0.1: + cls = "H4-involved" + n_involved += 1 + else: + cls = "H4-minor" + n_minor += 1 + out.write( + f"C{c_idx:>4} {h4_frac:>8.3f} {'H' + str(dom_head):>9} {entropy:>8.3f} {cls:>16}\n" + ) + + out.write( + f"\nSummary: {n_concentrated} concentrated, {n_involved} involved, {n_minor} minor\n\n" + ) + + return results + + +# ── Phase 2: Per-component induction score via ablation ────────────────────── + + +def _run_layers_0_to_1( + target_model: LlamaSimpleMLP, + input_ids: torch.Tensor, +) -> torch.Tensor: + """Run layers 0-1 and return residual stream at L2 input.""" + x = target_model.wte(input_ids) + for block in target_model._h[:TARGET_LAYER]: + x = x + block.attn(block.rms_1(x)) + x = x + block.mlp(block.rms_2(x)) + return x + + +def _compute_attention_weights( + attn_input: torch.Tensor, + attn: CausalSelfAttention, + q_full: torch.Tensor | None = None, + k_full: torch.Tensor | None = None, +) -> torch.Tensor: + """Compute softmax attention weights for a given layer's attention module. + + If q_full/k_full are provided, use those instead of computing from attn_input. + Returns (batch, n_heads, seq_len, seq_len). + """ + B, T, _ = attn_input.shape + + q_proj = q_full if q_full is not None else attn.q_proj(attn_input) + k_proj = k_full if k_full is not None else attn.k_proj(attn_input) + + q = q_proj.view(B, T, attn.n_head, attn.head_dim).transpose(1, 2) + k = k_proj.view(B, T, attn.n_key_value_heads, attn.head_dim).transpose(1, 2) + + position_ids = torch.arange(T, device=attn_input.device).unsqueeze(0) + cos = attn.rotary_cos[position_ids].to(q.dtype) + sin = attn.rotary_sin[position_ids].to(q.dtype) + q, k = attn.apply_rotary_pos_emb(q, k, cos, sin) + + if attn.repeat_kv_heads > 1: + k = k.repeat_interleave(attn.repeat_kv_heads, dim=1) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(attn.head_dim)) + att = att.masked_fill(attn.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + return att + + +def _induction_score_from_attn( + att: torch.Tensor, + half_len: int, +) -> NDArray[np.floating]: + """Compute induction score per head from attention weights. + + att: (batch, n_heads, seq_len, seq_len) + Returns: (n_heads,) array of mean induction scores. + """ + src = torch.arange(1, half_len, device=att.device) + dst = torch.arange(half_len, 2 * half_len - 1, device=att.device) + induction_attn = att[:, :, dst, src] # (batch, n_heads, half_len-1) + return induction_attn.float().mean(dim=(0, 2)).cpu().numpy() + + +def _run_phase2( + target_model: LlamaSimpleMLP, + model: ComponentModel, + summary: dict[str, ComponentSummary], + device: torch.device, + n_batches: int, + half_seq_len: int, + out: StringIO, +) -> dict[str, list[tuple[int, float, NDArray[np.floating]]]]: + """Phase 2: Per-component induction score via ablation. + + Returns {proj_name: [(component_idx, delta_h4, delta_all_heads), ...]}. + """ + out.write("=" * 80 + "\n") + out.write("PHASE 2: Per-component induction score via ablation\n") + out.write("=" * 80 + "\n\n") + + attn = target_model._h[TARGET_LAYER].attn + rms = target_model._h[TARGET_LAYER].rms_1 + n_heads = attn.n_head + vocab_size = target_model.config.vocab_size + + # Accumulate baseline scores and per-component deltas + baseline_accum = np.zeros(n_heads, dtype=np.float64) + + ablation_results: dict[str, dict[int, np.ndarray]] = {} + for proj_name in ("q_proj", "k_proj"): + module_path = f"h.{TARGET_LAYER}.attn.{proj_name}" + alive = _get_alive_indices(summary, module_path) + ablation_results[proj_name] = {c: np.zeros(n_heads, dtype=np.float64) for c in alive} + + for batch_i in range(n_batches): + first_half = torch.randint(100, vocab_size - 100, (BATCH_SIZE, half_seq_len), device=device) + input_ids = torch.cat([first_half, first_half], dim=1) + + l2_input = _run_layers_0_to_1(target_model, input_ids) + attn_input = rms(l2_input) + + q_full = attn.q_proj(attn_input) + k_full = attn.k_proj(attn_input) + + baseline_att = _compute_attention_weights(attn_input, attn, q_full, k_full) + baseline_scores = _induction_score_from_attn(baseline_att, half_seq_len) + baseline_accum += baseline_scores + + for proj_name in ("q_proj", "k_proj"): + module_path = f"h.{TARGET_LAYER}.attn.{proj_name}" + component = model.components[module_path] + assert isinstance(component, LinearComponents) + alive = _get_alive_indices(summary, module_path) + + full_proj = q_full if proj_name == "q_proj" else k_full + + for c_idx in alive: + v_c = component.V[:, c_idx] # (d_in,) + u_c = component.U[c_idx] # (d_out,) + scalar_c = (attn_input @ v_c).unsqueeze(-1) # (B, T, 1) + contribution_c = scalar_c * u_c.unsqueeze(0).unsqueeze(0) # (B, T, d_out) + + ablated_proj = full_proj - contribution_c + + if proj_name == "q_proj": + att_ablated = _compute_attention_weights( + attn_input, attn, q_full=ablated_proj, k_full=k_full + ) + else: + att_ablated = _compute_attention_weights( + attn_input, attn, q_full=q_full, k_full=ablated_proj + ) + + ablated_scores = _induction_score_from_attn(att_ablated, half_seq_len) + delta = baseline_scores - ablated_scores # positive = component helps induction + ablation_results[proj_name][c_idx] += delta + + if (batch_i + 1) % 5 == 0: + logger.info(f"Phase 2: processed {batch_i + 1}/{n_batches} batches") + + baseline_accum /= n_batches + for proj_name in ablation_results: + for c_idx in ablation_results[proj_name]: + ablation_results[proj_name][c_idx] /= n_batches + + out.write(f"Baseline induction scores (n={n_batches} batches of {BATCH_SIZE}):\n") + for h in range(n_heads): + marker = " <-- TARGET" if h == TARGET_HEAD else "" + out.write(f" H{h}: {baseline_accum[h]:.4f}{marker}\n") + out.write("\n") + + # Build sorted result lists + phase2_results: dict[str, list[tuple[int, float, NDArray[np.floating]]]] = {} + for proj_name in ("q_proj", "k_proj"): + items: list[tuple[int, float, NDArray[np.floating]]] = [] + for c_idx, delta_all in ablation_results[proj_name].items(): + delta_h4 = float(delta_all[TARGET_HEAD]) + items.append((c_idx, delta_h4, delta_all)) + items.sort(key=lambda t: t[1], reverse=True) + phase2_results[proj_name] = items + + out.write( + f"── h.{TARGET_LAYER}.attn.{proj_name}: Top components by H4 induction contribution ──\n" + ) + head_labels = " ".join(f"{'H' + str(h):>7}" for h in range(n_heads)) + out.write(f"{'Comp':>6} {'dH4':>8} {head_labels}\n") + out.write("-" * (20 + n_heads * 9) + "\n") + for c_idx, delta_h4, delta_all in items[:30]: + deltas_str = " ".join(f"{delta_all[h]:>+7.4f}" for h in range(n_heads)) + out.write(f"C{c_idx:>4} {delta_h4:>+8.4f} {deltas_str}\n") + out.write("\n") + + total_h4 = sum(delta_h4 for _, delta_h4, _ in items) + out.write(f"Sum of all {proj_name} component deltas for H4: {total_h4:+.4f}\n") + out.write(f"Baseline H4 induction score: {baseline_accum[TARGET_HEAD]:.4f}\n\n") + + return phase2_results + + +# ── Phase 3: Cross-head analysis of top induction components ───────────────── + + +def _run_phase3( + phase1_results: dict[str, tuple[list[int], NDArray[np.floating]]], + phase2_results: dict[str, list[tuple[int, float, NDArray[np.floating]]]], + n_heads: int, + out: StringIO, +) -> None: + """Phase 3: Cross-head analysis of top induction components.""" + out.write("=" * 80 + "\n") + out.write("PHASE 3: Cross-head analysis of top induction components\n") + out.write("=" * 80 + "\n\n") + + for proj_name in ("q_proj", "k_proj"): + alive_indices, head_fracs = phase1_results[proj_name] + idx_to_row = {c: i for i, c in enumerate(alive_indices)} + + top_components = phase2_results[proj_name][:TOP_N] + + out.write( + f"── h.{TARGET_LAYER}.attn.{proj_name}: Top {TOP_N} by induction contribution ──\n\n" + ) + + for c_idx, delta_h4, delta_all in top_components: + out.write(f"Component C{c_idx} (dH4 = {delta_h4:+.4f}):\n") + + # Weight norm distribution + if c_idx in idx_to_row: + row = idx_to_row[c_idx] + fracs = head_fracs[row] + out.write(" Weight norm fraction per head:\n ") + out.write( + " ".join( + f"H{h}: {fracs[h]:.3f}{'*' if h == TARGET_HEAD else ''}" + for h in range(n_heads) + ) + ) + out.write("\n") + + # Ablation effect per head + out.write(" Induction score change per head when ablated:\n ") + out.write( + " ".join( + f"H{h}: {delta_all[h]:+.4f}{'*' if h == TARGET_HEAD else ''}" + for h in range(n_heads) + ) + ) + out.write("\n") + + # Cross-head effects + significant_other = [ + (h, float(delta_all[h])) + for h in range(n_heads) + if h != TARGET_HEAD and abs(delta_all[h]) > 0.005 + ] + if significant_other: + out.write(" Significant cross-head effects (|delta| > 0.005):\n") + for h, d in significant_other: + direction = "increases" if d < 0 else "decreases" + out.write( + f" H{h}: {d:+.4f} (ablating this component {direction} H{h} induction)\n" + ) + + out.write("\n") + + +# ── Phase 4: "Why not perfect?" analysis ───────────────────────────────────── + + +def _run_phase4( + target_model: LlamaSimpleMLP, + phase1_results: dict[str, tuple[list[int], NDArray[np.floating]]], + phase2_results: dict[str, list[tuple[int, float, NDArray[np.floating]]]], + device: torch.device, + n_heads: int, + half_seq_len: int, + out: StringIO, +) -> None: + """Phase 4: Analyze what prevents L2H4 from having a perfect induction score.""" + out.write("=" * 80 + "\n") + out.write("PHASE 4: Why not perfect? Analysis\n") + out.write("=" * 80 + "\n\n") + + attn = target_model._h[TARGET_LAYER].attn + rms = target_model._h[TARGET_LAYER].rms_1 + vocab_size = target_model.config.vocab_size + seq_len = half_seq_len * 2 + + # 4a: BOS attention competition + out.write("── 4a: Attention mass allocation in H4 on induction data ──\n\n") + + attn_to_bos_accum = np.zeros(n_heads, dtype=np.float64) + attn_to_induction_accum = np.zeros(n_heads, dtype=np.float64) + attn_to_other_accum = np.zeros(n_heads, dtype=np.float64) + + n_batches_phase4 = 10 + for _ in range(n_batches_phase4): + first_half = torch.randint(100, vocab_size - 100, (BATCH_SIZE, half_seq_len), device=device) + input_ids = torch.cat([first_half, first_half], dim=1) + l2_input = _run_layers_0_to_1(target_model, input_ids) + attn_input = rms(l2_input) + att = _compute_attention_weights(attn_input, attn) + + second_half_att = att[:, :, half_seq_len:, :] # (B, n_heads, half_len, seq_len) + bos_att = second_half_att[:, :, :, 0] # (B, n_heads, half_len) + attn_to_bos_accum += bos_att.float().mean(dim=(0, 2)).cpu().numpy() + + # Induction target: for query at pos half_len+k, target is pos k+1 + src = torch.arange(1, half_seq_len + 1, device=device).clamp(max=seq_len - 1) + dst_range = torch.arange(half_seq_len, device=device) + induction_att = second_half_att[:, :, dst_range, src] # (B, n_heads, half_len) + attn_to_induction_accum += induction_att.float().mean(dim=(0, 2)).cpu().numpy() + + other_att = 1.0 - bos_att - induction_att + attn_to_other_accum += other_att.float().mean(dim=(0, 2)).cpu().numpy() + + attn_to_bos_accum /= n_batches_phase4 + attn_to_induction_accum /= n_batches_phase4 + attn_to_other_accum /= n_batches_phase4 + + out.write(f"{'Head':>6} {'Induction':>10} {'BOS':>8} {'Other':>8}\n") + out.write("-" * 38 + "\n") + for h in range(n_heads): + marker = " <-- TARGET" if h == TARGET_HEAD else "" + out.write( + f" H{h} {attn_to_induction_accum[h]:>10.4f} {attn_to_bos_accum[h]:>8.4f} " + f"{attn_to_other_accum[h]:>8.4f}{marker}\n" + ) + out.write("\n") + + # 4b: Non-induction components in H4 + out.write("── 4b: H4-concentrated components with low induction contribution ──\n\n") + out.write( + "These components have significant weight in H4 but contribute little to induction.\n" + ) + out.write("They may drive BOS attention or other non-induction patterns.\n\n") + + for proj_name in ("q_proj", "k_proj"): + alive_indices, head_fracs = phase1_results[proj_name] + idx_to_row = {c: i for i, c in enumerate(alive_indices)} + + ablation_map = {c_idx: delta_h4 for c_idx, delta_h4, _ in phase2_results[proj_name]} + + non_induction_h4: list[tuple[int, float, float]] = [] + for c_idx in alive_indices: + row = idx_to_row[c_idx] + h4_frac = float(head_fracs[row, TARGET_HEAD]) + delta_h4 = ablation_map.get(c_idx, 0.0) + if h4_frac > 0.1 and delta_h4 < 0.01: + non_induction_h4.append((c_idx, h4_frac, delta_h4)) + + non_induction_h4.sort(key=lambda t: t[1], reverse=True) + + out.write( + f" {proj_name}: {len(non_induction_h4)} components with H4 frac > 0.1 and dH4 < 0.01:\n" + ) + out.write(f" {'Comp':>6} {'H4 frac':>8} {'dH4':>8}\n") + out.write(" " + "-" * 28 + "\n") + for c_idx, h4_frac, delta_h4 in non_induction_h4[:15]: + out.write(f" C{c_idx:>4} {h4_frac:>8.3f} {delta_h4:>+8.4f}\n") + out.write("\n") + + # 4c: Induction leakage / competition across heads + out.write("── 4c: Cross-head induction leakage ──\n\n") + out.write("Components whose ablation *increases* another head's induction score,\n") + out.write("suggesting competitive dynamics.\n\n") + + for proj_name in ("q_proj", "k_proj"): + leakage: list[tuple[int, int, float, float]] = [] + for c_idx, delta_h4, delta_all in phase2_results[proj_name]: + for h in range(n_heads): + if h != TARGET_HEAD and delta_all[h] < -0.005: + leakage.append((c_idx, h, float(delta_all[h]), delta_h4)) + + if leakage: + leakage.sort(key=lambda t: t[2]) + out.write( + f" {proj_name}: {len(leakage)} cases of ablation increasing other heads' induction:\n" + ) + out.write(f" {'Comp':>6} {'Head':>6} {'dHead':>8} {'dH4':>8}\n") + out.write(" " + "-" * 34 + "\n") + for c_idx, h, d_other, d_h4 in leakage[:15]: + out.write(f" C{c_idx:>4} H{h:>4} {d_other:>+8.4f} {d_h4:>+8.4f}\n") + out.write("\n") + else: + out.write(f" {proj_name}: No significant cross-head leakage detected.\n\n") + + +# ── Main ───────────────────────────────────────────────────────────────────── + + +def characterize_induction_components( + wandb_path: ModelPath, + n_batches: int = _default_n_batches, + half_seq_len: int = _default_half_seq_len, +) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + spd_model = ComponentModel.from_run_info(run_info) + spd_model.eval() + + target_model = spd_model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + spd_model = spd_model.to(device) + target_model = target_model.to(device) + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data for {run_id}" + summary = repo.get_summary() + + n_heads = target_model._h[TARGET_LAYER].attn.n_head + head_dim = target_model._h[TARGET_LAYER].attn.head_dim + + logger.info(f"Model: {len(target_model._h)} layers, {n_heads} heads, head_dim={head_dim}") + logger.info(f"Target: L{TARGET_LAYER}H{TARGET_HEAD}") + + report = StringIO() + report.write("Induction Component Characterization Report\n") + report.write(f"Run: {run_id}\n") + report.write(f"Target: L{TARGET_LAYER}H{TARGET_HEAD}\n") + report.write(f"Batches: {n_batches} x {BATCH_SIZE}, half_seq_len={half_seq_len}\n") + report.write(f"Device: {device}\n\n") + + with torch.no_grad(): + phase1_results = _run_phase1(spd_model, summary, head_dim, n_heads, report) + phase2_results = _run_phase2( + target_model, + spd_model, + summary, + device, + n_batches, + half_seq_len, + report, + ) + _run_phase3(phase1_results, phase2_results, n_heads, report) + _run_phase4( + target_model, + phase1_results, + phase2_results, + device, + n_heads, + half_seq_len, + report, + ) + + report_text = report.getvalue() + print(report_text) + + report_path = out_dir / "induction_component_report.txt" + report_path.write_text(report_text) + logger.info(f"Report saved to {report_path}") + + +if __name__ == "__main__": + fire.Fire(characterize_induction_components) diff --git a/spd/scripts/collect_attention_patterns.py b/spd/scripts/collect_attention_patterns.py new file mode 100644 index 000000000..848506294 --- /dev/null +++ b/spd/scripts/collect_attention_patterns.py @@ -0,0 +1,62 @@ +"""Shared utility for collecting attention weights from LlamaSimpleMLP models. + +Used by the detect_* head characterization scripts. +""" + +import math + +import torch +from torch.nn import functional as F + +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP + + +def collect_attention_patterns( + model: LlamaSimpleMLP, + input_ids: torch.Tensor, +) -> list[torch.Tensor]: + """Run forward pass and return attention weights for each layer. + + Returns list of (batch, n_heads, seq_len, seq_len) tensors. + """ + B, T = input_ids.shape + x = model.wte(input_ids) + patterns: list[torch.Tensor] = [] + + for block in model._h: + attn_input = block.rms_1(x) + attn = block.attn + + q = attn.q_proj(attn_input).view(B, T, attn.n_head, attn.head_dim).transpose(1, 2) + k = ( + attn.k_proj(attn_input) + .view(B, T, attn.n_key_value_heads, attn.head_dim) + .transpose(1, 2) + ) + v = ( + attn.v_proj(attn_input) + .view(B, T, attn.n_key_value_heads, attn.head_dim) + .transpose(1, 2) + ) + + position_ids = torch.arange(T, device=input_ids.device).unsqueeze(0) + cos = attn.rotary_cos[position_ids].to(q.dtype) # pyright: ignore[reportIndexIssue] + sin = attn.rotary_sin[position_ids].to(q.dtype) # pyright: ignore[reportIndexIssue] + q, k = attn.apply_rotary_pos_emb(q, k, cos, sin) + + if attn.repeat_kv_heads > 1: + k = k.repeat_interleave(attn.repeat_kv_heads, dim=1) + v = v.repeat_interleave(attn.repeat_kv_heads, dim=1) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(attn.head_dim)) + att = att.masked_fill(attn.bias[:, :, :T, :T] == 0, float("-inf")) # pyright: ignore[reportIndexIssue] + att = F.softmax(att, dim=-1) + patterns.append(att) + + y = att @ v + y = y.transpose(1, 2).contiguous().view(B, T, attn.n_embd) + y = attn.o_proj(y) + x = x + y + x = x + block.mlp(block.rms_2(x)) + + return patterns diff --git a/spd/scripts/detect_delimiter_heads/__init__.py b/spd/scripts/detect_delimiter_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_delimiter_heads/detect_delimiter_heads.py b/spd/scripts/detect_delimiter_heads/detect_delimiter_heads.py new file mode 100644 index 000000000..d10c02c93 --- /dev/null +++ b/spd/scripts/detect_delimiter_heads/detect_delimiter_heads.py @@ -0,0 +1,222 @@ +"""Detect delimiter-attending attention heads. + +For each head, measures what fraction of attention weight lands on structural +delimiter tokens (periods, commas, semicolons, etc.) and compares to the +baseline delimiter frequency. Heads with a high ratio over baseline +disproportionately target structural markers. + +Usage: + python -m spd.scripts.detect_delimiter_heads.detect_delimiter_heads \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray +from transformers import AutoTokenizer + +from spd.configs import LMTaskConfig +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_BATCHES = 100 +BATCH_SIZE = 32 + +DELIMITER_CHARS = [".", ",", ";", ":", "!", "?", "\n"] +DELIMITER_MULTI = [".\n", ".\n\n", ",\n", ";\n"] + + +def _get_delimiter_token_ids(tokenizer: AutoTokenizer) -> set[int]: + """Collect token IDs for delimiter characters and common multi-char delimiters.""" + delimiter_ids: set[int] = set() + for char in DELIMITER_CHARS: + token_ids = tokenizer.encode(char) + if len(token_ids) == 1: + delimiter_ids.add(token_ids[0]) + for multi in DELIMITER_MULTI: + token_ids = tokenizer.encode(multi) + if len(token_ids) == 1: + delimiter_ids.add(token_ids[0]) + return delimiter_ids + + +def _compute_delimiter_scores( + patterns: list[torch.Tensor], + input_ids: torch.Tensor, + delimiter_ids: set[int], +) -> tuple[NDArray[np.floating], float]: + """Compute mean fraction of attention on delimiter tokens per head. + + Returns (raw_scores of shape (n_layers, n_heads), baseline_fraction). + """ + B, T = input_ids.shape + n_layers = len(patterns) + n_heads = patterns[0].shape[1] + + delim_set = torch.tensor(sorted(delimiter_ids), device=input_ids.device) + is_delim = (input_ids.unsqueeze(-1) == delim_set.unsqueeze(0).unsqueeze(0)).any(dim=-1) + baseline_fraction = is_delim.float().mean().item() + + # (B, 1, 1, T) for broadcasting with attention (B, H, T_q, T_k) + is_delim_key = is_delim.unsqueeze(1).unsqueeze(2).float() + + raw_scores = np.zeros((n_layers, n_heads)) + for layer_idx, att in enumerate(patterns): + attn_to_delim = (att.float() * is_delim_key).sum(dim=-1) # (B, H, T) + raw_scores[layer_idx] = attn_to_delim.mean(dim=(0, 2)).cpu().numpy() + + return raw_scores, baseline_fraction + + +def _plot_dual_heatmap( + raw_scores: NDArray[np.floating], + ratio_scores: NDArray[np.floating], + run_id: str, + n_samples: int, + baseline_frac: float, + out_path: Path, +) -> None: + n_layers, n_heads = raw_scores.shape + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(max(12, n_heads * 2.4), max(4, n_layers * 1.0))) + + im1 = ax1.imshow(raw_scores, aspect="auto", cmap="viridis", vmin=0) + fig.colorbar(im1, ax=ax1, shrink=0.8, pad=0.02, label="Attn fraction on delimiters") + ax1.set_title("Attention to delimiters", fontsize=11, fontweight="bold") + + im2 = ax2.imshow(ratio_scores, aspect="auto", cmap="viridis", vmin=0) + fig.colorbar( + im2, ax=ax2, shrink=0.8, pad=0.02, label=f"Ratio over baseline ({baseline_frac:.3f})" + ) + ax2.set_title("Ratio over baseline", fontsize=11, fontweight="bold") + + for ax in (ax1, ax2): + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=10) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=10) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + data_sets = [raw_scores, ratio_scores] + for ax, data in zip([ax1, ax2], data_sets, strict=True): + for layer_idx in range(n_layers): + for h in range(n_heads): + val = data[layer_idx, h] + threshold = abs(data).max() * 0.6 + color = "white" if abs(val) < threshold else "black" + ax.text( + h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=8, color=color + ) + + fig.suptitle( + f"{run_id} | Delimiter head scores (n={n_samples} batches, baseline={baseline_frac:.3f})", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_delimiter_heads(wandb_path: ModelPath, n_batches: int = N_BATCHES) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + delimiter_ids = _get_delimiter_token_ids(tokenizer) + logger.info(f"Found {len(delimiter_ids)} delimiter token IDs: {sorted(delimiter_ids)}") + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + seq_len = target_model.config.n_ctx + logger.info(f"Model: {n_layers} layers, {n_heads} heads, seq_len={seq_len}") + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=BATCH_SIZE, + buffer_size=1000, + ) + + accum_raw = np.zeros((n_layers, n_heads)) + accum_baseline = 0.0 + n_processed = 0 + + with torch.no_grad(): + for i, batch in enumerate(loader): + if i >= n_batches: + break + input_ids = batch[task_config.column_name][:, :seq_len].to(device) + patterns = collect_attention_patterns(target_model, input_ids) + raw_scores, baseline_fraction = _compute_delimiter_scores( + patterns, input_ids, delimiter_ids + ) + + accum_raw += raw_scores + accum_baseline += baseline_fraction + n_processed += 1 + if (i + 1) % 25 == 0: + logger.info(f"Processed {i + 1}/{n_batches} batches") + + assert n_processed > 0 + accum_raw /= n_processed + mean_baseline = accum_baseline / n_processed + ratio_scores = accum_raw / max(mean_baseline, 1e-8) + + logger.info(f"Delimiter head scores (n={n_processed} batches, baseline={mean_baseline:.4f}):") + logger.info(" Raw attn | Ratio over baseline") + for layer_idx in range(n_layers): + for h in range(n_heads): + raw_val = accum_raw[layer_idx, h] + ratio_val = ratio_scores[layer_idx, h] + marker = " <-- delimiter head" if ratio_val > 2.0 else "" + logger.info(f" L{layer_idx}H{h}: raw={raw_val:.4f} ratio={ratio_val:.2f}{marker}") + + _plot_dual_heatmap( + accum_raw, + ratio_scores, + run_id, + n_processed, + mean_baseline, + out_dir / "delimiter_scores.png", + ) + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_delimiter_heads) diff --git a/spd/scripts/detect_duplicate_token_heads/__init__.py b/spd/scripts/detect_duplicate_token_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_duplicate_token_heads/detect_duplicate_token_heads.py b/spd/scripts/detect_duplicate_token_heads/detect_duplicate_token_heads.py new file mode 100644 index 000000000..5720fd6b4 --- /dev/null +++ b/spd/scripts/detect_duplicate_token_heads/detect_duplicate_token_heads.py @@ -0,0 +1,178 @@ +"""Detect duplicate-token attention heads. + +For each head, measures the mean attention weight going to previous positions +that contain the exact same token as the current position, conditioned on +positions where at least one prior duplicate exists. + +Usage: + python -m spd.scripts.detect_duplicate_token_heads.detect_duplicate_token_heads \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.configs import LMTaskConfig +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_BATCHES = 100 +BATCH_SIZE = 32 + + +def _compute_duplicate_token_scores( + patterns: list[torch.Tensor], + input_ids: torch.Tensor, +) -> tuple[NDArray[np.floating], int]: + """Compute mean attention to prior same-token positions per head. + + Only positions where the current token has at least one prior duplicate are + included. Returns (scores of shape (n_layers, n_heads), n_valid_positions). + """ + B, T = input_ids.shape + n_layers = len(patterns) + n_heads = patterns[0].shape[1] + + # mask[b, i, j] = True iff j < i and input_ids[b, j] == input_ids[b, i] + same_token = input_ids.unsqueeze(2) == input_ids.unsqueeze(1) # (B, T, T) + causal = torch.tril(torch.ones(T, T, device=input_ids.device, dtype=torch.bool), diagonal=-1) + dup_mask = same_token & causal + has_dup = dup_mask.any(dim=-1) # (B, T) + n_valid = has_dup.sum().item() + + scores = np.zeros((n_layers, n_heads)) + if n_valid == 0: + return scores, 0 + + dup_mask_float = dup_mask.unsqueeze(1).float() # (B, 1, T, T) + has_dup_float = has_dup.unsqueeze(1).float() # (B, 1, T) + + for layer_idx, att in enumerate(patterns): + dup_attn = (att.float() * dup_mask_float).sum(dim=-1) # (B, H, T) + valid_sum = (dup_attn * has_dup_float).sum(dim=(0, 2)) # (H,) + scores[layer_idx] = valid_sum.cpu().numpy() / n_valid + + return scores, n_valid + + +def _plot_score_heatmap( + scores: NDArray[np.floating], + run_id: str, + n_samples: int, + out_path: Path, +) -> None: + n_layers, n_heads = scores.shape + fig, ax = plt.subplots(figsize=(max(6, n_heads * 1.2), max(4, n_layers * 1.0))) + + im = ax.imshow(scores, aspect="auto", cmap="viridis", vmin=0) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label="Mean attn to same-token pos") + + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=10) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=10) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + for layer_idx in range(n_layers): + for h in range(n_heads): + val = scores[layer_idx, h] + color = "white" if val < scores.max() * 0.6 else "black" + ax.text(h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=9, color=color) + + fig.suptitle( + f"{run_id} | Duplicate-token head scores (n={n_samples} batches)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_duplicate_token_heads(wandb_path: ModelPath, n_batches: int = N_BATCHES) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + seq_len = target_model.config.n_ctx + logger.info(f"Model: {n_layers} layers, {n_heads} heads, seq_len={seq_len}") + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=BATCH_SIZE, + buffer_size=1000, + ) + + accum_scores = np.zeros((n_layers, n_heads)) + total_valid = 0 + n_processed = 0 + + with torch.no_grad(): + for i, batch in enumerate(loader): + if i >= n_batches: + break + input_ids = batch[task_config.column_name][:, :seq_len].to(device) + patterns = collect_attention_patterns(target_model, input_ids) + scores, n_valid = _compute_duplicate_token_scores(patterns, input_ids) + + accum_scores += scores * n_valid + total_valid += n_valid + n_processed += 1 + if (i + 1) % 25 == 0: + logger.info(f"Processed {i + 1}/{n_batches} batches") + + assert n_processed > 0 and total_valid > 0 + accum_scores /= total_valid + + logger.info(f"Duplicate-token scores (n={n_processed} batches, {total_valid} valid positions):") + for layer_idx in range(n_layers): + for h in range(n_heads): + score = accum_scores[layer_idx, h] + marker = " <-- dup-token head" if score > 0.3 else "" + logger.info(f" L{layer_idx}H{h}: {score:.4f}{marker}") + + _plot_score_heatmap(accum_scores, run_id, n_processed, out_dir / "duplicate_token_scores.png") + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_duplicate_token_heads) diff --git a/spd/scripts/detect_induction_heads/__init__.py b/spd/scripts/detect_induction_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_induction_heads/detect_induction_heads.py b/spd/scripts/detect_induction_heads/detect_induction_heads.py new file mode 100644 index 000000000..19dd81ef7 --- /dev/null +++ b/spd/scripts/detect_induction_heads/detect_induction_heads.py @@ -0,0 +1,209 @@ +"""Detect induction heads using repeated random token sequences. + +For repeated sequences [A B C ... | A B C ...], an induction head at position L+k +in the second half should attend to position k+1 in the first half (the token after +the first occurrence of the current token). The induction score is the mean attention +weight at this offset diagonal. + +Usage: + python -m spd.scripts.detect_induction_heads.detect_induction_heads \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_BATCHES = 50 +BATCH_SIZE = 32 +HALF_SEQ_LEN = 256 + + +def _compute_induction_scores( + patterns: list[torch.Tensor], + half_len: int, +) -> NDArray[np.floating]: + """Compute induction score for each layer and head. + + For repeated sequences [t_0..t_{L-1} t_0..t_{L-1}], the induction pattern at + position L+k attends to position k+1. We average over k in [0, L-2]. + + Returns shape (n_layers, n_heads). + """ + src = torch.arange(1, half_len, device=patterns[0].device) + dst = torch.arange(half_len, 2 * half_len - 1, device=patterns[0].device) + + n_layers = len(patterns) + n_heads = patterns[0].shape[1] + scores = np.zeros((n_layers, n_heads)) + + for layer_idx, att in enumerate(patterns): + # att[:, :, dst, src] zips indices: att[b, h, dst[i], src[i]] + induction_attn = att[:, :, dst, src] # (batch, n_heads, half_len-1) + scores[layer_idx] = induction_attn.float().mean(dim=(0, 2)).cpu().numpy() + + return scores + + +def _plot_score_heatmap( + scores: NDArray[np.floating], + run_id: str, + n_samples: int, + out_path: Path, +) -> None: + n_layers, n_heads = scores.shape + fig, ax = plt.subplots(figsize=(max(6, n_heads * 1.2), max(4, n_layers * 1.0))) + + im = ax.imshow(scores, aspect="auto", cmap="viridis", vmin=0) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label="Induction score") + + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=10) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=10) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + for layer_idx in range(n_layers): + for h in range(n_heads): + val = scores[layer_idx, h] + color = "white" if val < scores.max() * 0.6 else "black" + ax.text(h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=9, color=color) + + fig.suptitle( + f"{run_id} | Induction head scores (n={n_samples} batches)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def _plot_mean_attention_patterns( + mean_patterns: list[torch.Tensor], + run_id: str, + n_samples: int, + half_len: int, + out_path: Path, +) -> None: + """Plot grid of mean attention patterns on repeated random sequences.""" + n_layers = len(mean_patterns) + n_heads = mean_patterns[0].shape[0] + + fig, axes = plt.subplots( + n_layers, + n_heads, + figsize=(n_heads * 3, n_layers * 3), + squeeze=False, + ) + + for layer_idx in range(n_layers): + for h in range(n_heads): + ax = axes[layer_idx, h] + pattern = mean_patterns[layer_idx][h].numpy() + ax.imshow(pattern, aspect="auto", cmap="viridis", vmin=0) + ax.axhline(y=half_len - 0.5, color="red", linewidth=0.5, linestyle="--") + ax.axvline(x=half_len - 0.5, color="red", linewidth=0.5, linestyle="--") + ax.set_title(f"L{layer_idx}H{h}", fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + if h == 0: + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + + fig.suptitle( + f"{run_id} | Mean attention (repeated random seqs) (n={n_samples})", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_induction_heads( + wandb_path: ModelPath, + n_batches: int = N_BATCHES, + half_seq_len: int = HALF_SEQ_LEN, +) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + vocab_size = target_model.config.vocab_size + seq_len = half_seq_len * 2 + logger.info(f"Model: {n_layers} layers, {n_heads} heads") + logger.info(f"Induction test: half_len={half_seq_len}, total_len={seq_len}") + + accum_scores = np.zeros((n_layers, n_heads)) + accum_patterns = [torch.zeros(n_heads, seq_len, seq_len) for _ in range(n_layers)] + n_processed = 0 + + with torch.no_grad(): + for i in range(n_batches): + first_half = torch.randint( + 100, vocab_size - 100, (BATCH_SIZE, half_seq_len), device=device + ) + input_ids = torch.cat([first_half, first_half], dim=1) + + patterns = collect_attention_patterns(target_model, input_ids) + scores = _compute_induction_scores(patterns, half_seq_len) + accum_scores += scores + + for layer_idx in range(n_layers): + accum_patterns[layer_idx] += patterns[layer_idx].float().mean(dim=0).cpu() + + n_processed += 1 + if (i + 1) % 10 == 0: + logger.info(f"Processed {i + 1}/{n_batches} batches") + + assert n_processed > 0 + accum_scores /= n_processed + for layer_idx in range(n_layers): + accum_patterns[layer_idx] /= n_processed + + logger.info(f"Induction scores (n={n_processed} batches):") + for layer_idx in range(n_layers): + for h in range(n_heads): + score = accum_scores[layer_idx, h] + marker = " <-- induction head" if score > 0.3 else "" + logger.info(f" L{layer_idx}H{h}: {score:.4f}{marker}") + + _plot_score_heatmap(accum_scores, run_id, n_processed, out_dir / "induction_scores.png") + _plot_mean_attention_patterns( + accum_patterns, run_id, n_processed, half_seq_len, out_dir / "mean_attention_repeated.png" + ) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_induction_heads) diff --git a/spd/scripts/detect_positional_heads/__init__.py b/spd/scripts/detect_positional_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_positional_heads/detect_positional_heads.py b/spd/scripts/detect_positional_heads/detect_positional_heads.py new file mode 100644 index 000000000..4697fefe4 --- /dev/null +++ b/spd/scripts/detect_positional_heads/detect_positional_heads.py @@ -0,0 +1,260 @@ +"""Detect positional attention heads that attend to fixed relative offsets. + +For each head, builds a histogram of mean attention weight by relative offset +(offset = query_pos - key_pos). A positional head shows a sharp peak at one +or a few specific offsets regardless of token content. Also measures attention +to BOS (absolute position 0). + +Usage: + python -m spd.scripts.detect_positional_heads.detect_positional_heads \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.configs import LMTaskConfig +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_BATCHES = 100 +BATCH_SIZE = 32 +MAX_OFFSET = 128 + + +def _compute_positional_profiles( + patterns: list[torch.Tensor], + max_offset: int, +) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]: + """Compute per-head mean attention by relative offset, max-offset score, and BOS score. + + Returns: + profiles: (n_layers, n_heads, max_offset) mean attention at each relative offset + max_offset_scores: (n_layers, n_heads) max attention at any offset >= 1 + bos_scores: (n_layers, n_heads) mean attention to absolute position 0 + """ + n_layers = len(patterns) + n_heads = patterns[0].shape[1] + B, _, T, _ = patterns[0].shape + + profiles = np.zeros((n_layers, n_heads, max_offset)) + bos_scores = np.zeros((n_layers, n_heads)) + + for layer_idx, att in enumerate(patterns): + # BOS: mean attention to position 0 across all query positions + bos_scores[layer_idx] = att[:, :, :, 0].float().mean(dim=(0, 2)).cpu().numpy() + + # Positional profile: for each offset d, average att[b, h, q, q-d] over valid q + for d in range(min(max_offset, T)): + diag = torch.diagonal(att, offset=-d, dim1=-2, dim2=-1) # (B, H, T-d) + profiles[layer_idx, :, d] = diag.float().mean(dim=(0, 2)).cpu().numpy() + + max_offset_scores = ( + profiles[:, :, 1:].max(axis=2) if max_offset > 1 else np.zeros((n_layers, n_heads)) + ) + + return profiles, max_offset_scores, bos_scores + + +def _plot_score_heatmap( + scores: NDArray[np.floating], + run_id: str, + n_samples: int, + out_path: Path, + title: str, + colorbar_label: str, +) -> None: + n_layers, n_heads = scores.shape + fig, ax = plt.subplots(figsize=(max(6, n_heads * 1.2), max(4, n_layers * 1.0))) + + im = ax.imshow(scores, aspect="auto", cmap="viridis", vmin=0) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label=colorbar_label) + + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=10) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=10) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + for layer_idx in range(n_layers): + for h in range(n_heads): + val = scores[layer_idx, h] + color = "white" if val < scores.max() * 0.6 else "black" + ax.text(h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=9, color=color) + + fig.suptitle( + f"{run_id} | {title} (n={n_samples} batches)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def _plot_positional_profiles( + profiles: NDArray[np.floating], + run_id: str, + n_samples: int, + out_path: Path, + max_display_offset: int = 64, +) -> None: + """Plot positional profiles as a grid of line charts (one per head).""" + n_layers, n_heads, n_offsets = profiles.shape + display_offsets = min(max_display_offset, n_offsets) + + fig, axes = plt.subplots( + n_layers, + n_heads, + figsize=(n_heads * 3, n_layers * 2.5), + squeeze=False, + ) + + for layer_idx in range(n_layers): + for h in range(n_heads): + ax = axes[layer_idx, h] + ax.plot(range(display_offsets), profiles[layer_idx, h, :display_offsets], linewidth=1.0) + ax.set_title(f"L{layer_idx}H{h}", fontsize=9) + ax.set_xlim(0, display_offsets) + if layer_idx == n_layers - 1: + ax.set_xlabel("Offset", fontsize=8) + if h == 0: + ax.set_ylabel("Mean attn", fontsize=8) + ax.tick_params(labelsize=7) + + fig.suptitle( + f"{run_id} | Positional profiles (n={n_samples} batches)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_positional_heads( + wandb_path: ModelPath, n_batches: int = N_BATCHES, max_offset: int = MAX_OFFSET +) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + seq_len = target_model.config.n_ctx + logger.info(f"Model: {n_layers} layers, {n_heads} heads, seq_len={seq_len}") + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=BATCH_SIZE, + buffer_size=1000, + ) + + accum_profiles = np.zeros((n_layers, n_heads, max_offset)) + accum_max_offset = np.zeros((n_layers, n_heads)) + accum_bos = np.zeros((n_layers, n_heads)) + n_processed = 0 + + with torch.no_grad(): + for i, batch in enumerate(loader): + if i >= n_batches: + break + input_ids = batch[task_config.column_name][:, :seq_len].to(device) + patterns = collect_attention_patterns(target_model, input_ids) + profiles, max_offset_scores, bos_scores = _compute_positional_profiles( + patterns, max_offset + ) + + accum_profiles += profiles + accum_max_offset += max_offset_scores + accum_bos += bos_scores + n_processed += 1 + if (i + 1) % 25 == 0: + logger.info(f"Processed {i + 1}/{n_batches} batches") + + assert n_processed > 0 + accum_profiles /= n_processed + accum_max_offset /= n_processed + accum_bos /= n_processed + + # Find the peak offset for each head + peak_offsets = accum_profiles[:, :, 1:].argmax(axis=2) + 1 # offset >= 1 + + logger.info(f"Positional head scores (n={n_processed} batches):") + logger.info(" Max-offset score | Peak offset | BOS score") + for layer_idx in range(n_layers): + for h in range(n_heads): + max_val = accum_max_offset[layer_idx, h] + peak = peak_offsets[layer_idx, h] + bos_val = accum_bos[layer_idx, h] + marker = "" + if max_val > 0.3: + marker = f" <-- positional head (offset={peak})" + elif bos_val > 0.1: + marker = " <-- BOS head" + logger.info( + f" L{layer_idx}H{h}: max={max_val:.4f} (peak@{peak}) bos={bos_val:.4f}{marker}" + ) + + _plot_score_heatmap( + accum_max_offset, + run_id, + n_processed, + out_dir / "positional_max_offset_scores.png", + "Max-offset positional scores", + "Max mean attn at any offset", + ) + _plot_score_heatmap( + accum_bos, + run_id, + n_processed, + out_dir / "bos_attention_scores.png", + "BOS attention scores", + "Mean attn to position 0", + ) + _plot_positional_profiles( + accum_profiles, run_id, n_processed, out_dir / "positional_profiles.png" + ) + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_positional_heads) diff --git a/spd/scripts/detect_prev_token_heads/__init__.py b/spd/scripts/detect_prev_token_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_prev_token_heads/detect_prev_token_heads.py b/spd/scripts/detect_prev_token_heads/detect_prev_token_heads.py new file mode 100644 index 000000000..652dfd5ba --- /dev/null +++ b/spd/scripts/detect_prev_token_heads/detect_prev_token_heads.py @@ -0,0 +1,207 @@ +"""Detect previous-token attention heads by measuring mean attention to position i-1. + +For each layer and head, computes the average attention weight from position i to +position i-1 across many data batches. Heads with a high score consistently attend +to the previous token, a key building block of induction circuits. + +Usage: + python -m spd.scripts.detect_prev_token_heads.detect_prev_token_heads \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.configs import LMTaskConfig +from spd.data import DatasetConfig, create_data_loader +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_BATCHES = 100 +BATCH_SIZE = 32 + + +def _plot_score_heatmap( + scores: NDArray[np.floating], + run_id: str, + n_samples: int, + out_path: Path, +) -> None: + n_layers, n_heads = scores.shape + fig, ax = plt.subplots(figsize=(max(6, n_heads * 1.2), max(4, n_layers * 1.0))) + + im = ax.imshow(scores, aspect="auto", cmap="Blues", vmin=0, vmax=0.95) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label="Mean attn to pos i-1") + + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=10) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=10) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + for layer_idx in range(n_layers): + for h in range(n_heads): + val = scores[layer_idx, h] + text_color = "white" if val > 0.65 else "black" + ax.text( + h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=9, color=text_color + ) + + fig.suptitle( + f"{run_id} | Previous-token head scores (n={n_samples} batches)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def _plot_attention_patterns( + patterns: list[torch.Tensor], + run_id: str, + title: str, + out_path: Path, + max_pos: int = 128, +) -> None: + """Plot grid of attention patterns (one per head, truncated to max_pos).""" + n_layers = len(patterns) + n_heads = patterns[0].shape[0] + + fig, axes = plt.subplots( + n_layers, + n_heads, + figsize=(n_heads * 3, n_layers * 3), + squeeze=False, + ) + + for layer_idx in range(n_layers): + for h in range(n_heads): + ax = axes[layer_idx, h] + pattern = patterns[layer_idx][h, :max_pos, :max_pos].numpy() + ax.imshow(pattern, aspect="auto", cmap="viridis", vmin=0) + ax.set_title(f"L{layer_idx}H{h}", fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + if h == 0: + ax.set_ylabel(f"Layer {layer_idx}", fontsize=9) + + fig.suptitle( + f"{run_id} | {title} (pos 0-{max_pos})", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_prev_token_heads(wandb_path: ModelPath, n_batches: int = N_BATCHES) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + seq_len = target_model.config.n_ctx + logger.info(f"Model: {n_layers} layers, {n_heads} heads, seq_len={seq_len}") + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=BATCH_SIZE, + buffer_size=1000, + ) + + accum_scores = np.zeros((n_layers, n_heads)) + accum_patterns = [torch.zeros(n_heads, seq_len, seq_len) for _ in range(n_layers)] + single_patterns: list[torch.Tensor] | None = None + n_processed = 0 + + with torch.no_grad(): + for i, batch in enumerate(loader): + if i >= n_batches: + break + input_ids = batch[task_config.column_name][:, :seq_len].to(device) + patterns = collect_attention_patterns(target_model, input_ids) + + if i == 0: + single_patterns = [att[0].float().cpu() for att in patterns] + + for layer_idx, att in enumerate(patterns): + diag = torch.diagonal(att, offset=-1, dim1=-2, dim2=-1) # (batch, heads, T-1) + accum_scores[layer_idx] += diag.float().mean(dim=(0, 2)).cpu().numpy() + accum_patterns[layer_idx] += att.float().mean(dim=0).cpu() + + n_processed += 1 + if (i + 1) % 25 == 0: + logger.info(f"Processed {i + 1}/{n_batches} batches") + + assert n_processed > 0 + accum_scores /= n_processed + for layer_idx in range(n_layers): + accum_patterns[layer_idx] /= n_processed + + logger.info(f"Previous-token scores (n={n_processed} batches):") + for layer_idx in range(n_layers): + for h in range(n_heads): + score = accum_scores[layer_idx, h] + marker = " <-- prev-token head" if score > 0.3 else "" + logger.info(f" L{layer_idx}H{h}: {score:.4f}{marker}") + + _plot_score_heatmap(accum_scores, run_id, n_processed, out_dir / "prev_token_scores.png") + _plot_attention_patterns( + accum_patterns, + run_id, + f"Mean attention patterns (n={n_processed})", + out_dir / "mean_attention_patterns.png", + ) + assert single_patterns is not None + _plot_attention_patterns( + single_patterns, + run_id, + "Single-datapoint attention patterns", + out_dir / "single_attention_patterns.png", + ) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_prev_token_heads) diff --git a/spd/scripts/detect_prev_token_heads/detect_prev_token_heads_random_tokens.py b/spd/scripts/detect_prev_token_heads/detect_prev_token_heads_random_tokens.py new file mode 100644 index 000000000..d0de6a15c --- /dev/null +++ b/spd/scripts/detect_prev_token_heads/detect_prev_token_heads_random_tokens.py @@ -0,0 +1,114 @@ +"""Detect previous-token attention heads using random token sequences. + +Same analysis as detect_prev_token_heads but with random (uniform) token IDs instead +of real text. Heads that score high here attend to the previous position regardless +of token content, indicating purely positional attention behavior. + +Usage: + python -m spd.scripts.detect_prev_token_heads.detect_prev_token_heads_random_tokens \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import numpy as np +import torch + +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.scripts.detect_prev_token_heads.detect_prev_token_heads import ( + _plot_attention_patterns, + _plot_score_heatmap, +) +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_BATCHES = 100 +BATCH_SIZE = 32 + + +def detect_prev_token_heads_random_tokens( + wandb_path: ModelPath, n_batches: int = N_BATCHES +) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + seq_len = target_model.config.n_ctx + vocab_size = target_model.config.vocab_size + logger.info(f"Model: {n_layers} layers, {n_heads} heads, seq_len={seq_len}, vocab={vocab_size}") + + accum_scores = np.zeros((n_layers, n_heads)) + accum_patterns = [torch.zeros(n_heads, seq_len, seq_len) for _ in range(n_layers)] + single_patterns: list[torch.Tensor] | None = None + + with torch.no_grad(): + for i in range(n_batches): + input_ids = torch.randint(0, vocab_size, (BATCH_SIZE, seq_len), device=device) + patterns = collect_attention_patterns(target_model, input_ids) + + if i == 0: + single_patterns = [att[0].float().cpu() for att in patterns] + + for layer_idx, att in enumerate(patterns): + diag = torch.diagonal(att, offset=-1, dim1=-2, dim2=-1) + accum_scores[layer_idx] += diag.float().mean(dim=(0, 2)).cpu().numpy() + accum_patterns[layer_idx] += att.float().mean(dim=0).cpu() + + if (i + 1) % 25 == 0: + logger.info(f"Processed {i + 1}/{n_batches} batches") + + accum_scores /= n_batches + for layer_idx in range(n_layers): + accum_patterns[layer_idx] /= n_batches + + logger.info(f"Previous-token scores on random tokens (n={n_batches} batches):") + for layer_idx in range(n_layers): + for h in range(n_heads): + score = accum_scores[layer_idx, h] + marker = " <-- prev-token head" if score > 0.3 else "" + logger.info(f" L{layer_idx}H{h}: {score:.4f}{marker}") + + _plot_score_heatmap( + accum_scores, + run_id, + n_batches, + out_dir / "prev_token_scores_random_tokens.png", + ) + _plot_attention_patterns( + accum_patterns, + run_id, + f"Mean attention patterns, random tokens (n={n_batches})", + out_dir / "mean_attention_patterns_random_tokens.png", + ) + assert single_patterns is not None + _plot_attention_patterns( + single_patterns, + run_id, + "Single-datapoint attention patterns, random tokens", + out_dir / "single_attention_patterns_random_tokens.png", + ) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_prev_token_heads_random_tokens) diff --git a/spd/scripts/detect_s_inhibition_heads/__init__.py b/spd/scripts/detect_s_inhibition_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_s_inhibition_heads/detect_s_inhibition_heads.py b/spd/scripts/detect_s_inhibition_heads/detect_s_inhibition_heads.py new file mode 100644 index 000000000..f185f3d84 --- /dev/null +++ b/spd/scripts/detect_s_inhibition_heads/detect_s_inhibition_heads.py @@ -0,0 +1,272 @@ +"""Detect S-inhibition heads using IOI (Indirect Object Identification) prompts. + +S-inhibition heads attend from the end of IOI sentences to the repeated subject +(S2) and suppress it from being predicted. We measure two signals: + 1. Attention from the final position to the S2 position (data-driven) + 2. OV copy score: whether the head's OV circuit promotes or suppresses the + attended token's logit (weight-based, negative = inhibition) + +Prompts follow the IOI pattern: + "When [IO] and [S] went to the store, [S] gave a drink to" -> answer: [IO] + +Usage: + python -m spd.scripts.detect_s_inhibition_heads.detect_s_inhibition_heads \ + wandb:goodfire/spd/runs/ +""" + +import random +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray +from transformers import AutoTokenizer + +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_PROMPTS = 500 + +TEMPLATE = "When{io} and{s} went to the store,{s} gave a drink to" + +CANDIDATE_NAMES = [ + " Alice", " Bob", " Mary", " John", " Tom", " Sam", " Dan", " Jim", " Amy", + " Eve", " Max", " Ben", " Ann", " Joe", " Kate", " Bill", " Jack", " Mark", + " Paul", " Dave", " Luke", " Jill", " Brad", " Emma", " Alex", " Ryan", + " Meg", " Zoe", " Beth", " Fred", +] # fmt: skip + + +def _get_single_token_names( + tokenizer: AutoTokenizer, +) -> list[tuple[str, int]]: + """Return (name_string, token_id) for names that encode as a single token.""" + valid = [] + for name in CANDIDATE_NAMES: + token_ids = tokenizer.encode(name) + if len(token_ids) == 1: + valid.append((name, token_ids[0])) + return valid + + +def _create_ioi_batch( + tokenizer: AutoTokenizer, + names: list[tuple[str, int]], + batch_size: int, +) -> tuple[torch.Tensor, list[int], list[int], list[int], list[int]]: + """Create a batch of IOI prompts. + + Returns (input_ids, s2_positions, end_positions, io_token_ids, s_token_ids). + All prompts use the same template, so they have identical token counts. + """ + all_tokens: list[list[int]] = [] + s2_positions: list[int] = [] + end_positions: list[int] = [] + io_token_ids: list[int] = [] + s_token_ids: list[int] = [] + + for _ in range(batch_size): + (io_name, io_tid), (s_name, s_tid) = random.sample(names, 2) + text = TEMPLATE.format(io=io_name, s=s_name) + tokens = tokenizer.encode(text) + + s_positions = [idx for idx, t in enumerate(tokens) if t == s_tid] + assert len(s_positions) == 2, ( + f"Expected 2 occurrences of '{s_name}', got {len(s_positions)}" + ) + + all_tokens.append(tokens) + s2_positions.append(s_positions[1]) + end_positions.append(len(tokens) - 1) + io_token_ids.append(io_tid) + s_token_ids.append(s_tid) + + # All prompts should be the same length (same template, single-token names) + assert all(len(t) == len(all_tokens[0]) for t in all_tokens) + + return ( + torch.tensor(all_tokens), + s2_positions, + end_positions, + io_token_ids, + s_token_ids, + ) + + +def _compute_ov_copy_scores( + model: LlamaSimpleMLP, + name_token_ids: list[int], +) -> NDArray[np.floating]: + """Compute per-head OV copy score averaged over name tokens. + + copy_score(h, t) = W_U[t] @ W_O_h @ W_V_h @ W_E[t] + Positive means the head promotes token t when attending to it (copying). + Negative means the head suppresses it (inhibition). + + Returns shape (n_layers, n_heads). + """ + W_E = model.wte.weight.float() # (vocab, d_model) + W_U = model.lm_head.weight.float() # (vocab, d_model) — tied with W_E + + n_layers = len(model._h) + head_dim = model._h[0].attn.head_dim + n_heads = model._h[0].attn.n_head + scores = np.zeros((n_layers, n_heads)) + + name_embeds = W_E[name_token_ids] # (n_names, d_model) + name_unembed = W_U[name_token_ids] # (n_names, d_model) + + for layer_idx, block in enumerate(model._h): + attn = block.attn + W_V = attn.v_proj.weight.float() # (n_kv_heads * head_dim, d_model) + W_O = attn.o_proj.weight.float() # (d_model, d_model) + + for h in range(n_heads): + kv_idx = (h // attn.repeat_kv_heads) * head_dim + q_idx = h * head_dim + W_V_h = W_V[kv_idx : kv_idx + head_dim, :] # (head_dim, d_model) + W_O_h = W_O[:, q_idx : q_idx + head_dim] # (d_model, head_dim) + W_OV_h = W_O_h @ W_V_h # (d_model, d_model) + + # copy_score for each name: unembed[t] @ W_OV @ embed[t] + ov_output = name_embeds @ W_OV_h.T # (n_names, d_model) + copy = (name_unembed * ov_output).sum(dim=-1) # (n_names,) + scores[layer_idx, h] = copy.mean().item() + + return scores + + +def _plot_dual_heatmap( + attn_scores: NDArray[np.floating], + copy_scores: NDArray[np.floating], + run_id: str, + n_prompts: int, + out_path: Path, +) -> None: + n_layers, n_heads = attn_scores.shape + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(max(12, n_heads * 2.4), max(4, n_layers * 1.0))) + + # Left: attention to S2 + im1 = ax1.imshow(attn_scores, aspect="auto", cmap="viridis", vmin=0) + fig.colorbar(im1, ax=ax1, shrink=0.8, pad=0.02, label="Attn to S2") + ax1.set_title("Attention to S2 from end", fontsize=11, fontweight="bold") + + # Right: OV copy scores + vabs = max(abs(copy_scores.min()), abs(copy_scores.max())) or 1.0 + im2 = ax2.imshow(copy_scores, aspect="auto", cmap="RdBu_r", vmin=-vabs, vmax=vabs) + fig.colorbar(im2, ax=ax2, shrink=0.8, pad=0.02, label="Copy score") + ax2.set_title("OV copy score (neg = inhibit)", fontsize=11, fontweight="bold") + + for ax in (ax1, ax2): + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=10) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=10) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + data_sets = [attn_scores, copy_scores] + for ax, data in zip([ax1, ax2], data_sets, strict=True): + for layer_idx in range(n_layers): + for h in range(n_heads): + val = data[layer_idx, h] + threshold = abs(data).max() * 0.6 + color = "white" if abs(val) < threshold else "black" + ax.text( + h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=8, color=color + ) + + fig.suptitle( + f"{run_id} | S-inhibition analysis (n={n_prompts} IOI prompts)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_s_inhibition_heads( + wandb_path: ModelPath, n_prompts: int = N_PROMPTS, batch_size: int = 50 +) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + names = _get_single_token_names(tokenizer) + assert len(names) >= 4, f"Need at least 4 single-token names, got {len(names)}" + logger.info(f"Found {len(names)} single-token names: {[n for n, _ in names]}") + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + logger.info(f"Model: {n_layers} layers, {n_heads} heads") + + # Weight-based OV copy scores + all_name_tids = [tid for _, tid in names] + copy_scores = _compute_ov_copy_scores(target_model, all_name_tids) + logger.info("OV copy scores computed") + + # Data-driven attention to S2 + accum_attn_to_s2 = np.zeros((n_layers, n_heads)) + n_processed = 0 + + with torch.no_grad(): + for start in range(0, n_prompts, batch_size): + bs = min(batch_size, n_prompts - start) + input_ids, s2_positions, end_positions, _, _ = _create_ioi_batch(tokenizer, names, bs) + input_ids = input_ids.to(device) + patterns = collect_attention_patterns(target_model, input_ids) + + for layer_idx, att in enumerate(patterns): + for b in range(bs): + # Attention from end position to S2 position, per head + accum_attn_to_s2[layer_idx] += ( + att[b, :, end_positions[b], s2_positions[b]].float().cpu().numpy() + ) + + n_processed += bs + logger.info(f"Processed {n_processed}/{n_prompts} IOI prompts") + + assert n_processed > 0 + accum_attn_to_s2 /= n_processed + + logger.info(f"S-inhibition analysis (n={n_processed} IOI prompts):") + logger.info(" Attention to S2 | OV copy score") + for layer_idx in range(n_layers): + for h in range(n_heads): + attn_val = accum_attn_to_s2[layer_idx, h] + copy_val = copy_scores[layer_idx, h] + marker = "" + if attn_val > 0.1 and copy_val < 0: + marker = " <-- S-inhibition candidate" + logger.info(f" L{layer_idx}H{h}: attn={attn_val:.4f} copy={copy_val:.4f}{marker}") + + _plot_dual_heatmap( + accum_attn_to_s2, copy_scores, run_id, n_processed, out_dir / "s_inhibition_scores.png" + ) + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_s_inhibition_heads) diff --git a/spd/scripts/detect_successor_heads/__init__.py b/spd/scripts/detect_successor_heads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/detect_successor_heads/detect_successor_heads.py b/spd/scripts/detect_successor_heads/detect_successor_heads.py new file mode 100644 index 000000000..50261a110 --- /dev/null +++ b/spd/scripts/detect_successor_heads/detect_successor_heads.py @@ -0,0 +1,263 @@ +"""Detect successor heads by measuring attention to ordinal predecessors. + +Creates comma-separated ordinal sequences (digits, letters, number words, days) +and measures how much each head attends from element[k] to element[k-1]. Since +elements are separated by commas, the predecessor is 2 positions back -- not the +immediately preceding token -- which separates this signal from previous-token +heads. + +A random-word control measures the same positional attention on non-ordinal +sequences. The successor-specific signal is ordinal_score - control_score. + +Usage: + python -m spd.scripts.detect_successor_heads.detect_successor_heads \ + wandb:goodfire/spd/runs/ +""" + +import random +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray +from transformers import AutoTokenizer + +from spd.log import logger +from spd.models.component_model import SPDRunInfo +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.collect_attention_patterns import collect_attention_patterns +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +N_PROMPTS = 500 + +ORDINAL_SEQUENCES: list[list[str]] = [ + [f" {i}" for i in range(10)], + [f" {chr(i)}" for i in range(ord("A"), ord("Z") + 1)], + [" one", " two", " three", " four", " five", " six", " seven", " eight", " nine", " ten"], + [" Monday", " Tuesday", " Wednesday", " Thursday", " Friday", " Saturday", " Sunday"], + [" January", " February", " March", " April", " May", " June", + " July", " August", " September", " October", " November", " December"], +] # fmt: skip + +RANDOM_WORDS = [ + " cat", " dog", " red", " big", " cup", " hat", " sun", " box", " pen", " map", + " key", " bag", " top", " old", " hot", " new", " run", " sit", " eat", " fly", + " car", " bus", " bed", " arm", " egg", " ice", " oil", " tea", " war", " sky", +] # fmt: skip + + +def _filter_single_token(elements: list[str], tokenizer: AutoTokenizer) -> list[tuple[str, int]]: + """Return (element_string, token_id) for elements that are single tokens.""" + valid = [] + for elem in elements: + token_ids = tokenizer.encode(elem) + if len(token_ids) == 1: + valid.append((elem, token_ids[0])) + return valid + + +def _build_comma_sequence( + elements: list[tuple[str, int]], comma_tid: int +) -> tuple[list[int], list[int]]: + """Build token sequence: elem0, comma, elem1, comma, ... + + Returns (token_ids, element_positions). + """ + tokens: list[int] = [] + positions: list[int] = [] + for idx, (_, tid) in enumerate(elements): + if idx > 0: + tokens.append(comma_tid) + positions.append(len(tokens)) + tokens.append(tid) + return tokens, positions + + +def _generate_prompts( + tokenizer: AutoTokenizer, + ordinal: bool, + n_prompts: int, +) -> list[tuple[list[int], list[int]]]: + """Generate comma-separated sequence prompts. + + Returns list of (token_ids, element_positions). + For ordinal=True, uses ordinal sequences. For ordinal=False, uses random words. + """ + comma_tids = tokenizer.encode(",") + assert len(comma_tids) == 1 + comma_tid = comma_tids[0] + + if ordinal: + valid_sequences = [] + for seq in ORDINAL_SEQUENCES: + valid = _filter_single_token(seq, tokenizer) + if len(valid) >= 3: + valid_sequences.append(valid) + logger.info( + f" Ordinal sequence: {len(valid)} elements ({valid[0][0].strip()}, ...)" + ) + assert valid_sequences, "No valid ordinal sequences found" + else: + random_valid = _filter_single_token(RANDOM_WORDS, tokenizer) + assert len(random_valid) >= 5, f"Need at least 5 random words, got {len(random_valid)}" + + prompts: list[tuple[list[int], list[int]]] = [] + for _ in range(n_prompts): + if ordinal: + seq = random.choice(valid_sequences) + min_len = 3 + max_len = min(len(seq), 12) + subseq_len = random.randint(min_len, max_len) + start = random.randint(0, len(seq) - subseq_len) + elements = seq[start : start + subseq_len] + else: + subseq_len = random.randint(3, 12) + elements = random.sample(random_valid, min(subseq_len, len(random_valid))) + + tokens, positions = _build_comma_sequence(elements, comma_tid) + prompts.append((tokens, positions)) + + return prompts + + +def _compute_predecessor_scores( + model: LlamaSimpleMLP, + prompts: list[tuple[list[int], list[int]]], + device: torch.device, +) -> NDArray[np.floating]: + """Compute mean attention from element[k] to element[k-1] for each head. + + Returns shape (n_layers, n_heads). + """ + n_layers = len(model._h) + n_heads = model._h[0].attn.n_head + accum = np.zeros((n_layers, n_heads)) + n_pairs = 0 + + with torch.no_grad(): + for tokens, positions in prompts: + input_ids = torch.tensor([tokens], device=device) + patterns = collect_attention_patterns(model, input_ids) + + for k in range(1, len(positions)): + dst_pos = positions[k] + src_pos = positions[k - 1] + for layer_idx, att in enumerate(patterns): + accum[layer_idx] += att[0, :, dst_pos, src_pos].float().cpu().numpy() + n_pairs += 1 + + assert n_pairs > 0 + return accum / n_pairs + + +def _plot_triple_heatmap( + ordinal_scores: NDArray[np.floating], + control_scores: NDArray[np.floating], + run_id: str, + n_prompts: int, + out_path: Path, +) -> None: + successor_signal = ordinal_scores - control_scores + n_layers, n_heads = ordinal_scores.shape + fig, axes = plt.subplots(1, 3, figsize=(max(18, n_heads * 3.6), max(4, n_layers * 1.0))) + + titles = ["Ordinal predecessor attn", "Random-word control", "Successor signal (diff)"] + data_list = [ordinal_scores, control_scores, successor_signal] + cmaps = ["viridis", "viridis", "RdBu_r"] + + for ax, data, title, cmap in zip(axes, data_list, titles, cmaps, strict=True): + if cmap == "RdBu_r": + vabs = max(abs(data.min()), abs(data.max())) or 1.0 + im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=-vabs, vmax=vabs) + else: + im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=0) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) + ax.set_title(title, fontsize=10, fontweight="bold") + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=9) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels([f"L{li}" for li in range(n_layers)], fontsize=9) + ax.set_xlabel("Head") + ax.set_ylabel("Layer") + + for layer_idx in range(n_layers): + for h in range(n_heads): + val = data[layer_idx, h] + threshold = max(abs(data).max() * 0.6, 1e-6) + color = "white" if abs(val) < threshold else "black" + ax.text( + h, layer_idx, f"{val:.3f}", ha="center", va="center", fontsize=8, color=color + ) + + fig.suptitle( + f"{run_id} | Successor head analysis (n={n_prompts} prompts per condition)", + fontsize=13, + fontweight="bold", + ) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {out_path}") + + +def detect_successor_heads(wandb_path: ModelPath, n_prompts: int = N_PROMPTS) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + config = run_info.config + target_model = LlamaSimpleMLP.from_pretrained(config.pretrained_model_name) + target_model.eval() + + for block in target_model._h: + block.attn.flash_attention = False + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + target_model = target_model.to(device) + + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + + n_layers = len(target_model._h) + n_heads = target_model._h[0].attn.n_head + logger.info(f"Model: {n_layers} layers, {n_heads} heads") + + logger.info("Generating ordinal prompts...") + ordinal_prompts = _generate_prompts(tokenizer, ordinal=True, n_prompts=n_prompts) + logger.info(f"Generated {len(ordinal_prompts)} ordinal prompts") + + logger.info("Generating random-word control prompts...") + control_prompts = _generate_prompts(tokenizer, ordinal=False, n_prompts=n_prompts) + logger.info(f"Generated {len(control_prompts)} control prompts") + + logger.info("Computing ordinal predecessor scores...") + ordinal_scores = _compute_predecessor_scores(target_model, ordinal_prompts, device) + + logger.info("Computing control predecessor scores...") + control_scores = _compute_predecessor_scores(target_model, control_prompts, device) + + successor_signal = ordinal_scores - control_scores + + logger.info(f"Successor head analysis (n={n_prompts} prompts per condition):") + logger.info(" Ordinal | Control | Signal") + for layer_idx in range(n_layers): + for h in range(n_heads): + o_val = ordinal_scores[layer_idx, h] + c_val = control_scores[layer_idx, h] + s_val = successor_signal[layer_idx, h] + marker = " <-- successor head" if s_val > 0.05 else "" + logger.info(f" L{layer_idx}H{h}: {o_val:.4f} {c_val:.4f} {s_val:+.4f}{marker}") + + _plot_triple_heatmap( + ordinal_scores, control_scores, run_id, n_prompts, out_dir / "successor_scores.png" + ) + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(detect_successor_heads) diff --git a/spd/scripts/plot_attention_weights/__init__.py b/spd/scripts/plot_attention_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_attention_weights/plot_attention_weights.py b/spd/scripts/plot_attention_weights/plot_attention_weights.py new file mode 100644 index 000000000..1133df86e --- /dev/null +++ b/spd/scripts/plot_attention_weights/plot_attention_weights.py @@ -0,0 +1,206 @@ +"""Plot target vs reconstructed attention weight matrices from an SPD run. + +For each layer and attention projection (q/k/v/o_proj), produces multiple 4x4 grids. +Each grid shows the target weight, reconstructed (UV^T), and 14 subcomponent weights. +Successive grids page through all alive components (ranked by mean_ci descending). +All grids for a given layer/projection share the same color scale. + +Alive components are determined from harvest data (mean_ci > 0). + +Usage: + python -m spd.scripts.plot_attention_weights.plot_attention_weights wandb:goodfire/spd/runs/ +""" + +import math +from datetime import datetime +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.axes import Axes +from numpy.typing import NDArray +from torch import Tensor + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +PROJ_NAMES = ("q_proj", "k_proj", "v_proj", "o_proj") +GRID_SIZE = 4 +COMPONENTS_PER_PAGE = GRID_SIZE * GRID_SIZE - 2 # 14 (2 cells for target + recon) + + +def _get_alive_indices(summary: dict[str, ComponentSummary], module_path: str) -> list[int]: + """Return component indices for a module sorted by CI descending, filtered to alive.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > 0 + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _add_head_lines(ax: Axes, proj_name: str, head_dim: int, shape: tuple[int, int]) -> None: + n_rows, n_cols = shape + + if proj_name in ("q_proj", "k_proj", "v_proj"): + n_heads = n_rows // head_dim + for i in range(1, n_heads): + ax.axhline( + y=i * head_dim - 0.5, color="black", linewidth=0.5, linestyle="--", alpha=0.6 + ) + else: + n_heads = n_cols // head_dim + for i in range(1, n_heads): + ax.axvline( + x=i * head_dim - 0.5, color="black", linewidth=0.5, linestyle="--", alpha=0.6 + ) + + +def _style_ax( + fig: plt.Figure, + ax: Axes, + data: NDArray[np.floating], + vmax: float, + title: str, + proj_name: str, + head_dim: int, +) -> None: + im = ax.imshow(data, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + ax.set_title(title, fontsize=10) + ax.set_yticks([]) + ax.set_xticks([]) + _add_head_lines(ax, proj_name, head_dim, data.shape) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) + + +def _component_weight(component: LinearComponents, idx: int) -> Tensor: + """Weight matrix for a single subcomponent: outer(U[c], V[:, c]) -> (d_out, d_in).""" + return torch.outer(component.U[idx], component.V[:, idx]) + + +def _plot_grids( + target_weight: Tensor, + recon_weight: Tensor, + component: LinearComponents, + alive_indices: list[int], + layer_idx: int, + proj_name: str, + head_dim: int, + run_id: str, + timestamp: str, + out_dir: Path, +) -> None: + target_np: NDArray[np.floating] = target_weight.float().cpu().numpy() + recon_np: NDArray[np.floating] = recon_weight.float().cpu().numpy() + + all_sub_weights = [_component_weight(component, i).float().cpu().numpy() for i in alive_indices] + + vmax = float(max(abs(target_np).max(), abs(recon_np).max())) + if all_sub_weights: + vmax = max(vmax, float(max(abs(w).max() for w in all_sub_weights))) + + n_pages = max(1, math.ceil(len(alive_indices) / COMPONENTS_PER_PAGE)) + + for page in range(n_pages): + start = page * COMPONENTS_PER_PAGE + page_indices = alive_indices[start : start + COMPONENTS_PER_PAGE] + page_weights = all_sub_weights[start : start + COMPONENTS_PER_PAGE] + + fig, axes = plt.subplots(GRID_SIZE, GRID_SIZE, figsize=(24, 22)) + + _style_ax(fig, axes[0, 0], target_np, vmax, "Target weight", proj_name, head_dim) + _style_ax(fig, axes[0, 1], recon_np, vmax, "Reconstructed (UV\u1d40)", proj_name, head_dim) + + cell = 0 + for r in range(GRID_SIZE): + for c in range(GRID_SIZE): + if r == 0 and c < 2: + continue + ax = axes[r, c] + if cell < len(page_indices): + c_idx = page_indices[cell] + _style_ax(fig, ax, page_weights[cell], vmax, f"C{c_idx}", proj_name, head_dim) + else: + ax.set_visible(False) + cell += 1 + + page_label = f"({page + 1}/{n_pages})" if n_pages > 1 else "" + fig.suptitle( + f"{run_id} | Layer {layer_idx} — {proj_name} {page_label}\n{timestamp}", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(hspace=0.25, wspace=0.15) + + suffix = f"_p{page + 1}" if n_pages > 1 else "" + path = out_dir / f"layer{layer_idx}_{proj_name}{suffix}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_attention_weights(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + blocks = target_model._h + head_dim = blocks[0].attn.head_dim + n_layers = len(blocks) + logger.info(f"Model: {n_layers} layers, head_dim={head_dim}") + + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + with torch.no_grad(): + for layer_idx in range(n_layers): + attn = blocks[layer_idx].attn + for proj_name in PROJ_NAMES: + module_path = f"h.{layer_idx}.attn.{proj_name}" + + target_weight = getattr(attn, proj_name).weight + component = model.components[module_path] + assert isinstance(component, LinearComponents) + recon_weight = component.weight + + alive_indices = _get_alive_indices(summary, module_path) + logger.info(f"{module_path}: {len(alive_indices)} alive components") + + _plot_grids( + target_weight, + recon_weight, + component, + alive_indices, + layer_idx, + proj_name, + head_dim, + run_id, + timestamp, + out_dir, + ) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_attention_weights) diff --git a/spd/scripts/plot_component_head_norms/__init__.py b/spd/scripts/plot_component_head_norms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_component_head_norms/plot_component_head_norms.py b/spd/scripts/plot_component_head_norms/plot_component_head_norms.py new file mode 100644 index 000000000..5ff4ebf6e --- /dev/null +++ b/spd/scripts/plot_component_head_norms/plot_component_head_norms.py @@ -0,0 +1,161 @@ +"""Plot subcomponent-to-head Frobenius norm heatmaps from an SPD run. + +For each layer and attention projection (q/k/v/o_proj), produces a heatmap where: + - x-axis: attention head index + - y-axis: alive subcomponents (sorted by mean_ci descending, top N) + - color: Frobenius norm of the subcomponent's weight slice for that head + +Head slicing: + - q/k/v_proj: heads partition rows (output dim). Head h = rows [h*head_dim : (h+1)*head_dim]. + - o_proj: heads partition columns (input dim). Head h = cols [h*head_dim : (h+1)*head_dim]. + +Usage: + python -m spd.scripts.plot_component_head_norms.plot_component_head_norms wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +PROJ_NAMES = ("q_proj", "k_proj", "v_proj", "o_proj") +MIN_MEAN_CI = 0.01 + + +def _get_alive_indices( + summary: dict[str, ComponentSummary], module_path: str, min_mean_ci: float +) -> list[int]: + """Return component indices for a module sorted by CI descending, filtered by threshold.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > min_mean_ci + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _component_weight(component: LinearComponents, idx: int) -> torch.Tensor: + """Weight matrix for a single subcomponent: outer(U[c], V[:, c]) -> (d_out, d_in).""" + return torch.outer(component.U[idx], component.V[:, idx]) + + +def _head_norms( + component: LinearComponents, + alive_indices: list[int], + proj_name: str, + head_dim: int, + n_heads: int, +) -> NDArray[np.floating]: + """Compute (n_alive, n_heads) array of Frobenius norms per subcomponent per head.""" + norms = np.zeros((len(alive_indices), n_heads), dtype=np.float32) + for row, c_idx in enumerate(alive_indices): + w = _component_weight(component, c_idx).float() # (d_out, d_in) + for h in range(n_heads): + if proj_name in ("q_proj", "k_proj", "v_proj"): + head_slice = w[h * head_dim : (h + 1) * head_dim, :] + else: + head_slice = w[:, h * head_dim : (h + 1) * head_dim] + norms[row, h] = torch.linalg.norm(head_slice).item() + return norms + + +def _plot_heatmap( + norms: NDArray[np.floating], + alive_indices: list[int], + n_heads: int, + layer_idx: int, + proj_name: str, + run_id: str, + out_dir: Path, +) -> None: + fig, ax = plt.subplots(figsize=(max(8, n_heads * 0.8), max(6, len(alive_indices) * 0.25))) + + im = ax.imshow(norms, aspect="auto", cmap="Purples", vmin=0) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label="Frobenius norm") + + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=8) + ax.set_xlabel("Head") + + ax.set_yticks(range(len(alive_indices))) + ax.set_yticklabels([f"C{idx}" for idx in alive_indices], fontsize=7) + ax.set_ylabel("Component (sorted by CI)") + + fig.suptitle(f"{run_id} | Layer {layer_idx} — {proj_name}", fontsize=14, fontweight="bold") + fig.subplots_adjust(left=0.12, right=0.95, top=0.93, bottom=0.08) + + path = out_dir / f"layer{layer_idx}_{proj_name}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_component_head_norms(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + blocks = target_model._h + head_dim = blocks[0].attn.head_dim + n_layers = len(blocks) + logger.info(f"Model: {n_layers} layers, head_dim={head_dim}") + + with torch.no_grad(): + for layer_idx in range(n_layers): + for proj_name in PROJ_NAMES: + module_path = f"h.{layer_idx}.attn.{proj_name}" + + component = model.components[module_path] + assert isinstance(component, LinearComponents) + + alive_indices = _get_alive_indices(summary, module_path, MIN_MEAN_CI) + logger.info( + f"{module_path}: {len(alive_indices)} components with mean_ci > {MIN_MEAN_CI}" + ) + + if proj_name in ("q_proj", "k_proj", "v_proj"): + n_heads = component.U.shape[1] // head_dim + else: + n_heads = component.V.shape[0] // head_dim + norms = _head_norms(component, alive_indices, proj_name, head_dim, n_heads) + + _plot_heatmap( + norms, + alive_indices, + n_heads, + layer_idx, + proj_name, + run_id, + out_dir, + ) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_component_head_norms) diff --git a/spd/scripts/plot_head_spread/__init__.py b/spd/scripts/plot_head_spread/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_head_spread/plot_head_spread.py b/spd/scripts/plot_head_spread/plot_head_spread.py new file mode 100644 index 000000000..830b2fcbb --- /dev/null +++ b/spd/scripts/plot_head_spread/plot_head_spread.py @@ -0,0 +1,181 @@ +"""Plot head-spread entropy histograms for subcomponents of an SPD run. + +For each layer, produces a figure with 4 subplots (q/k/v/o_proj). Each subplot is a bar chart: + - x-axis: subcomponent index (mean_ci > threshold) + - y-axis: head-spread entropy H = -sum(p_i * ln(p_i)), where p_i = norm_i / sum(norms) + +Higher entropy means the subcomponent's weight is spread across many heads; +lower entropy means it is concentrated in one or few heads. + +Usage: + python -m spd.scripts.plot_head_spread.plot_head_spread wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.cm as mpl_cm +import matplotlib.colors as mpl_colors +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +PROJ_NAMES = ("q_proj", "k_proj", "v_proj", "o_proj") +MIN_MEAN_CI = 0.01 + + +def _get_alive_indices( + summary: dict[str, ComponentSummary], module_path: str, min_mean_ci: float +) -> list[int]: + """Return component indices for a module sorted by CI descending, filtered by threshold.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > min_mean_ci + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _head_norms( + component: LinearComponents, + alive_indices: list[int], + proj_name: str, + head_dim: int, + n_heads: int, +) -> NDArray[np.floating]: + """Compute (n_alive, n_heads) array of Frobenius norms per subcomponent per head.""" + norms = np.zeros((len(alive_indices), n_heads), dtype=np.float32) + for row, c_idx in enumerate(alive_indices): + w = torch.outer(component.U[c_idx], component.V[:, c_idx]).float() + for h in range(n_heads): + if proj_name in ("q_proj", "k_proj", "v_proj"): + head_slice = w[h * head_dim : (h + 1) * head_dim, :] + else: + head_slice = w[:, h * head_dim : (h + 1) * head_dim] + norms[row, h] = torch.linalg.norm(head_slice).item() + return norms + + +def _total_norms(component: LinearComponents, alive_indices: list[int]) -> NDArray[np.floating]: + """Frobenius norm of each subcomponent's full weight matrix.""" + result = np.zeros(len(alive_indices), dtype=np.float32) + for row, c_idx in enumerate(alive_indices): + u_c = component.U[c_idx] + v_c = component.V[:, c_idx] + # ||u_c v_c^T||_F = ||u_c|| * ||v_c|| + result[row] = (torch.linalg.norm(u_c) * torch.linalg.norm(v_c)).item() + return result + + +def _entropy(norms: NDArray[np.floating]) -> NDArray[np.floating]: + """Compute head-spread entropy for each subcomponent from its per-head norms. + + For each row: p_i = norm_i / sum(norms), H = -sum(p_i * ln(p_i)). + """ + totals = norms.sum(axis=1, keepdims=True) + # Avoid division by zero for components with all-zero norms + totals = np.maximum(totals, 1e-12) + p = norms / totals + # Use 0 * log(0) = 0 convention + log_p = np.where(p > 0, np.log(p), 0.0) + return -(p * log_p).sum(axis=1) + + +def plot_head_spread(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + blocks = target_model._h + head_dim = blocks[0].attn.head_dim + n_layers = len(blocks) + logger.info(f"Model: {n_layers} layers, head_dim={head_dim}") + + cmap = plt.get_cmap("Purples") + + with torch.no_grad(): + for layer_idx in range(n_layers): + fig, axes = plt.subplots(4, 1, figsize=(14, 16)) + + for ax_idx, proj_name in enumerate(PROJ_NAMES): + ax = axes[ax_idx] + module_path = f"h.{layer_idx}.attn.{proj_name}" + + component = model.components[module_path] + assert isinstance(component, LinearComponents) + + alive_indices = _get_alive_indices(summary, module_path, MIN_MEAN_CI) + logger.info( + f"{module_path}: {len(alive_indices)} components with mean_ci > {MIN_MEAN_CI}" + ) + + if len(alive_indices) == 0: + ax.set_visible(False) + continue + + if proj_name in ("q_proj", "k_proj", "v_proj"): + n_heads = component.U.shape[1] // head_dim + else: + n_heads = component.V.shape[0] // head_dim + max_entropy = np.log(n_heads) + + norms = _head_norms(component, alive_indices, proj_name, head_dim, n_heads) + entropies = _entropy(norms) + total = _total_norms(component, alive_indices) + + norm_obj = mpl_colors.Normalize(vmin=0, vmax=total.max()) + colors = cmap(norm_obj(total)) + + x = np.arange(len(alive_indices)) + ax.bar(x, entropies, width=1.0, color=colors, edgecolor="none") + ax.axhline(y=max_entropy, color="grey", linestyle="--", linewidth=0.8, alpha=0.7) + ax.set_ylabel("Entropy (Param norm spread across heads)") + ax.set_title(proj_name, fontsize=11) + ax.set_ylim(0, max_entropy * 1.05) + + ax.set_xticks(x) + ax.set_xticklabels([f"C{idx}" for idx in alive_indices], fontsize=5, rotation=90) + + sm = mpl_cm.ScalarMappable(cmap=cmap, norm=norm_obj) + fig.colorbar(sm, ax=ax, shrink=0.8, pad=0.01, label="Total norm") + + fig.suptitle( + f"{run_id} | Layer {layer_idx} — Head-spread entropy", + fontsize=14, + fontweight="bold", + ) + fig.tight_layout(rect=(0, 0, 1, 0.96)) + + path = out_dir / f"layer{layer_idx}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_head_spread) diff --git a/spd/scripts/plot_kv_coactivation/__init__.py b/spd/scripts/plot_kv_coactivation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_kv_coactivation/plot_kv_coactivation.py b/spd/scripts/plot_kv_coactivation/plot_kv_coactivation.py new file mode 100644 index 000000000..7bc384e37 --- /dev/null +++ b/spd/scripts/plot_kv_coactivation/plot_kv_coactivation.py @@ -0,0 +1,290 @@ +"""Plot k-v component co-activation heatmaps from harvest co-occurrence data. + +For each layer, produces five heatmaps showing how k_proj and v_proj components +co-activate across the dataset: + - Raw co-occurrence count (how many tokens where both fired) + - Phi coefficient (correlation of binary firing indicators) + - Jaccard similarity (intersection over union of firing sets) + - P(V | K) conditional probability (fraction of K-active tokens where V is also active) + - P(K | V) conditional probability (fraction of V-active tokens where K is also active) + +All metrics are derived from the pre-computed CorrelationStorage in the harvest data. + +Usage: + python -m spd.scripts.plot_kv_coactivation.plot_kv_coactivation \ + wandb:goodfire/spd/runs/ +""" + +import re +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.harvest.storage import CorrelationStorage +from spd.log import logger +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +MIN_MEAN_CI = 0.01 + + +def _get_alive_indices(summary: dict[str, ComponentSummary], module_path: str) -> list[int]: + """Return component indices for a module sorted by CI descending, filtered by threshold.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > MIN_MEAN_CI + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _correlation_indices( + corr: CorrelationStorage, module_path: str, alive_indices: list[int] +) -> list[int]: + """Map (module_path, component_idx) pairs to indices in CorrelationStorage.""" + return [corr.key_to_idx[f"{module_path}:{idx}"] for idx in alive_indices] + + +def _compute_raw_cooccurrence( + count_ij: torch.Tensor, k_corr_idx: list[int], v_corr_idx: list[int] +) -> NDArray[np.floating]: + k_idx = torch.tensor(k_corr_idx) + v_idx = torch.tensor(v_corr_idx) + sub = count_ij[v_idx[:, None], k_idx[None, :]].float() + return sub.numpy() + + +def _compute_phi_coefficient( + count_ij: torch.Tensor, + count_i: torch.Tensor, + count_total: int, + k_corr_idx: list[int], + v_corr_idx: list[int], +) -> NDArray[np.floating]: + k_idx = torch.tensor(k_corr_idx) + v_idx = torch.tensor(v_corr_idx) + a = count_ij[v_idx[:, None], k_idx[None, :]].float() + n_k = count_i[k_idx].float() # (n_k_alive,) + n_v = count_i[v_idx].float() # (n_v_alive,) + n = float(count_total) + + numerator = n * a - n_v[:, None] * n_k[None, :] + denominator = torch.sqrt(n_v[:, None] * (n - n_v[:, None]) * n_k[None, :] * (n - n_k[None, :])) + phi = torch.where(denominator > 0, numerator / denominator, torch.zeros_like(a)) + return phi.numpy() + + +def _compute_jaccard( + count_ij: torch.Tensor, + count_i: torch.Tensor, + k_corr_idx: list[int], + v_corr_idx: list[int], +) -> NDArray[np.floating]: + k_idx = torch.tensor(k_corr_idx) + v_idx = torch.tensor(v_corr_idx) + intersection = count_ij[v_idx[:, None], k_idx[None, :]].float() + union = count_i[v_idx].float()[:, None] + count_i[k_idx].float()[None, :] - intersection + jaccard = torch.where(union > 0, intersection / union, torch.zeros_like(intersection)) + return jaccard.numpy() + + +def _compute_conditional_prob( + count_ij: torch.Tensor, + count_i: torch.Tensor, + k_corr_idx: list[int], + v_corr_idx: list[int], + condition_on: str, +) -> NDArray[np.floating]: + """P(V|K) when condition_on="k", P(K|V) when condition_on="v".""" + k_idx = torch.tensor(k_corr_idx) + v_idx = torch.tensor(v_corr_idx) + # count_ij[v, k] = number of tokens where both v and k are active + joint = count_ij[v_idx[:, None], k_idx[None, :]].float() + if condition_on == "k": + denom = count_i[k_idx].float()[None, :] + else: + denom = count_i[v_idx].float()[:, None] + return torch.where(denom > 0, joint / denom, torch.zeros_like(joint)).numpy() + + +def _plot_heatmap( + data: NDArray[np.floating], + k_alive: list[int], + v_alive: list[int], + layer_idx: int, + run_id: str, + metric_name: str, + cmap: str, + vmin: float | None, + vmax: float | None, + out_dir: Path, +) -> None: + n_v, n_k = data.shape + fig, ax = plt.subplots(figsize=(max(8, n_k * 0.25), max(6, n_v * 0.25))) + + im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label=metric_name) + + ax.set_xticks(range(n_k)) + ax.set_xticklabels([f"C{idx}" for idx in k_alive], fontsize=7, rotation=90) + ax.set_xlabel("k_proj component (sorted by CI)") + + ax.set_yticks(range(n_v)) + ax.set_yticklabels([f"C{idx}" for idx in v_alive], fontsize=7) + ax.set_ylabel("v_proj component (sorted by CI)") + + fig.suptitle( + f"{run_id} | Layer {layer_idx} — k/v {metric_name} (ci>{MIN_MEAN_CI})", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(left=0.12, right=0.95, top=0.93, bottom=0.12) + + path = out_dir / f"layer{layer_idx}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def _get_n_layers(summary: dict[str, ComponentSummary]) -> int: + """Infer number of layers from summary keys like 'h.0.attn.k_proj'.""" + layer_indices = { + int(m.group(1)) for s in summary.values() if (m := re.match(r"h\.(\d+)\.", s.layer)) + } + assert layer_indices, "No layer indices found in summary" + return max(layer_indices) + 1 + + +def plot_kv_coactivation(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + + out_base = SCRIPT_DIR / "out" / run_id + raw_dir = out_base / "ci_cooccurrence" + phi_dir = out_base / "phi_coefficient" + jaccard_dir = out_base / "jaccard" + p_v_given_k_dir = out_base / "p_v_given_k" + p_k_given_v_dir = out_base / "p_k_given_v" + for d in (raw_dir, phi_dir, jaccard_dir, p_v_given_k_dir, p_k_given_v_dir): + d.mkdir(parents=True, exist_ok=True) + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + corr = repo.get_correlations() + assert corr is not None, f"No correlation data found for {run_id}" + logger.info( + f"Loaded correlations: {len(corr.component_keys)} components, {corr.count_total} tokens" + ) + + n_layers = _get_n_layers(summary) + for layer_idx in range(n_layers): + k_path = f"h.{layer_idx}.attn.k_proj" + v_path = f"h.{layer_idx}.attn.v_proj" + + k_alive = _get_alive_indices(summary, k_path) + v_alive = _get_alive_indices(summary, v_path) + logger.info(f"Layer {layer_idx}: {len(k_alive)} k components, {len(v_alive)} v components") + + if not k_alive or not v_alive: + logger.info(f"Layer {layer_idx}: skipping (no alive k or v components)") + continue + + k_corr_idx = _correlation_indices(corr, k_path, k_alive) + v_corr_idx = _correlation_indices(corr, v_path, v_alive) + + # CI co-occurrence + raw = _compute_raw_cooccurrence(corr.count_ij, k_corr_idx, v_corr_idx) + _plot_heatmap( + raw, + k_alive, + v_alive, + layer_idx, + run_id, + "CI co-occurrence", + "Purples", + 0, + None, + raw_dir, + ) + + # Phi coefficient + phi = _compute_phi_coefficient( + corr.count_ij, corr.count_i, corr.count_total, k_corr_idx, v_corr_idx + ) + phi_abs_max = float(np.abs(phi).max()) or 1.0 + _plot_heatmap( + phi, + k_alive, + v_alive, + layer_idx, + run_id, + "phi coefficient", + "RdBu_r", + -phi_abs_max, + phi_abs_max, + phi_dir, + ) + + # Jaccard + jacc = _compute_jaccard(corr.count_ij, corr.count_i, k_corr_idx, v_corr_idx) + _plot_heatmap( + jacc, + k_alive, + v_alive, + layer_idx, + run_id, + "Jaccard similarity", + "Purples", + 0, + None, + jaccard_dir, + ) + + # P(V | K) + p_v_given_k = _compute_conditional_prob( + corr.count_ij, corr.count_i, k_corr_idx, v_corr_idx, condition_on="k" + ) + _plot_heatmap( + p_v_given_k, + k_alive, + v_alive, + layer_idx, + run_id, + "P(V | K)", + "Purples", + 0, + None, + p_v_given_k_dir, + ) + + # P(K | V) + p_k_given_v = _compute_conditional_prob( + corr.count_ij, corr.count_i, k_corr_idx, v_corr_idx, condition_on="v" + ) + _plot_heatmap( + p_k_given_v, + k_alive, + v_alive, + layer_idx, + run_id, + "P(K | V)", + "Purples", + 0, + None, + p_k_given_v_dir, + ) + + logger.info(f"All plots saved to {out_base}") + + +if __name__ == "__main__": + fire.Fire(plot_kv_coactivation) diff --git a/spd/scripts/plot_kv_vt_similarity/__init__.py b/spd/scripts/plot_kv_vt_similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_kv_vt_similarity/plot_kv_vt_similarity.py b/spd/scripts/plot_kv_vt_similarity/plot_kv_vt_similarity.py new file mode 100644 index 000000000..d25109bdc --- /dev/null +++ b/spd/scripts/plot_kv_vt_similarity/plot_kv_vt_similarity.py @@ -0,0 +1,149 @@ +"""Plot k-v component input-direction similarity heatmaps from SPD weight decomposition. + +For each layer, produces a heatmap of cosine similarity between the V (input-direction) +vectors of k_proj and v_proj components. High similarity means a k component and v component +respond to the same input directions. + + cos_sim(k_c, v_c') = dot(V_k[:, c], V_v[:, c']) / (||V_k[:, c]|| * ||V_v[:, c']||) + +Usage: + python -m spd.scripts.plot_kv_vt_similarity.plot_kv_vt_similarity \ + wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +MIN_MEAN_CI = 0.01 + + +def _get_alive_indices(summary: dict[str, ComponentSummary], module_path: str) -> list[int]: + """Return component indices for a module sorted by CI descending, filtered by threshold.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > MIN_MEAN_CI + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _compute_cosine_similarity( + k_component: LinearComponents, + v_component: LinearComponents, + k_alive: list[int], + v_alive: list[int], +) -> NDArray[np.floating]: + """Cosine similarity between V columns of k_proj and v_proj components. + + Returns (n_v_alive, n_k_alive) array. + """ + V_k = k_component.V[:, k_alive].float() # (d_in, n_k_alive) + V_v = v_component.V[:, v_alive].float() # (d_in, n_v_alive) + + V_k_normed = V_k / torch.linalg.norm(V_k, dim=0, keepdim=True).clamp(min=1e-12) + V_v_normed = V_v / torch.linalg.norm(V_v, dim=0, keepdim=True).clamp(min=1e-12) + + # (n_v_alive, n_k_alive) + sim = (V_v_normed.T @ V_k_normed).cpu().numpy() + return sim + + +def _plot_heatmap( + data: NDArray[np.floating], + k_alive: list[int], + v_alive: list[int], + layer_idx: int, + run_id: str, + out_dir: Path, +) -> None: + n_v, n_k = data.shape + fig, ax = plt.subplots(figsize=(max(8, n_k * 0.25), max(6, n_v * 0.25))) + + abs_max = float(np.abs(data).max()) or 1.0 + im = ax.imshow(data, aspect="auto", cmap="PiYG", vmin=-abs_max, vmax=abs_max) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label="Cosine similarity") + + ax.set_xticks(range(n_k)) + ax.set_xticklabels([f"C{idx}" for idx in k_alive], fontsize=7, rotation=90) + ax.set_xlabel("k_proj component (sorted by CI)") + + ax.set_yticks(range(n_v)) + ax.set_yticklabels([f"C{idx}" for idx in v_alive], fontsize=7) + ax.set_ylabel("v_proj component (sorted by CI)") + + fig.suptitle( + f"{run_id} | Layer {layer_idx} — k/v V-direction cosine similarity (ci>{MIN_MEAN_CI})", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(left=0.12, right=0.95, top=0.93, bottom=0.12) + + path = out_dir / f"layer{layer_idx}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_kv_vt_similarity(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + n_layers = len(target_model._h) + logger.info(f"Model: {n_layers} layers") + + with torch.no_grad(): + for layer_idx in range(n_layers): + k_path = f"h.{layer_idx}.attn.k_proj" + v_path = f"h.{layer_idx}.attn.v_proj" + + k_alive = _get_alive_indices(summary, k_path) + v_alive = _get_alive_indices(summary, v_path) + logger.info( + f"Layer {layer_idx}: {len(k_alive)} k components, {len(v_alive)} v components" + ) + + if not k_alive or not v_alive: + logger.info(f"Layer {layer_idx}: skipping (no alive k or v components)") + continue + + k_component = model.components[k_path] + v_component = model.components[v_path] + assert isinstance(k_component, LinearComponents) + assert isinstance(v_component, LinearComponents) + + sim = _compute_cosine_similarity(k_component, v_component, k_alive, v_alive) + _plot_heatmap(sim, k_alive, v_alive, layer_idx, run_id, out_dir) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_kv_vt_similarity) diff --git a/spd/scripts/plot_mean_ci/__init__.py b/spd/scripts/plot_mean_ci/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_mean_ci/plot_mean_ci.py b/spd/scripts/plot_mean_ci/plot_mean_ci.py new file mode 100644 index 000000000..fa1d2daff --- /dev/null +++ b/spd/scripts/plot_mean_ci/plot_mean_ci.py @@ -0,0 +1,98 @@ +"""Plot mean CI per component from harvested data. + +For each module, produces a scatter plot of mean CI values sorted descending by component. +Two figures are generated: one with a linear y-scale and one with a log y-scale. +Modules are arranged in a grid (max 6 rows, filling column by column), matching the +layout used by the training-time eval figures in spd/plotting.py. + +Usage: + python -m spd.scripts.plot_mean_ci.plot_mean_ci wandb:goodfire/spd/runs/ +""" + +from collections import defaultdict +from pathlib import Path + +import fire +import matplotlib.pyplot as plt + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +MAX_ROWS = 6 + + +def _sorted_mean_cis_by_module( + summary: dict[str, ComponentSummary], +) -> dict[str, list[float]]: + """Group mean_ci values by module, sorted descending within each module.""" + by_module = defaultdict[str, list[float]](list) + for s in summary.values(): + by_module[s.layer].append(s.mean_activations["causal_importance"]) + return {k: sorted(v, reverse=True) for k, v in sorted(by_module.items())} + + +def _plot_grid( + mean_cis_by_module: dict[str, list[float]], + log_y: bool, + out_dir: Path, +) -> None: + n_modules = len(mean_cis_by_module) + n_cols = (n_modules + MAX_ROWS - 1) // MAX_ROWS + n_rows = min(n_modules, MAX_ROWS) + + fig, axs = plt.subplots( + n_rows, n_cols, figsize=(8 * n_cols, 3 * n_rows), squeeze=False, dpi=200 + ) + + for i in range(n_modules, n_rows * n_cols): + axs[i % n_rows, i // n_rows].set_visible(False) + + for i, (module_name, cis) in enumerate(mean_cis_by_module.items()): + row = i % n_rows + col = i // n_rows + ax = axs[row, col] + + if log_y: + ax.set_yscale("log") + + ax.scatter(range(len(cis)), cis, marker="x", s=10) + + if row == n_rows - 1 or i == n_modules - 1: + ax.set_xlabel("Component") + ax.set_ylabel("mean CI") + ax.set_title(module_name, fontsize=10) + + fig.tight_layout() + + scale = "log" if log_y else "linear" + path = out_dir / f"mean_ci_{scale}.png" + fig.savefig(path, dpi=200, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_mean_ci(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + mean_cis_by_module = _sorted_mean_cis_by_module(summary) + logger.info(f"Modules: {len(mean_cis_by_module)}") + + _plot_grid(mean_cis_by_module, log_y=False, out_dir=out_dir) + _plot_grid(mean_cis_by_module, log_y=True, out_dir=out_dir) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_mean_ci) diff --git a/spd/scripts/plot_per_head_component_activations/__init__.py b/spd/scripts/plot_per_head_component_activations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_per_head_component_activations/plot_per_head_component_activations.py b/spd/scripts/plot_per_head_component_activations/plot_per_head_component_activations.py new file mode 100644 index 000000000..31105e839 --- /dev/null +++ b/spd/scripts/plot_per_head_component_activations/plot_per_head_component_activations.py @@ -0,0 +1,175 @@ +"""Plot per-head component activation heatmaps from an SPD run. + +For each layer and attention projection (q/k/v_proj), produces a heatmap where: + - y-axis: alive subcomponents (sorted by mean_ci descending, mean_ci > threshold) + - x-axis: attention head index + - color: mean |component activation| * ||u_c[head_h]|| / ||u_c|| + +Component activations are stored by harvest as (v_c^T @ x) * ||u_c||. Since v_c is +shared across all heads for q/k/v_proj, the per-head contribution is exact: + |v_c^T @ x| * ||u_c[head_h]|| = |stored_act| * ||u_c[head_h]|| / ||u_c|| + +o_proj is excluded because its head structure is in V (input), not U (output), +so the stored scalar activation can't be cleanly decomposed per head. + +Usage: + python -m spd.scripts.plot_per_head_component_activations.plot_per_head_component_activations wandb:goodfire/spd/runs/ +""" + +from pathlib import Path + +import fire +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData, ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +PROJ_NAMES = ("q_proj", "k_proj", "v_proj") +MIN_MEAN_CI = 0.01 + + +def _get_alive_components( + summary: dict[str, ComponentSummary], module_path: str, min_mean_ci: float +) -> list[tuple[str, int]]: + """Return (component_key, component_idx) pairs sorted by CI descending.""" + components = [ + (key, s.component_idx, s.mean_activations["causal_importance"]) + for key, s in summary.items() + if s.layer == module_path and s.mean_activations["causal_importance"] > min_mean_ci + ] + components.sort(key=lambda t: t[2], reverse=True) + return [(key, idx) for key, idx, _ in components] + + +def _mean_abs_act_on_firing_data(comp_data: ComponentData) -> float: + """Compute mean absolute component activation on firing positions.""" + acts: list[float] = [] + for example in comp_data.activation_examples: + comp_acts = example.activations["component_activation"] + for firing, stored_act in zip(example.firings, comp_acts, strict=True): + if firing: + acts.append(abs(stored_act)) + if not acts: + return 0.0 + return sum(acts) / len(acts) + + +def _per_head_activations( + component: LinearComponents, + alive: list[tuple[str, int]], + comp_data_map: dict[str, ComponentData], + head_dim: int, + n_heads: int, +) -> NDArray[np.floating]: + """Compute (n_alive, n_heads) array: mean |activation| * ||u_c[head_h]|| / ||u_c||.""" + result = np.zeros((len(alive), n_heads), dtype=np.float32) + for row, (key, c_idx) in enumerate(alive): + comp_data = comp_data_map.get(key) + if comp_data is None: + continue + mean_act = _mean_abs_act_on_firing_data(comp_data) + u_c = component.U[c_idx].float() + u_norm = torch.linalg.norm(u_c).item() + if u_norm == 0: + continue + for h in range(n_heads): + u_head = u_c[h * head_dim : (h + 1) * head_dim] + u_head_norm = torch.linalg.norm(u_head).item() + result[row, h] = mean_act * u_head_norm / u_norm + return result + + +def _plot_heatmap( + norms: NDArray[np.floating], + alive_indices: list[int], + n_heads: int, + layer_idx: int, + proj_name: str, + run_id: str, + out_dir: Path, +) -> None: + fig, ax = plt.subplots(figsize=(max(8, n_heads * 0.8), max(6, len(alive_indices) * 0.25))) + + im = ax.imshow(norms, aspect="auto", cmap="Reds", vmin=0) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, label="Per-head component activation") + + ax.set_xticks(range(n_heads)) + ax.set_xticklabels([f"H{h}" for h in range(n_heads)], fontsize=8) + ax.set_xlabel("Head") + + ax.set_yticks(range(len(alive_indices))) + ax.set_yticklabels([f"C{idx}" for idx in alive_indices], fontsize=7) + ax.set_ylabel("Component (sorted by CI)") + + fig.suptitle( + f"{run_id} | Layer {layer_idx} — {proj_name}\nPer-head component activation on CI-important data", + fontsize=12, + fontweight="bold", + ) + fig.subplots_adjust(left=0.12, right=0.95, top=0.90, bottom=0.08) + + path = out_dir / f"layer{layer_idx}_{proj_name}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def plot_per_head_component_activations(wandb_path: ModelPath) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + run_info = SPDRunInfo.from_path(wandb_path) + + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + blocks = target_model._h + head_dim = blocks[0].attn.head_dim + n_layers = len(blocks) + logger.info(f"Model: {n_layers} layers, head_dim={head_dim}") + + with torch.no_grad(): + for layer_idx in range(n_layers): + for proj_name in PROJ_NAMES: + module_path = f"h.{layer_idx}.attn.{proj_name}" + + component = model.components[module_path] + assert isinstance(component, LinearComponents) + + alive = _get_alive_components(summary, module_path, MIN_MEAN_CI) + logger.info(f"{module_path}: {len(alive)} components with mean_ci > {MIN_MEAN_CI}") + + if not alive: + continue + + keys = [key for key, _ in alive] + comp_data_map = repo.get_components_bulk(keys) + indices = [idx for _, idx in alive] + + n_heads = component.U.shape[1] // head_dim + norms = _per_head_activations(component, alive, comp_data_map, head_dim, n_heads) + + _plot_heatmap(norms, indices, n_heads, layer_idx, proj_name, run_id, out_dir) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_per_head_component_activations) diff --git a/spd/scripts/plot_qk_c_attention_contributions/__init__.py b/spd/scripts/plot_qk_c_attention_contributions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_qk_c_attention_contributions/plot_qk_c_attention_contributions.py b/spd/scripts/plot_qk_c_attention_contributions/plot_qk_c_attention_contributions.py new file mode 100644 index 000000000..db292652a --- /dev/null +++ b/spd/scripts/plot_qk_c_attention_contributions/plot_qk_c_attention_contributions.py @@ -0,0 +1,908 @@ +"""Plot weight-only attention contribution heatmaps between q and k subcomponents. + +For each layer and relative position offset, produces a single grid containing: + - First cell: summed (all heads) q·k attention contributions + - Remaining cells: one per query head + +Uses V-norm-scaled U dot products with RoPE applied at specified relative position offsets. + +Usage: + python -m spd.scripts.plot_qk_c_attention_contributions.plot_qk_c_attention_contributions \ + wandb:goodfire/spd/runs/ +""" + +import math +from dataclasses import dataclass +from pathlib import Path + +import fire +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.lines import Line2D +from numpy.typing import NDArray + +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentSummary +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import LinearComponents +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.scripts.rope_aware_qk import compute_qk_rope_coefficients, evaluate_qk_at_offsets +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path + +SCRIPT_DIR = Path(__file__).parent +MIN_MEAN_CI = 0.01 +DEFAULT_OFFSETS = tuple(range(17)) + + +def _get_alive_indices( + summary: dict[str, ComponentSummary], module_path: str, min_mean_ci: float +) -> list[int]: + """Return component indices for a module sorted by CI descending, filtered by threshold.""" + components = [ + (s.component_idx, s.mean_activations["causal_importance"]) + for s in summary.values() + if s.layer == module_path and s.mean_activations["causal_importance"] > min_mean_ci + ] + components.sort(key=lambda t: t[1], reverse=True) + return [idx for idx, _ in components] + + +def _compute_per_head_attention_contributions( + q_component: LinearComponents, + k_component: LinearComponents, + q_alive: list[int], + k_alive: list[int], + n_q_heads: int, + n_kv_heads: int, + head_dim: int, + rotary_cos: torch.Tensor, + rotary_sin: torch.Tensor, + offsets: tuple[int, ...], +) -> NDArray[np.floating]: + """Compute (n_offsets, n_q_heads, n_q_alive, n_k_alive) per-head attention contributions. + + Scales U vectors by their corresponding V norms before the dot product, so that the + result accounts for the unnormalized magnitude split between U and V. + """ + V_q_norms = torch.linalg.norm(q_component.V[:, q_alive], dim=0).float() # (n_q_alive,) + V_k_norms = torch.linalg.norm(k_component.V[:, k_alive], dim=0).float() # (n_k_alive,) + + U_q = q_component.U[q_alive].float() * V_q_norms[:, None] # (n_q_alive, n_q_heads * head_dim) + U_q = U_q.reshape(len(q_alive), n_q_heads, head_dim) + + U_k = k_component.U[k_alive].float() * V_k_norms[:, None] # (n_k_alive, n_kv_heads * head_dim) + U_k = U_k.reshape(len(k_alive), n_kv_heads, head_dim) + + g = n_q_heads // n_kv_heads + U_k_expanded = U_k.repeat_interleave(g, dim=1) # (n_k_alive, n_q_heads, head_dim) + + head_results = [] + for h in range(n_q_heads): + A, B = compute_qk_rope_coefficients(U_q[:, h, :], U_k_expanded[:, h, :]) + W_h = evaluate_qk_at_offsets(A, B, rotary_cos, rotary_sin, offsets, head_dim) + head_results.append(W_h) # (n_offsets, n_q, n_k) + + # (n_heads, n_offsets, n_q, n_k) -> (n_offsets, n_heads, n_q, n_k) + return torch.stack(head_results).permute(1, 0, 2, 3).cpu().numpy() + + +def _plot_heatmaps( + W_per_head: NDArray[np.floating], + q_alive: list[int], + k_alive: list[int], + n_q_heads: int, + layer_idx: int, + run_id: str, + out_dir: Path, + offset: int, + vmax: float, +) -> None: + n_cells = n_q_heads + 1 # summed + per-head + n_cols = math.ceil(math.sqrt(n_cells)) + n_rows = math.ceil(n_cells / n_cols) + n_q, n_k = W_per_head.shape[1], W_per_head.shape[2] + + W_summed = W_per_head.sum(axis=0) + + cell_w = max(4, n_k * 0.15) + cell_h = max(3, n_q * 0.15) + fig, axes = plt.subplots( + n_rows, n_cols, figsize=(cell_w * n_cols, cell_h * n_rows), squeeze=False + ) + + # All cells: summed first, then per-head + all_data = [W_summed] + [W_per_head[h] for h in range(n_q_heads)] + titles = ["Sum"] + [f"H{h}" for h in range(n_q_heads)] + + for cell, (data, title) in enumerate(zip(all_data, titles, strict=True)): + row, col = divmod(cell, n_cols) + ax = axes[row, col] + im = ax.imshow(data, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) + ax.set_title(title, fontsize=10, fontweight="bold" if cell == 0 else "normal") + ax.set_yticks(range(n_q)) + ax.set_yticklabels([f"Q C{idx}" for idx in q_alive], fontsize=5) + ax.set_xticks(range(n_k)) + ax.set_xticklabels([f"K C{idx}" for idx in k_alive], fontsize=5, rotation=90) + + # Hide unused cells + for i in range(n_cells, n_rows * n_cols): + row, col = divmod(i, n_cols) + axes[row, col].set_visible(False) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 q\u00b7k attention contributions" + f" (\u0394={offset}) (ci>{MIN_MEAN_CI})", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(hspace=0.3, wspace=0.4) + + path = out_dir / f"layer{layer_idx}_qk_attention_contributions_offset{offset}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def _plot_diff_heatmaps( + D_per_head: NDArray[np.floating], + q_alive: list[int], + k_alive: list[int], + n_q_heads: int, + layer_idx: int, + run_id: str, + out_dir: Path, + offset: int, + vmax: float, +) -> None: + """Plot W(Δ=offset) - W(Δ=0) heatmaps.""" + n_cells = n_q_heads + 1 + n_cols = math.ceil(math.sqrt(n_cells)) + n_rows = math.ceil(n_cells / n_cols) + n_q, n_k = D_per_head.shape[1], D_per_head.shape[2] + + D_summed = D_per_head.sum(axis=0) + + cell_w = max(4, n_k * 0.15) + cell_h = max(3, n_q * 0.15) + fig, axes = plt.subplots( + n_rows, n_cols, figsize=(cell_w * n_cols, cell_h * n_rows), squeeze=False + ) + + all_data = [D_summed] + [D_per_head[h] for h in range(n_q_heads)] + titles = ["Sum"] + [f"H{h}" for h in range(n_q_heads)] + + for cell, (data, title) in enumerate(zip(all_data, titles, strict=True)): + row, col = divmod(cell, n_cols) + ax = axes[row, col] + im = ax.imshow(data, aspect="auto", cmap="PiYG", vmin=-vmax, vmax=vmax) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) + ax.set_title(title, fontsize=10, fontweight="bold" if cell == 0 else "normal") + ax.set_yticks(range(n_q)) + ax.set_yticklabels([f"Q C{idx}" for idx in q_alive], fontsize=5) + ax.set_xticks(range(n_k)) + ax.set_xticklabels([f"K C{idx}" for idx in k_alive], fontsize=5, rotation=90) + + for i in range(n_cells, n_rows * n_cols): + row, col = divmod(i, n_cols) + axes[row, col].set_visible(False) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 q\u00b7k attention diff" + f" (\u0394={offset} \u2212 \u0394=0) (ci>{MIN_MEAN_CI})", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(hspace=0.3, wspace=0.4) + + path = out_dir / f"layer{layer_idx}_qk_attention_diff_offset{offset}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def _plot_heatmaps_per_head( + W: NDArray[np.floating], + q_alive: list[int], + k_alive: list[int], + n_q_heads: int, + layer_idx: int, + run_id: str, + out_dir: Path, + offsets: tuple[int, ...], + vmax: float, +) -> None: + """For each head (and Sum), plot a grid of heatmaps across all offsets.""" + n_offsets = len(offsets) + n_cols = math.ceil(math.sqrt(n_offsets)) + n_rows = math.ceil(n_offsets / n_cols) + n_q, n_k = len(q_alive), len(k_alive) + + per_head_dir = out_dir / "heatmap_offsets_per_head" + per_head_dir.mkdir(parents=True, exist_ok=True) + + cell_w = max(4, n_k * 0.15) + cell_h = max(3, n_q * 0.15) + + # W shape: (n_offsets, n_q_heads, n_q, n_k) + # Build list of (label, data) where data is (n_offsets, n_q, n_k) + head_series: list[tuple[str, NDArray[np.floating]]] = [ + ("Sum", W.sum(axis=1)), + ] + [(f"H{h}", W[:, h]) for h in range(n_q_heads)] + + for label, data in head_series: + fig, axes = plt.subplots( + n_rows, n_cols, figsize=(cell_w * n_cols, cell_h * n_rows), squeeze=False + ) + + for cell, offset in enumerate(offsets): + row, col = divmod(cell, n_cols) + ax = axes[row, col] + im = ax.imshow(data[cell], aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) + ax.set_title(f"\u0394={offset}", fontsize=10) + ax.set_yticks(range(n_q)) + ax.set_yticklabels([f"Q C{idx}" for idx in q_alive], fontsize=5) + ax.set_xticks(range(n_k)) + ax.set_xticklabels([f"K C{idx}" for idx in k_alive], fontsize=5, rotation=90) + + for i in range(n_offsets, n_rows * n_cols): + row, col = divmod(i, n_cols) + axes[row, col].set_visible(False) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} {label} \u2014 q\u00b7k attention contributions" + f" (ci>{MIN_MEAN_CI})", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(hspace=0.3, wspace=0.4) + + path = per_head_dir / f"layer{layer_idx}_qk_attention_{label.lower()}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +HEAD_CMAPS = [ + "Reds", + "Blues", + "Greens", + "Oranges", + "Purples", + "Greys", + "YlOrBr", + "BuPu", + "PuRd", + "GnBu", + "OrRd", + "YlGn", +] + + +def _plot_head_vs_sum_scatter( + W: NDArray[np.floating], + n_q_heads: int, + layer_idx: int, + run_id: str, + out_dir: Path, + offsets: tuple[int, ...], +) -> None: + """Scatter: x = sum-across-heads contribution, y = per-head contribution. + + Each head uses a distinct sequential colormap; within that colormap each + offset maps to a different shade (darker = larger offset index). + """ + # W shape: (n_offsets, n_q_heads, n_q, n_k) + W_summed = W.sum(axis=1) # (n_offsets, n_q, n_k) + n_offsets = len(offsets) + + # Map offset indices to colormap values in [0.3, 0.9] so no shade is too light + norm = mcolors.Normalize(vmin=-0.5, vmax=n_offsets - 0.5) + offset_vals = [0.3 + 0.6 * norm(i) for i in range(n_offsets)] + + fig, ax = plt.subplots(figsize=(8, 8)) + + for h in range(n_q_heads): + cmap = plt.get_cmap(HEAD_CMAPS[h % len(HEAD_CMAPS)]) + for oi, offset in enumerate(offsets): + x = W_summed[oi].ravel() + y = W[oi, h].ravel() + color = cmap(offset_vals[oi]) + ax.scatter( + x, + y, + s=6, + alpha=0.5, + color=color, + label=f"H{h} \u0394={offset}", + edgecolors="none", + rasterized=True, + ) + + # x=y reference line + lo = min(ax.get_xlim()[0], ax.get_ylim()[0]) + hi = max(ax.get_xlim()[1], ax.get_ylim()[1]) + ax.plot([lo, hi], [lo, hi], "k--", linewidth=0.5, alpha=0.4) + ax.set_xlim((lo, hi)) + ax.set_ylim((lo, hi)) + + ax.set_xlabel("Summed (all heads) attention contribution") + ax.set_ylabel("Per-head attention contribution") + ax.set_aspect("equal") + + # Legend: one entry per head (color patch at mid-shade), one per offset (grey shades) + head_handles = [] + for h in range(n_q_heads): + cmap = plt.get_cmap(HEAD_CMAPS[h % len(HEAD_CMAPS)]) + head_handles.append( + Line2D( + [0], + [0], + marker="o", + color="none", + markerfacecolor=cmap(0.6), + markersize=6, + label=f"H{h}", + ) + ) + offset_handles = [] + for oi, offset in enumerate(offsets): + grey = str(1.0 - offset_vals[oi]) # darker for higher offset index + offset_handles.append( + Line2D( + [0], + [0], + marker="o", + color="none", + markerfacecolor=grey, + markersize=6, + label=f"\u0394={offset}", + ) + ) + + leg1 = ax.legend( + handles=head_handles, + loc="upper left", + fontsize=7, + title="Head", + title_fontsize=8, + framealpha=0.8, + ) + ax.add_artist(leg1) + ax.legend( + handles=offset_handles, + loc="lower right", + fontsize=7, + title="Offset", + title_fontsize=8, + framealpha=0.8, + ) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 per-head vs summed q\u00b7k contributions" + f" (ci>{MIN_MEAN_CI})", + fontsize=12, + fontweight="bold", + ) + + scatter_dir = out_dir / "scatter_head_vs_sum" + scatter_dir.mkdir(parents=True, exist_ok=True) + path = scatter_dir / f"layer{layer_idx}_head_vs_sum_scatter.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def _plot_pair_lines( + W_summed: NDArray[np.floating], + offsets: tuple[int, ...], + q_alive: list[int], + k_alive: list[int], + layer_idx: int, + run_id: str, + out_dir: Path, + top_n_pairs: int, +) -> None: + """Line plot of attention contribution vs offset for top (q_c, k_c) pairs. + + Top-N pairs are plotted in color; a wider set of background pairs is plotted + in faint gray for context. + """ + _n_offsets, n_q, n_k = W_summed.shape + peak_abs = np.abs(W_summed).max(axis=0) # (n_q, n_k) + flat_ranked = np.argsort(peak_abs.ravel())[::-1] + + top_pairs = [divmod(int(idx), n_k) for idx in flat_ranked[:top_n_pairs]] + top_pair_set = set(top_pairs) + + fig, ax = plt.subplots(figsize=(10, 6)) + x = list(offsets) + + plotted_gray = False + for qi in range(n_q): + for ki in range(n_k): + if (qi, ki) in top_pair_set: + continue + label = "other" if not plotted_gray else None + ax.plot(x, W_summed[:, qi, ki], color="0.80", linewidth=0.8, alpha=0.45, label=label) + plotted_gray = True + + for qi, ki in top_pairs: + ax.plot( + x, + W_summed[:, qi, ki], + marker="o", + markersize=3, + label=f"Q C{q_alive[qi]} \u2192 K C{k_alive[ki]}", + ) + + ax.axhline(0, color="black", linewidth=0.5, linestyle="--", alpha=0.4) + ax.set_xlabel("Offset (\u0394)") + ax.set_ylabel("Attention contribution (summed across heads)") + ax.set_xticks(x) + ax.legend(fontsize=6, loc="center left", bbox_to_anchor=(1.02, 0.5)) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 q\u00b7k pair contributions vs offset" + f" (top {len(top_pairs)} pairs, ci>{MIN_MEAN_CI})", + fontsize=12, + fontweight="bold", + ) + + lines_dir = out_dir / "lines" + lines_dir.mkdir(parents=True, exist_ok=True) + path = lines_dir / f"layer{layer_idx}_qk_pair_lines.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def _plot_pair_lines_per_head( + W: NDArray[np.floating], + offsets: tuple[int, ...], + q_alive: list[int], + k_alive: list[int], + layer_idx: int, + run_id: str, + out_dir: Path, + top_n: int, +) -> None: + """Line plot of per-head attention contribution vs offset for top (head, q_c, k_c) triples.""" + peak_abs = np.abs(W).max(axis=0) # (n_q_heads, n_q, n_k) + + flat_indices = np.argsort(peak_abs.ravel())[::-1][:top_n] + n_k = len(k_alive) + triples = [] + for idx in flat_indices: + h, rem = divmod(int(idx), len(q_alive) * n_k) + qi, ki = divmod(rem, n_k) + triples.append((h, qi, ki)) + + fig, ax = plt.subplots(figsize=(10, 6)) + x = list(offsets) + + for h, qi, ki in triples: + y = W[:, h, qi, ki] + ax.plot( + x, + y, + marker="o", + markersize=3, + label=f"H{h}: Q C{q_alive[qi]} \u2192 K C{k_alive[ki]}", + ) + + ax.axhline(0, color="black", linewidth=0.5, linestyle="--", alpha=0.4) + ax.set_xlabel("Offset (\u0394)") + ax.set_ylabel("Attention contribution (single head)") + ax.set_xticks(x) + ax.legend(fontsize=6, loc="center left", bbox_to_anchor=(1.02, 0.5)) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 per-head q\u00b7k pair contributions vs offset" + f" (top {len(triples)}, ci>{MIN_MEAN_CI})", + fontsize=12, + fontweight="bold", + ) + + lines_dir = out_dir / "lines_per_head" + lines_dir.mkdir(parents=True, exist_ok=True) + path = lines_dir / f"layer{layer_idx}_qk_pair_lines_per_head.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +def _plot_pair_lines_single_head( + W: NDArray[np.floating], + offsets: tuple[int, ...], + q_alive: list[int], + k_alive: list[int], + n_q_heads: int, + layer_idx: int, + run_id: str, + out_dir: Path, + top_n: int, +) -> None: + """2x3 grid of per-head line plots with consistent pair colors across heads. + + Global top-K pairs (ranked by peak |W| across all heads and offsets) are plotted in + color; each head's remaining local top-K pairs are plotted in faint gray. + """ + # W shape: (n_offsets, n_q_heads, n_q, n_k) + n_k = len(k_alive) + x = list(offsets) + + # Global top-K: rank (q, k) pairs by peak absolute value across all heads and offsets + global_peak = np.abs(W).max(axis=(0, 1)) # (n_q, n_k) + global_flat = np.argsort(global_peak.ravel())[::-1][:top_n] + global_pairs = [divmod(int(idx), n_k) for idx in global_flat] + global_pair_set = set(global_pairs) + + cmap = plt.get_cmap("tab20") + pair_colors = {pair: cmap(i % 20) for i, pair in enumerate(global_pairs)} + + n_cols = 3 + n_rows = math.ceil(n_q_heads / n_cols) + fig, axes = plt.subplots(n_rows, n_cols, figsize=(24, 5 * n_rows), squeeze=False) + + for h in range(n_q_heads): + row, col = divmod(h, n_cols) + ax = axes[row, col] + + W_h = W[:, h] # (n_offsets, n_q, n_k) + n_q = W_h.shape[1] + + plotted_gray = False + for qi in range(n_q): + for ki in range(n_k): + if (qi, ki) in global_pair_set: + continue + label = "other" if not plotted_gray else None + ax.plot(x, W_h[:, qi, ki], color="0.80", linewidth=0.8, alpha=0.45, label=label) + plotted_gray = True + + # Plot global top-K pairs in color + for qi, ki in global_pairs: + ax.plot( + x, + W_h[:, qi, ki], + color=pair_colors[(qi, ki)], + marker="o", + markersize=3, + label=f"Q C{q_alive[qi]} \u2192 K C{k_alive[ki]}", + ) + + # Sum of all (q, k) pair contributions within this head + total = W_h.sum(axis=(1, 2)) # (n_offsets,) + ax.plot(x, total, color="black", linewidth=2, label="sum (all pairs)") + + ax.axhline(0, color="black", linewidth=0.5, linestyle="--", alpha=0.4) + ax.set_xlabel("Offset (\u0394)") + ax.set_ylabel("Attention contribution") + ax.set_title(f"H{h}", fontsize=11, fontweight="bold") + ax.set_xticks(x) + + # Hide unused cells + for i in range(n_q_heads, n_rows * n_cols): + row, col = divmod(i, n_cols) + axes[row, col].set_visible(False) + + # Shared legend from first subplot + handles, labels = axes[0, 0].get_legend_handles_labels() + fig.legend( + handles, + labels, + fontsize=6, + loc="center left", + bbox_to_anchor=(1.0, 0.5), + ncol=1, + ) + + fig.suptitle( + f"{run_id} | Layer {layer_idx} \u2014 per-head q\u00b7k pair contributions vs offset" + f" (top {top_n}, ci>{MIN_MEAN_CI})", + fontsize=13, + fontweight="bold", + ) + fig.subplots_adjust(hspace=0.3, wspace=0.3) + + lines_dir = out_dir / "lines_single_head" + lines_dir.mkdir(parents=True, exist_ok=True) + path = lines_dir / f"layer{layer_idx}_qk_pair_lines_grid.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved {path}") + + +PLOT_TYPES = ( + "heatmaps", + "heatmaps_per_head", + "scatter", + "diffs", + "lines", + "lines_per_head", + "lines_single_head", +) + + +@dataclass +class _LayerCache: + W: NDArray[np.floating] + q_alive: list[int] + k_alive: list[int] + offsets: tuple[int, ...] + n_q_heads: int + + +def _cache_path(out_dir: Path, layer_idx: int) -> Path: + return out_dir / "cache" / f"layer{layer_idx}.npz" + + +def _save_layer_cache(out_dir: Path, layer_idx: int, cache: _LayerCache) -> None: + cache_dir = out_dir / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + np.savez( + _cache_path(out_dir, layer_idx), + W=cache.W, + q_alive=np.array(cache.q_alive), + k_alive=np.array(cache.k_alive), + offsets=np.array(cache.offsets), + n_q_heads=np.array(cache.n_q_heads), + ) + + +def _load_layer_cache(out_dir: Path, layer_idx: int) -> _LayerCache | None: + path = _cache_path(out_dir, layer_idx) + if not path.exists(): + return None + data = np.load(path) + return _LayerCache( + W=data["W"], + q_alive=data["q_alive"].tolist(), + k_alive=data["k_alive"].tolist(), + offsets=tuple(data["offsets"].tolist()), + n_q_heads=int(data["n_q_heads"]), + ) + + +def _compute_and_cache_all_layers( + wandb_path: ModelPath, + offsets: tuple[int, ...], + out_dir: Path, + run_id: str, +) -> None: + run_info = SPDRunInfo.from_path(wandb_path) + model = ComponentModel.from_run_info(run_info) + model.eval() + + repo = HarvestRepo.open_most_recent(run_id) + assert repo is not None, f"No harvest data found for {run_id}" + summary = repo.get_summary() + + target_model = model.target_model + assert isinstance(target_model, LlamaSimpleMLP) + blocks = target_model._h + assert not blocks[0].attn.rotary_adjacent_pairs, "RoPE math assumes non-adjacent pairs layout" + head_dim = blocks[0].attn.head_dim + n_q_heads = blocks[0].attn.n_head + n_kv_heads = blocks[0].attn.n_key_value_heads + n_layers = len(blocks) + logger.info( + f"Model: {n_layers} layers, head_dim={head_dim}, " + f"n_q_heads={n_q_heads}, n_kv_heads={n_kv_heads}" + ) + + with torch.no_grad(): + for layer_idx in range(n_layers): + q_path = f"h.{layer_idx}.attn.q_proj" + k_path = f"h.{layer_idx}.attn.k_proj" + + q_component = model.components[q_path] + k_component = model.components[k_path] + assert isinstance(q_component, LinearComponents) + assert isinstance(k_component, LinearComponents) + + q_alive = _get_alive_indices(summary, q_path, MIN_MEAN_CI) + k_alive = _get_alive_indices(summary, k_path, MIN_MEAN_CI) + logger.info( + f"Layer {layer_idx}: {len(q_alive)} q components, {len(k_alive)} k components" + ) + + if not q_alive or not k_alive: + logger.info(f"Layer {layer_idx}: skipping (no alive q or k components)") + continue + + rotary_cos = blocks[layer_idx].attn.rotary_cos + rotary_sin = blocks[layer_idx].attn.rotary_sin + assert isinstance(rotary_cos, torch.Tensor) + assert isinstance(rotary_sin, torch.Tensor) + + W = _compute_per_head_attention_contributions( + q_component, + k_component, + q_alive, + k_alive, + n_q_heads, + n_kv_heads, + head_dim, + rotary_cos, + rotary_sin, + offsets, + ) + + cache = _LayerCache( + W=W, q_alive=q_alive, k_alive=k_alive, offsets=offsets, n_q_heads=n_q_heads + ) + _save_layer_cache(out_dir, layer_idx, cache) + logger.info(f"Cached layer {layer_idx}") + + +def _get_layer_caches( + wandb_path: ModelPath, + offsets: tuple[int, ...], + out_dir: Path, + run_id: str, + recompute: bool, +) -> list[tuple[int, _LayerCache]]: + """Load caches if they exist and offsets match, otherwise recompute all layers.""" + if not recompute: + caches: list[tuple[int, _LayerCache]] = [] + layer_idx = 0 + while True: + cached = _load_layer_cache(out_dir, layer_idx) + if cached is None: + break + if cached.offsets == offsets: + caches.append((layer_idx, cached)) + else: + logger.info(f"Cache offsets mismatch at layer {layer_idx}, recomputing all") + caches = [] + break + layer_idx += 1 + + if caches: + logger.info(f"Loaded {len(caches)} layers from cache") + return caches + + _compute_and_cache_all_layers(wandb_path, offsets, out_dir, run_id) + + caches = [] + layer_idx = 0 + while True: + cached = _load_layer_cache(out_dir, layer_idx) + if cached is None: + break + caches.append((layer_idx, cached)) + layer_idx += 1 + return caches + + +def _plot_layer( + cache: _LayerCache, + layer_idx: int, + run_id: str, + out_dir: Path, + top_n_pairs: int, + plots: set[str], +) -> None: + W = cache.W + q_alive = cache.q_alive + k_alive = cache.k_alive + offsets = cache.offsets + n_q_heads = cache.n_q_heads + + W_summed_all = W.sum(axis=1) # (n_offsets, n_q, n_k) + vmax = float(max(np.abs(W_summed_all).max(), np.abs(W).max())) or 1.0 + + if "heatmaps" in plots: + for offset_idx, offset in enumerate(offsets): + _plot_heatmaps( + W[offset_idx], + q_alive, + k_alive, + n_q_heads, + layer_idx, + run_id, + out_dir, + offset, + vmax, + ) + + if "heatmaps_per_head" in plots: + _plot_heatmaps_per_head( + W, q_alive, k_alive, n_q_heads, layer_idx, run_id, out_dir, offsets, vmax + ) + + if "scatter" in plots: + _plot_head_vs_sum_scatter(W, n_q_heads, layer_idx, run_id, out_dir, offsets) + + if "diffs" in plots: + assert offsets[0] == 0, "First offset must be 0 for diff computation" + W_base = W[0] + non_zero_offsets = [(idx, o) for idx, o in enumerate(offsets) if o != 0] + if non_zero_offsets: + diffs = np.stack([W[idx] - W_base for idx, _ in non_zero_offsets]) + D_summed_all = diffs.sum(axis=1) + diff_vmax = float(max(np.abs(D_summed_all).max(), np.abs(diffs).max())) or 1.0 + + diff_dir = out_dir / "diffs" + diff_dir.mkdir(parents=True, exist_ok=True) + for i, (_, offset) in enumerate(non_zero_offsets): + _plot_diff_heatmaps( + diffs[i], + q_alive, + k_alive, + n_q_heads, + layer_idx, + run_id, + diff_dir, + offset, + diff_vmax, + ) + + if "lines" in plots: + _plot_pair_lines( + W_summed_all, offsets, q_alive, k_alive, layer_idx, run_id, out_dir, top_n_pairs + ) + + if "lines_per_head" in plots: + _plot_pair_lines_per_head( + W, offsets, q_alive, k_alive, layer_idx, run_id, out_dir, top_n_pairs + ) + + if "lines_single_head" in plots: + _plot_pair_lines_single_head( + W, + offsets, + q_alive, + k_alive, + n_q_heads, + layer_idx, + run_id, + out_dir, + top_n_pairs, + ) + + +def plot_qk_c_attention_contributions( + wandb_path: ModelPath, + offsets: tuple[int, ...] = DEFAULT_OFFSETS, + top_n_pairs: int = 10, + plots: str = "all", + recompute: bool = False, +) -> None: + """Plot weight-only attention contribution analyses. + + Args: + wandb_path: WandB run path. + offsets: Relative position offsets to evaluate. + top_n_pairs: Number of top (q, k) pairs to highlight in line plots. + plots: Comma-separated plot types, or "all". Options: heatmaps, heatmaps_per_head, + scatter, diffs, lines, lines_per_head, lines_single_head. + recompute: Force recomputation even if cached data exists. + """ + plot_set: set[str] = ( + set(PLOT_TYPES) if plots == "all" else {s.strip() for s in plots.split(",")} + ) + unknown = plot_set - set(PLOT_TYPES) + assert not unknown, f"Unknown plot types: {unknown}. Valid: {PLOT_TYPES}" + + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) + out_dir = SCRIPT_DIR / "out" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + layer_caches = _get_layer_caches(wandb_path, offsets, out_dir, run_id, recompute) + + for layer_idx, cache in layer_caches: + _plot_layer(cache, layer_idx, run_id, out_dir, top_n_pairs, plot_set) + + logger.info(f"All plots saved to {out_dir}") + + +if __name__ == "__main__": + fire.Fire(plot_qk_c_attention_contributions) diff --git a/spd/scripts/rope_aware_qk.py b/spd/scripts/rope_aware_qk.py new file mode 100644 index 000000000..b75daf6b8 --- /dev/null +++ b/spd/scripts/rope_aware_qk.py @@ -0,0 +1,69 @@ +"""RoPE-aware Q-K dot product helpers for SPD component analysis scripts. + +Decomposes the RoPE-modulated dot product into content-aligned and cross-half +coefficients, allowing evaluation at arbitrary relative position offsets: + + W(Δ) = (1/√d_head) Σ_d [A_d · cos(Δ·θ_d) + B_d · sin(Δ·θ_d)] + +Assumes non-adjacent-pairs RoPE layout (first-half/second-half dimension split). +""" + +import math +from collections.abc import Sequence + +import torch +from einops import einsum +from torch import Tensor + + +def compute_qk_rope_coefficients( + U_q: Tensor, + U_k: Tensor, +) -> tuple[Tensor, Tensor]: + """Compute A_d and B_d RoPE coefficients for all (q, k) pairs. + + Args: + U_q: (n_q, head_dim) query weight vectors for a single head + U_k: (n_k, head_dim) key weight vectors for a single head + + Returns: + A: (n_q, n_k, half_dim) content-aligned coefficients + B: (n_q, n_k, half_dim) cross-half coefficients (zero contribution at Δ=0) + """ + half = U_q.shape[-1] // 2 + q1, q2 = U_q[..., :half], U_q[..., half:] + k1, k2 = U_k[..., :half], U_k[..., half:] + A = einsum(q1, k1, "q d, k d -> q k d") + einsum(q2, k2, "q d, k d -> q k d") + B = einsum(q1, k2, "q d, k d -> q k d") - einsum(q2, k1, "q d, k d -> q k d") + return A, B + + +def evaluate_qk_at_offsets( + A: Tensor, + B: Tensor, + rotary_cos: Tensor, + rotary_sin: Tensor, + offsets: Sequence[int], + head_dim: int, +) -> Tensor: + """Evaluate W(Δ) at specified relative position offsets. + + Args: + A: (n_q, n_k, half_dim) content-aligned coefficients + B: (n_q, n_k, half_dim) cross-half coefficients + rotary_cos: (n_ctx, head_dim) precomputed cos buffer from model + rotary_sin: (n_ctx, head_dim) precomputed sin buffer from model + offsets: relative position offsets Δ to evaluate + head_dim: head dimension (used for 1/√d scaling) + + Returns: + (n_offsets, n_q, n_k) dot product values at each offset + """ + half = head_dim // 2 + results = [] + for delta in offsets: + cos_d = rotary_cos[delta, :half].float() + sin_d = rotary_sin[delta, :half].float() + W = (A * cos_d + B * sin_d).sum(dim=-1) / math.sqrt(head_dim) + results.append(W) + return torch.stack(results)