Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
447 commits
Select commit Hold shift + click to select a range
8906628
refactor to use MergeRunConfig everywhere
mivanit Aug 9, 2025
b197aad
Wip
mivanit Aug 10, 2025
65e7b93
wip
mivanit Aug 10, 2025
4fa6b2d
wip
mivanit Aug 10, 2025
78ea395
some info to yaml
mivanit Aug 10, 2025
21a244c
todos
mivanit Aug 11, 2025
82e38c6
LFGgsgsgs
mivanit Aug 12, 2025
99993f5
wip
mivanit Aug 12, 2025
5a3f78f
refactor configs. new system with ability to choose sampling mechanis…
mivanit Aug 12, 2025
69f1ac7
refactor makefile to use new configs
mivanit Aug 12, 2025
d28b1c1
add tests. code still not comitted yet
mivanit Aug 12, 2025
cfb17bb
format
mivanit Aug 12, 2025
6bcad5c
fix tests
mivanit Aug 12, 2025
acadfac
test configs
mivanit Aug 12, 2025
6294631
TESTS PASSIN LFGGGG
mivanit Aug 12, 2025
543a457
format
mivanit Aug 12, 2025
d984d3d
pyright fixes
mivanit Aug 12, 2025
d44c176
Merge branch 'dev' into feature/clustering
mivanit Aug 12, 2025
d2796fe
fix some pyright issues
mivanit Aug 13, 2025
b50c08f
parallelizing tests
mivanit Aug 13, 2025
11a25fe
distributed tests in CI
mivanit Aug 13, 2025
ed719f1
fix action
mivanit Aug 13, 2025
06ffad4
try to make tests faster
mivanit Aug 13, 2025
f96e864
remove experiment with no canonical run
mivanit Aug 13, 2025
fdb2572
try to debug issue with normalizing ensemble
mivanit Aug 13, 2025
1fe9ad3
[important] remove old files
mivanit Aug 13, 2025
d7a3343
move the merge pair samplers code to math folder
mivanit Aug 13, 2025
4e04542
fix import in tests
mivanit Aug 13, 2025
9bff3a1
default to cpu if no cuda in spd-cluster
mivanit Aug 13, 2025
9c31c1b
default to cpu if no cuda in spd-cluster
mivanit Aug 13, 2025
863ae60
wip wandb logging for spd-cluster refactor
mivanit Aug 13, 2025
f3e0b48
more wip wandb logging for spd-cluster refactor
mivanit Aug 13, 2025
edae24c
format?
mivanit Aug 14, 2025
a159f60
wandb log tensor info
mivanit Aug 14, 2025
9c2cea6
wandb log tensor info wip
mivanit Aug 14, 2025
7e73cf7
wandb log tensor info wip
mivanit Aug 14, 2025
969243a
wandb log tensor info wip
mivanit Aug 14, 2025
b7b8f61
some figs on wandb
mivanit Aug 14, 2025
4b1a9fb
format
mivanit Aug 14, 2025
b25823c
wip
mivanit Aug 14, 2025
14b896c
[temp] ignore config rename/deprecated warns
mivanit Aug 14, 2025
7c51471
wip
mivanit Aug 14, 2025
c6f5b44
wip
mivanit Aug 14, 2025
3d03fed
wip
mivanit Aug 14, 2025
1d278b7
wip
mivanit Aug 14, 2025
74fe2f5
wip
mivanit Aug 14, 2025
fae9ca7
wip
mivanit Aug 14, 2025
6eeb11f
wip
mivanit Aug 14, 2025
4e54996
wip
mivanit Aug 14, 2025
ad4fc74
wip
mivanit Aug 14, 2025
54d1921
wip
mivanit Aug 14, 2025
8293a15
Merge branch 'dev' into feature/clustering
mivanit Aug 14, 2025
4c310f1
[hack] lack of canonical runs causing error
mivanit Aug 14, 2025
60dc336
wip
mivanit Aug 14, 2025
b76c93c
wip, cleaning up s2
mivanit Aug 14, 2025
481beda
wip
mivanit Aug 15, 2025
4c6446c
nathan approved MDL cost computation
mivanit Aug 15, 2025
e377f07
swap to log2 in mdl
mivanit Aug 15, 2025
8e91e6e
log MDL cost to wandb, other minor fixes
mivanit Aug 15, 2025
162f229
wip
mivanit Aug 15, 2025
bc7f178
wip
mivanit Aug 15, 2025
852ece2
wip
mivanit Aug 15, 2025
5f78a7e
some logging stuff fixed. now tryna figure out why way too many merge…
mivanit Aug 15, 2025
d00e7b9
[important] got run loading to work right!!!
mivanit Aug 15, 2025
4ee62db
wip
mivanit Aug 15, 2025
2a0edfd
format
mivanit Aug 15, 2025
f6c62d2
[important] FEATURE COMPLETE??
mivanit Aug 15, 2025
b5f7483
format
mivanit Aug 15, 2025
82602c0
minor changes to wandb stuff
mivanit Aug 15, 2025
5d0a5d0
fix tests?
mivanit Aug 15, 2025
69c239e
format
mivanit Aug 15, 2025
6c8ba33
fix pyright errors, some warnings remain
mivanit Aug 15, 2025
0a57602
pyright passing!!!
mivanit Aug 15, 2025
26a5631
removed old ignores in pyproject.toml, fixed the resulting errors
mivanit Aug 15, 2025
e1ee44e
?
mivanit Aug 15, 2025
1fc039c
fix
mivanit Aug 15, 2025
2117700
Merge branch 'dev' into feature/clustering
mivanit Aug 15, 2025
32d6e9f
fix registry
mivanit Aug 15, 2025
057d82e
numprocess to auto for pytest with xdist
mivanit Aug 15, 2025
b4ee5ee
move the TaskName definition
mivanit Aug 15, 2025
a035627
remove pyright ignore
mivanit Aug 15, 2025
820b6cf
move clustering stuff into clustering folders
mivanit Aug 15, 2025
87dcd44
format
mivanit Aug 15, 2025
0adbd18
fix type error
mivanit Aug 15, 2025
11e6091
no more dep graph
mivanit Aug 15, 2025
ad47db2
remove outdated TODO.md
mivanit Aug 15, 2025
5e13e96
remove old ruff exclude
mivanit Aug 15, 2025
79b953e
remove unused function
mivanit Aug 15, 2025
3ab600e
remove `from __future__ import annotations`
mivanit Aug 17, 2025
3f288b2
factor out component filtering, add tests for it
mivanit Aug 17, 2025
2c8270b
tensor stats randomized for big matrix
mivanit Aug 17, 2025
70de4e1
wip
mivanit Aug 17, 2025
9b55cd1
minor fix related to passing filtered labels
mivanit Aug 18, 2025
485b3c7
histograms as wandb Images instead of plotly
mivanit Aug 18, 2025
354ea73
wip
mivanit Aug 18, 2025
64b3d30
factor out some logging stuff, lots of minor changes
mivanit Aug 18, 2025
3777b1d
wandb log semilog of merge pair cost
mivanit Aug 18, 2025
39613f7
oops, fix semilog
mivanit Aug 18, 2025
eb6ffb8
log hists for stuff
mivanit Aug 18, 2025
4a2cdac
close figures to avoid memory leaks, add some logging to help profile
mivanit Aug 19, 2025
de5e7ad
wip, added some logging
mivanit Aug 19, 2025
ce96406
logging stuff
mivanit Aug 19, 2025
6640430
pyright fixes
mivanit Aug 19, 2025
01059b2
make the merge history more lightweight
mivanit Aug 19, 2025
c03969f
wip
mivanit Aug 19, 2025
f6ec48e
format
mivanit Aug 19, 2025
371e015
intervals dict
mivanit Aug 19, 2025
58b50f3
fix issue with zip file closed
mivanit Aug 19, 2025
a0e621d
wip
mivanit Aug 19, 2025
f735c5a
remove merge profiling code
mivanit Aug 19, 2025
f8f98e4
get rid of some old comments
mivanit Sep 4, 2025
7247800
more explanation around popping
mivanit Sep 4, 2025
ae3ec43
components_in_pop_grp -> n_components_in_pop_grp
mivanit Sep 4, 2025
8e86541
Merge branch 'main' into feature/clustering
mivanit Sep 6, 2025
14f44e7
remove ignore deprecated config warnings todo
mivanit Sep 6, 2025
1a900ef
some typing fixes
mivanit Sep 6, 2025
4de2154
fix pyright issues by pinning `transformers` package
mivanit Sep 6, 2025
038ad58
fix annoying mock test issue
mivanit Sep 10, 2025
069dffc
Merge branch 'main' into feature/clustering
mivanit Sep 10, 2025
a99c2b1
fix pyright issue
mivanit Sep 10, 2025
8b7ecfd
remove call to deprecated plotting function
mivanit Sep 10, 2025
faca7bb
Merge branch 'main' into feature/clustering
mivanit Sep 10, 2025
933a71d
remove dead test
oclivegriffin Sep 23, 2025
c051b69
Merge branch 'main' into feature/clustering
oclivegriffin Sep 24, 2025
7d2885d
give sigmoid type
oclivegriffin Sep 24, 2025
f1799b3
update default ss decomp run
mivanit Sep 24, 2025
f192b69
Merge branch 'main' into feature/clustering
mivanit Sep 24, 2025
91242aa
wip
oclivegriffin Sep 24, 2025
c631c61
give sigmoid type
oclivegriffin Sep 24, 2025
0cafdd6
wip
oclivegriffin Sep 24, 2025
5c45477
merge
oclivegriffin Sep 24, 2025
3026407
fix interface, store ComponentModel.module_paths
mivanit Sep 24, 2025
14a0c89
patched model -> target model when using config
mivanit Sep 24, 2025
95bbdb5
fix: no more patched_model, use target_model for config
mivanit Sep 25, 2025
675fb5e
comments
mivanit Sep 25, 2025
7ff6246
fix cuda_memory_used.py
mivanit Sep 25, 2025
caf09f1
removed `spd/clustering/math/dev.py`
mivanit Sep 28, 2025
2b6449f
`StatsKeys` -> `StatsKey`
mivanit Sep 28, 2025
4c57042
remove `plot_merge_history_costs`
mivanit Sep 28, 2025
6617a1d
remove old js-embedding-vis dep
mivanit Sep 28, 2025
6e6da1d
Merge branch 'main' into feature/clustering
mivanit Sep 28, 2025
8b2a58b
Oli clustering refactor (#172)
mivanit Sep 30, 2025
df72249
fixes/improvements to dist_utils
mivanit Sep 30, 2025
b00243e
Merge branch 'main' into feature/clustering
mivanit Sep 30, 2025
667c836
Merge branch 'main' into feature/clustering
mivanit Oct 6, 2025
9efcbca
type fixes?
mivanit Oct 6, 2025
24eff75
wip
mivanit Oct 6, 2025
c29e868
Sync everything from feature/clustering-dashboard except spd/clusteri…
mivanit Oct 6, 2025
961894f
minimizing diff
mivanit Oct 6, 2025
cf45a57
minimize pyprojec.toml diff
mivanit Oct 6, 2025
d3a21e9
minimizing diff, removed deps
mivanit Oct 6, 2025
6789ca8
uv sync
mivanit Oct 6, 2025
1ab117d
fixing state dict mappings
mivanit Oct 6, 2025
9baa5e3
Merge branch 'fix/state-dict-key-mapping' into feature/clustering
mivanit Oct 6, 2025
d45df85
test parallelization
mivanit Oct 6, 2025
86b6f95
parallelize tests
mivanit Oct 6, 2025
7c3a2f0
device getting utils
mivanit Oct 6, 2025
61f0482
add TaskName type
mivanit Oct 6, 2025
7144f18
uv sync (pytest-xdist dep)
mivanit Oct 6, 2025
2130566
Merge branch 'main' into feature/clustering
mivanit Oct 6, 2025
7951b31
remove old junk from Makefile
mivanit Oct 6, 2025
4c9b6be
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
44f8c94
globally unique ports in tests to allow parallelization
mivanit Oct 6, 2025
88022ea
comments explaining port allocation in tests
mivanit Oct 6, 2025
5bcd8a3
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
0e38131
Merge branch 'main' into feature/clustering
danbraunai-goodfire Oct 6, 2025
3f55ffa
add distributed marker, rull all distributed tests on same worker
mivanit Oct 6, 2025
29c2738
Revert "add distributed marker, rull all distributed tests on same wo…
mivanit Oct 6, 2025
f5b3288
add distributed marker, rull all distributed tests on same worker
mivanit Oct 6, 2025
c3dca4a
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
3b116f5
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
2e60193
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
deb59a8
refactor: use general_utils methods for getting device everywhere
mivanit Oct 6, 2025
bac464c
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
2e29b2f
Merge branch 'main' into feature/clustering
mivanit Oct 6, 2025
2eb96fc
wip jaccard
mivanit Oct 6, 2025
c71c696
wip jaccard
mivanit Oct 6, 2025
f9b228c
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
8a7a423
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
a25a9a9
wip jaccard
mivanit Oct 6, 2025
d23823d
wip jaccard (plotting)
mivanit Oct 6, 2025
6280591
found where to increase timeout
mivanit Oct 6, 2025
85f789b
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
db59ce6
wip jaccard
mivanit Oct 6, 2025
e218796
make format
mivanit Oct 6, 2025
6225e0b
fixes
mivanit Oct 6, 2025
9f4c347
typing fixes
mivanit Oct 6, 2025
b12403a
claude doing a bunch of type hinting
mivanit Oct 6, 2025
9a06d29
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
7f605e5
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
f112af9
trying to get pyright passing?
mivanit Oct 6, 2025
cfc03ed
pyright works both locally and in CI
mivanit Oct 6, 2025
a92ae11
allow installing cpu-only torch in CI
mivanit Oct 6, 2025
5ad324c
figure out CI disk usage by tests on main
mivanit Oct 7, 2025
7978c44
alternate strategy for install
mivanit Oct 7, 2025
98ba633
fixes to the last commit
mivanit Oct 7, 2025
989c2dd
cleanup temp changes
mivanit Oct 7, 2025
5f9500f
make in CI
mivanit Oct 7, 2025
7eeea74
wip
mivanit Oct 7, 2025
4bc5728
wip
mivanit Oct 7, 2025
3c17aee
wip
mivanit Oct 7, 2025
1110c2d
uv sync
mivanit Oct 7, 2025
9cfc051
try to fix markupsafe?
mivanit Oct 7, 2025
23c41dd
pin markup safe with explanation
mivanit Oct 7, 2025
63d6432
update lockfile??
mivanit Oct 7, 2025
1d8da48
Merge branch 'main' into fix/ci-disk-usage
mivanit Oct 7, 2025
1bb5ac8
nope i think we need the index strategy
mivanit Oct 7, 2025
75a0efe
?
mivanit Oct 7, 2025
019e1b3
markupsafe issue
mivanit Oct 7, 2025
3082817
remove disk usage printing
mivanit Oct 7, 2025
d520a54
fix pyright issue
mivanit Oct 7, 2025
1eab6fe
dependency hell
mivanit Oct 7, 2025
3a05d5a
fix deps???
mivanit Oct 7, 2025
d053433
oops, missing index strategy. moved to makefile
mivanit Oct 7, 2025
25aa615
re-lock
mivanit Oct 7, 2025
c1ffcda
make from /usr/bin/ ?
mivanit Oct 7, 2025
f006884
dependency hell
mivanit Oct 7, 2025
533ad20
type checking hell
mivanit Oct 7, 2025
4524965
Update spd/utils/general_utils.py
mivanit Oct 7, 2025
15b314e
wrap and fix Conv1D imports
mivanit Oct 7, 2025
f28b670
minimize diff cleanup
mivanit Oct 7, 2025
aa98d4a
try compile-bytecode for ci install
mivanit Oct 7, 2025
740f6a2
dont compile bytecode actually
mivanit Oct 7, 2025
cfe1f81
remove markupsafe constraint?
mivanit Oct 7, 2025
17803ed
switched to use get_obj_device
mivanit Oct 7, 2025
153d044
remove device: torch.device type hints
mivanit Oct 7, 2025
2c60412
remove "distributed" test marker
mivanit Oct 7, 2025
44c6cc5
fix another timeout
mivanit Oct 7, 2025
0d5137c
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 7, 2025
a48ce55
Merge branch 'fix/ci-disk-usage' into feature/clustering
mivanit Oct 7, 2025
055f3cc
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 7, 2025
b1604bb
replace get_module_device -> get_obj_device
mivanit Oct 7, 2025
d20ea4f
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 7, 2025
42f58a7
better comments on port uniqueness
mivanit Oct 7, 2025
c722ddd
remove old markers
mivanit Oct 7, 2025
b472f5d
remove timeout TODO comments
mivanit Oct 7, 2025
921010b
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 7, 2025
888f5f2
Merge branch 'main' into feature/clustering
mivanit Oct 7, 2025
5d092e8
removed checks.yaml timeout todo, clustering tests pass in ~12min
mivanit Oct 7, 2025
5eb1af7
[diff-min] transformers version issue from #139 resolved
mivanit Oct 7, 2025
dc6e32c
fix comment
mivanit Oct 7, 2025
aae70da
wip jaccard
mivanit Oct 7, 2025
9c6103f
pyright fixes to jaccard, wip
mivanit Oct 8, 2025
6ccbd79
Update docs about grad syncing with DDP
danbraunai-goodfire Oct 8, 2025
fe0de02
Mention feature/memorization-experiments in README
danbraunai-goodfire Oct 8, 2025
0bdbd4e
Fix train and eval metrics and hidden_act_recon (#189)
danbraunai-goodfire Oct 9, 2025
a2bcaa3
Update canonical runs and change target model path (#197)
danbraunai-goodfire Oct 9, 2025
dadb9c7
Avoid using too many processes in tests
danbraunai-goodfire Oct 9, 2025
882d659
Merge branch 'main' into feature/clustering
mivanit Oct 10, 2025
e7e1b1d
fix wandb model paths to older runs
mivanit Oct 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
spd/scripts/sweep_params.yaml
spd/scripts/sweep_params.yaml
docs/coverage/**
artifacts/**
docs/dep_graph/**
tests/.temp/**

**/out/
neuronpedia_outputs/
Expand Down
73 changes: 73 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# TODO: Cluster Coactivation Matrix Implementation

## What Was Changed

### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`)
- New dataclass to hold vectorized cluster activations for all clusters
- Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list

### 2. Added `compute_all_cluster_activations()` function
- Vectorized computation of all cluster activations at once
- Replaces the per-cluster loop for better performance
- Returns `ClusterActivations` object

### 3. Added `compute_cluster_coactivations()` function
- Computes coactivation matrix from list of `ClusterActivations` across batches
- Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask`
- Follows the pattern from `spd/clustering/merge.py:69`
- Returns tuple of (coactivation_matrix, cluster_indices)

### 4. Modified `compute_max_activations()` function
- Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list
- Calls `compute_cluster_coactivations()` to compute the matrix
- **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]`
- Added coactivation matrix and cluster_indices to return value

### 5. Modified `spd/clustering/dashboard/run.py`
- Updated to handle new return value from `compute_max_activations()`
- Saves coactivation matrix as `coactivations.npz` in the dashboard output directory
- NPZ file contains:
- `coactivations`: the [n_clusters, n_clusters] matrix
- `cluster_indices`: array mapping matrix positions to cluster IDs

## What Needs to be Checked

### Testing
- [ ] **Run the dashboard pipeline** on a real clustering run to verify:
- Coactivation computation doesn't crash
- Coactivations are saved correctly to NPZ file
- Matrix dimensions are correct
- `cluster_indices` mapping is correct

### Type Checking
- [ ] Run `make type` to ensure no type errors were introduced
- [ ] Verify jaxtyping annotations are correct

### Verification
- [ ] Load a saved `coactivations.npz` file and verify:
```python
data = np.load("coactivations.npz")
coact = data["coactivations"]
cluster_indices = data["cluster_indices"]
# Check: coact should be symmetric
# Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most)
# Check: cluster_indices length should match coact.shape[0]
```

### Performance
- [ ] Check if vectorization actually improved performance
- [ ] Monitor memory usage with large numbers of clusters

### Edge Cases
- [ ] Test with clusters that have zero activations
- [ ] Test with single-batch runs
- [ ] Test with very large number of clusters

### Integration
- [ ] Verify the coactivation matrix can be used in downstream analysis
- [ ] Consider if visualization of coactivations should be added to dashboard

## Notes
- The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples)
- Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable
- The computation happens in the dashboard pipeline, NOT during the main clustering pipeline
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ dependencies = [
# see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 )
"datasets>=2.21.0",
"simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev",
"scipy>=1.14.1",
"muutils",
]

[dependency-groups]
Expand All @@ -42,6 +44,7 @@ dev = [

[project.scripts]
spd-run = "spd.scripts.run:cli"
spd-cluster = "spd.clustering.scripts.main:cli"

[build-system]
requires = ["setuptools", "wheel"]
Expand Down
Empty file added spd/clustering/__init__.py
Empty file.
269 changes: 269 additions & 0 deletions spd/clustering/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Literal, NamedTuple

import torch
from jaxtyping import Bool, Float, Float16, Int
from torch import Tensor

from spd.clustering.consts import (
ActivationsTensor,
BoolActivationsTensor,
ClusterCoactivationShaped,
ComponentLabels,
)
from spd.clustering.util import ModuleFilterFunc
from spd.models.component_model import ComponentModel, OutputWithCache
from spd.models.sigmoids import SigmoidTypes


def component_activations(
model: ComponentModel,
device: torch.device | str,
batch: Int[Tensor, "batch_size n_ctx"],
sigmoid_type: SigmoidTypes,
) -> dict[str, ActivationsTensor]:
"""Get the component activations over a **single** batch."""
causal_importances: dict[str, ActivationsTensor]
with torch.no_grad():
model_output: OutputWithCache = model(
batch.to(device),
cache_type="input",
)

causal_importances, _ = model.calc_causal_importances(
pre_weight_acts=model_output.cache,
sigmoid_type=sigmoid_type,
sampling="continuous",
detach_inputs=False,
)

return causal_importances


def compute_coactivatons(
activations: ActivationsTensor | BoolActivationsTensor,
) -> ClusterCoactivationShaped:
"""Compute the coactivations matrix from the activations."""
# TODO: this works for both boolean and continuous activations,
# but we could do better by just using OR for boolean activations
# and maybe even some bitshift hacks. but for now, we convert to float16
activations_f16: Float16[Tensor, "samples C"] = activations.to(torch.float16)
return activations_f16.T @ activations_f16


class FilteredActivations(NamedTuple):
activations: ActivationsTensor
"activations after filtering dead components"

labels: ComponentLabels
"list of length c with labels for each preserved component"

dead_components_labels: ComponentLabels | None
"list of labels for dead components, or None if no filtering was applied"

@property
def n_alive(self) -> int:
"""Number of alive components after filtering."""
n_alive: int = len(self.labels)
assert n_alive == self.activations.shape[1], (
f"{n_alive = } != {self.activations.shape[1] = }"
)
return n_alive

@property
def n_dead(self) -> int:
"""Number of dead components after filtering."""
return len(self.dead_components_labels) if self.dead_components_labels else 0


def filter_dead_components(
activations: ActivationsTensor,
labels: ComponentLabels,
filter_dead_threshold: float = 0.01,
) -> FilteredActivations:
"""Filter out dead components based on a threshold

if `filter_dead_threshold` is 0, no filtering is applied.
activations and labels are returned as is, `dead_components_labels` is `None`.

otherwise, components whose **maximum** activations across all samples is below the threshold
are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`.
`dead_components_labels` will also be `None` if no components were below the threshold.
"""
dead_components_lst: ComponentLabels | None = None
if filter_dead_threshold > 0:
dead_components_lst = ComponentLabels(list())
max_act: Float[Tensor, " c"] = activations.max(dim=0).values
dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold

if dead_components.any():
activations = activations[:, ~dead_components]
alive_labels: list[tuple[str, bool]] = [
(lbl, bool(keep.item()))
for lbl, keep in zip(labels, ~dead_components, strict=False)
]
# re-assign labels only if we are filtering
labels = ComponentLabels([label for label, keep in alive_labels if keep])
dead_components_lst = ComponentLabels(
[label for label, keep in alive_labels if not keep]
)

return FilteredActivations(
activations=activations,
labels=labels,
dead_components_labels=dead_components_lst if dead_components_lst else None,
)


@dataclass(frozen=True)
class ProcessedActivations:
"""Processed activations after filtering and concatenation"""

activations_raw: dict[str, ActivationsTensor]
"activations after filtering, but prior to concatenation"

activations: ActivationsTensor
"activations after filtering and concatenation"

labels: ComponentLabels
"list of length c with labels for each preserved component, format `{module_name}:{component_index}`"

dead_components_lst: ComponentLabels | None
"list of labels for dead components, or None if no filtering was applied"

def validate(self) -> None:
"""Validate the processed activations"""
# getting this property will also perform a variety of other checks
assert self.n_components_alive > 0

@property
def n_components_original(self) -> int:
"""Total number of components before filtering. equal to the sum of all components in `activations_raw`, or to `n_components_alive + n_components_dead`"""
return sum(act.shape[1] for act in self.activations_raw.values())

@property
def n_components_alive(self) -> int:
"""Number of alive components after filtering. equal to the length of `labels`"""
n_alive: int = len(self.labels)
assert n_alive + self.n_components_dead == self.n_components_original, (
f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })"
)
assert n_alive == self.activations.shape[1], (
f"{n_alive = } != {self.activations.shape[1] = }"
)

return n_alive

@property
def n_components_dead(self) -> int:
"""Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise"""
return len(self.dead_components_lst) if self.dead_components_lst else 0

@cached_property
def label_index(self) -> dict[str, int | None]:
"""Create a mapping from label to alive index (`None` if dead)"""
return {
**{label: i for i, label in enumerate(self.labels)},
**(
{label: None for label in self.dead_components_lst}
if self.dead_components_lst
else {}
),
}

def get_label_index(self, label: str) -> int | None:
"""Get the index of a label in the activations, or None if it is dead"""
return self.label_index[label]

def get_label_index_alive(self, label: str) -> int:
"""Get the index of a label in the activations, or raise if it is dead"""
idx: int | None = self.get_label_index(label)
if idx is None:
raise ValueError(f"Label '{label}' is dead and has no index in the activations.")
return idx

@property
def module_keys(self) -> list[str]:
"""Get the module keys from the activations_raw"""
return list(self.activations_raw.keys())

def get_module_indices(self, module_key: str) -> list[int | None]:
"""given a module key, return a list len "num components in that moduel", with int index in alive components, or None if dead"""
num_components: int = self.activations_raw[module_key].shape[1]
return [self.label_index[f"{module_key}:{i}"] for i in range(num_components)]


def process_activations(
activations: dict[
str, # module name to
Float[Tensor, "samples C"] # (sample x component gate activations)
| Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations)
],
filter_dead_threshold: float = 0.01,
seq_mode: Literal["concat", "seq_mean", None] = None,
filter_modules: ModuleFilterFunc | None = None,
) -> ProcessedActivations:
"""get back a dict of coactivations, slices, and concated activations

Args:
activations: Dictionary of activations by module
filter_dead_threshold: Threshold for filtering dead components
seq_mode: How to handle sequence dimension
filter_modules: Function to filter modules
sort_components: Whether to sort components by similarity within each module
"""

# reshape -- special cases for llms
# ============================================================
activations_: dict[str, ActivationsTensor]
if seq_mode == "concat":
# Concatenate the sequence dimension into the sample dimension
activations_ = {
key: act.reshape(act.shape[0] * act.shape[1], act.shape[2])
for key, act in activations.items()
}
elif seq_mode == "seq_mean":
# Take the mean over the sequence dimension
activations_ = {
key: act.mean(dim=1) if act.ndim == 3 else act for key, act in activations.items()
}
else:
# Use the activations as they are
activations_ = activations

# put the labelled activations into one big matrix and filter them
# ============================================================

# filter activations for only the modules we want
if filter_modules is not None:
activations_ = {key: act for key, act in activations_.items() if filter_modules(key)}

# compute the labels and total component count
total_c: int = 0
labels: ComponentLabels = ComponentLabels(list())
for key, act in activations_.items():
c: int = act.shape[-1]
labels.extend([f"{key}:{i}" for i in range(c)])
total_c += c

# concat the activations
act_concat: ActivationsTensor = torch.cat([activations_[key] for key in activations_], dim=-1)

# filter dead components
filtered_components: FilteredActivations = filter_dead_components(
activations=act_concat,
labels=labels,
filter_dead_threshold=filter_dead_threshold,
)

assert filtered_components.n_alive + filtered_components.n_dead == total_c, (
f"({filtered_components.n_alive = }) + ({filtered_components.n_dead = }) != ({total_c = })"
)

return ProcessedActivations(
activations_raw=activations_,
activations=filtered_components.activations,
labels=filtered_components.labels,
dead_components_lst=filtered_components.dead_components_labels,
)
Loading