Skip to content

Rewrite dataset attribution storage + 3 metrics#413

Open
ocg-goodfire wants to merge 17 commits intodevfrom
feature/faster-dataset-attributions-2
Open

Rewrite dataset attribution storage + 3 metrics#413
ocg-goodfire wants to merge 17 commits intodevfrom
feature/faster-dataset-attributions-2

Conversation

@ocg-goodfire
Copy link
Collaborator

Summary

  • Dict-of-dicts storage: Replace flat-indexed source_to_component / source_to_out_residual matrices with attrs[target_layer][source_layer] = Tensor[target_d, source_d]
  • Canonical naming: Storage uses canonical layer names (embed, output, 0.glu.up). Harvester stays concrete internally; translation at storage boundary via topology.target_to_canon()
  • 3 attribution metrics: attr (signed mean), attr_abs (attribution to |target|, via backprop through .abs()), mean_squared_attr (pre-sqrt, mergeable across workers)
  • Split entrypoints: Separate run_worker.py and run_merge.py instead of combined run.py
  • Bug fixes: alive_targets bool→index, embed CI KeyError, scatter_add OOB (vocab_size vs num_embeddings), merge entrypoint missing config_json, attr_abs sign correctness (backprop through |y| instead of flipping by source sign)

Test plan

  • Storage unit tests (17 tests: has_source, has_target, save/load roundtrip, merge weighted average)
  • Type checks pass (basedpyright)
  • End-to-end SLURM test on s-17805b61 (3 GPUs × 2 batches) — workers complete, merge pending with more memory
  • Benchmark: fast (scatter_add_) vs slow (per-element loop) comparison in progress

🤖 Generated with Claude Code

ocg-goodfire and others added 10 commits February 23, 2026 15:53
…3 metrics

Storage uses attrs[target_layer][source_layer] = Tensor[target_d, source_d]
with canonical layer names (embed, output, 0.glu.up, etc.). Harvester stays
concrete internally; translation at storage boundary via topology.target_to_canon.

Three attribution metrics accumulated:
- attr: E[grad*act] (signed mean)
- attr_abs: E[grad*|act|] (attribution to absolute target value)
- mean_squared_attr: E[(grad*act)²] (pre-sqrt, mergeable across workers)

Other changes:
- Fix filter bug: used "output" instead of concrete unembed path (e.g. "lm_head")
- Harvester parameterised with embed_path/unembed_path instead of magic strings
- Storage.merge() classmethod with correct weighted-average semantics
- Router simplified: no topology translation needed with canonical storage
- Query methods stubbed with ValueError (frontend not yet updated)
- Re-enable AttributionRepo.open() load
- Remove outdated test_harvester.py (uses old flat-index API)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…list

alive_targets is a bool tensor; .tolist() gives [True, False, ...] not indices.
torch.autograd.grad needs a scalar output, so index with actual int indices.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Embed tokens have no CI (always active), so skip CI weighting for embed sources.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…r vocab_size

tokenizer.vocab_size (50254) < len(tokenizer) (50277) due to added tokens.
Token IDs >= vocab_size cause scatter_add_ index out of bounds in the embed
accumulator. Use embedding_module.num_embeddings which matches the actual
token ID space.

Also add Path type annotation to test tmp_path params.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Merge doesn't need config_json, worker does. Separate entrypoints avoid the
issue where Fire requires config_json for both paths.

Cherry-picked from feature/faster-dataset-attributions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…natures

attr_abs now computed by backpropping through target_acts.abs() instead of
flipping by source activation sign. Requires 2 backward passes per target
component but is mathematically correct for cross-position (attention) interactions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3 metrics × dict-of-dicts makes rank files ~15GB each. Merge loads all
in double precision, needs much more than the default 10GB.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Backend: implement storage query methods with AttrMetric parameter, bulk
endpoint returns all 3 metrics (attr, attr_abs, mean_squared_attr), other
endpoints accept optional ?metric= query param.

Frontend: 3-way radio toggle (Signed / Abs Target / RMS) in
DatasetAttributionsSection. All metrics fetched at once, selection is local
state that switches which ComponentAttributions to display.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
parse_wandb_run_path now accepts "s-xxxxxxxx" and expands to goodfire/spd.
Handled in backend so it works for CLI, app, and any other consumer.
Frontend placeholder updated.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire force-pushed the feature/faster-dataset-attributions-2 branch from 9774be5 to 1bf9877 Compare February 23, 2026 19:25
Old subruns (da-timing-*, da-overnight-*) sort after da-YYYYMMDD_* and have
no dataset_attributions.pt. The old code only checked the last candidate
and returned None. Now iterates in reverse until finding one with the file.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire force-pushed the feature/faster-dataset-attributions-2 branch from e2fed0d to 627df2b Compare February 23, 2026 19:45
ocg-goodfire and others added 5 commits February 23, 2026 20:17
getTokenText did .find() over the full 50K vocab array for every
embed/output pill on each render. Build a Map<id, string> once via
$derived, making lookups O(1).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ntend

Backend resolves embed/output token strings via tokenizer.decode() and
includes token_str in DatasetAttributionEntry. Frontend uses it directly
instead of scanning a 50K vocab array per pill.

Removes tokens/outputProbs passthrough from EdgeAttributionGrid/List —
token strings now flow through EdgeAttribution.tokenStr from the source.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
attr_abs and mean_squared_attr are non-negative by definition, so the
negative top-k is meaningless. Only show negative column for signed attr.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mean_squared_attr only has positive_sources/positive_targets.
Hardcoded three paths in DatasetAttributionsSection matching each type.

Slow benchmark result: per-element loops >14x slower than scatter_add_
(>60 min vs 4.3 min per batch on s-17805b61, timed out at 1hr).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…trics

Rewrite docs to reflect dict-of-dicts storage, canonical naming, split
entrypoints, 3 metrics, and updated query method signatures.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire changed the base branch from main to dev February 23, 2026 21:07
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ocg-goodfire
Copy link
Collaborator Author

@claude please review this PR. specifically the exact correctness of the code. The interfaces and types are of course different but that's obviously intentional, I want to know if I've made mistakes implementing what should hopefully be somewhat obvious what the goal was

@claude
Copy link
Contributor

claude bot commented Feb 23, 2026

Claude finished @ocg-goodfire's task —— View job


Code Review Complete ✅

Reviewing the dataset attribution storage rewrite for correctness. Todo list:

  • Understand the new dict-of-dicts storage format vs old flat matrices
  • Review attribution metric calculations (attr, attr_abs, mean_squared_attr)
  • Check storage serialization/deserialization logic
  • Review worker/merge split implementation
  • Look for potential bugs and edge cases
  • Verify type correctness and consistency

Review Summary

The code looks correct and well-implemented. The rewrite successfully addresses the stated goals while fixing several important bugs. Here's my detailed analysis:

Storage Architecture

  • Dict-of-dicts format: Much cleaner than flat matrices. attrs[target_layer][source_layer] = Tensor[target_d, source_d] is intuitive.
  • Canonical naming: Smart translation boundary - concrete paths internally, canonical names (embed, output, 0.glu.up) for storage.
  • Output residual storage: Elegant solution storing attributions to pre-unembed residual rather than full vocab, reducing storage from O(V×C×V) to O(V×C×d_model).

Attribution Metrics

All three metrics are mathematically correct:

  • attr: E[∂y/∂x · x] - standard signed attribution
  • attr_abs: E[∂|y|/∂x · x] - correctly computed via backprop through .abs(), not sign flipping
  • mean_squared_attr: E[(∂y/∂x · x)²] - stored pre-sqrt for proper merging across workers

Bug Fixes Verified

  • alive_targets bool→index: Line 269 correctly uses torch.where(alive_targets)[0].tolist() instead of boolean indexing
  • scatter_add_ bounds: Uses vocab_size instead of num_embeddings preventing OOB errors
  • attr_abs sign: Correctly computes via abs_grad * act * source_ci from backprop through .abs()

Worker/Merge Implementation

  • Clean separation: run_worker.py for computation, run_merge.py for combining
  • Weighted averaging: Merge correctly weights by n_tokens, promotes to double precision to avoid numerical issues
  • Robust handling: Proper validation of CI thresholds and layer consistency across workers

Storage & Serialization

  • Save/load roundtrip: Comprehensive unit tests verify correctness
  • CPU tensor storage: Proper device handling
  • Type safety: Good use of Pydantic and jaxtyping

Performance Optimizations

  • scatter_add_ for embedding tokens vs vectorized .add_() for components (mentioned 14x speedup)
  • Summed gradients: Reduces backward passes from O(positions × components) to O(components)

Minor Observations

  1. has_source/has_target implementation (lines 82-102 in storage.py): The logic correctly handles different layer types and bounds checking.

  2. Gradient connectivity discovery remains unchanged and looks solid - correctly tests pairs and handles unused gradients.

  3. Unit test coverage is excellent (17 tests covering storage, merge, bounds checking).

Verdict: LGTM 🚀

This is a solid rewrite that improves both the interface and fixes real bugs. The storage format is more intuitive, the attribution calculations are mathematically sound, and the implementation handles edge cases properly. The extensive unit tests give good confidence in correctness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant