Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
705 commits
Select commit Hold shift + click to select a range
88022ea
comments explaining port allocation in tests
mivanit Oct 6, 2025
5bcd8a3
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
0e38131
Merge branch 'main' into feature/clustering
danbraunai-goodfire Oct 6, 2025
3f55ffa
add distributed marker, rull all distributed tests on same worker
mivanit Oct 6, 2025
29c2738
Revert "add distributed marker, rull all distributed tests on same wo…
mivanit Oct 6, 2025
f5b3288
add distributed marker, rull all distributed tests on same worker
mivanit Oct 6, 2025
c3dca4a
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
3b116f5
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
2e60193
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
4f72b11
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 6, 2025
d06ff34
add "num_nonsingleton_groups" metric to spd-cluster
mivanit Oct 6, 2025
f39e251
uv sync
mivanit Oct 6, 2025
b058a2d
wip jaccard
mivanit Oct 6, 2025
d7c3258
make format
mivanit Oct 6, 2025
deb59a8
refactor: use general_utils methods for getting device everywhere
mivanit Oct 6, 2025
bac464c
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
2e29b2f
Merge branch 'main' into feature/clustering
mivanit Oct 6, 2025
e23d6a8
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 6, 2025
2eb96fc
wip jaccard
mivanit Oct 6, 2025
c71c696
wip jaccard
mivanit Oct 6, 2025
f9b228c
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
8a7a423
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
89530b1
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 6, 2025
a25a9a9
wip jaccard
mivanit Oct 6, 2025
d23823d
wip jaccard (plotting)
mivanit Oct 6, 2025
6280591
found where to increase timeout
mivanit Oct 6, 2025
85f789b
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
db59ce6
wip jaccard
mivanit Oct 6, 2025
e6fb87a
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 6, 2025
3ddd037
wip???
mivanit Oct 6, 2025
e218796
make format
mivanit Oct 6, 2025
6225e0b
fixes
mivanit Oct 6, 2025
9f4c347
typing fixes
mivanit Oct 6, 2025
b12403a
claude doing a bunch of type hinting
mivanit Oct 6, 2025
9a06d29
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 6, 2025
7f605e5
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 6, 2025
f112af9
trying to get pyright passing?
mivanit Oct 6, 2025
cfc03ed
pyright works both locally and in CI
mivanit Oct 6, 2025
45fe5be
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 6, 2025
87d6071
[temp] debugging disk usage
mivanit Oct 6, 2025
d910d93
[temp] still debugging disk usage
mivanit Oct 6, 2025
39f68d1
[temp] more debug disk usage
mivanit Oct 6, 2025
d18845c
[temp] disk debug
mivanit Oct 6, 2025
a4784ad
[temp]
mivanit Oct 6, 2025
4665aab
[temp] remove docker? 🤔🤔🤔
mivanit Oct 6, 2025
04ef2fe
[temp] uv custom stuff
mivanit Oct 6, 2025
0a53bbb
ugh
mivanit Oct 6, 2025
51ae645
wip
mivanit Oct 6, 2025
cce22bb
wip
mivanit Oct 6, 2025
59d62eb
du in root takes forever????
mivanit Oct 6, 2025
50f7510
wip
mivanit Oct 6, 2025
9aa618b
add back docker container
mivanit Oct 6, 2025
a92ae11
allow installing cpu-only torch in CI
mivanit Oct 6, 2025
5ad324c
figure out CI disk usage by tests on main
mivanit Oct 7, 2025
7978c44
alternate strategy for install
mivanit Oct 7, 2025
98ba633
fixes to the last commit
mivanit Oct 7, 2025
989c2dd
cleanup temp changes
mivanit Oct 7, 2025
5f9500f
make in CI
mivanit Oct 7, 2025
7eeea74
wip
mivanit Oct 7, 2025
4bc5728
wip
mivanit Oct 7, 2025
3c17aee
wip
mivanit Oct 7, 2025
1110c2d
uv sync
mivanit Oct 7, 2025
9cfc051
try to fix markupsafe?
mivanit Oct 7, 2025
23c41dd
pin markup safe with explanation
mivanit Oct 7, 2025
63d6432
update lockfile??
mivanit Oct 7, 2025
1d8da48
Merge branch 'main' into fix/ci-disk-usage
mivanit Oct 7, 2025
1bb5ac8
nope i think we need the index strategy
mivanit Oct 7, 2025
75a0efe
?
mivanit Oct 7, 2025
019e1b3
markupsafe issue
mivanit Oct 7, 2025
3082817
remove disk usage printing
mivanit Oct 7, 2025
d520a54
fix pyright issue
mivanit Oct 7, 2025
1eab6fe
dependency hell
mivanit Oct 7, 2025
3a05d5a
fix deps???
mivanit Oct 7, 2025
d053433
oops, missing index strategy. moved to makefile
mivanit Oct 7, 2025
25aa615
re-lock
mivanit Oct 7, 2025
c1ffcda
make from /usr/bin/ ?
mivanit Oct 7, 2025
f006884
dependency hell
mivanit Oct 7, 2025
533ad20
type checking hell
mivanit Oct 7, 2025
4524965
Update spd/utils/general_utils.py
mivanit Oct 7, 2025
15b314e
wrap and fix Conv1D imports
mivanit Oct 7, 2025
f28b670
minimize diff cleanup
mivanit Oct 7, 2025
aa98d4a
try compile-bytecode for ci install
mivanit Oct 7, 2025
740f6a2
dont compile bytecode actually
mivanit Oct 7, 2025
cfe1f81
remove markupsafe constraint?
mivanit Oct 7, 2025
17803ed
switched to use get_obj_device
mivanit Oct 7, 2025
153d044
remove device: torch.device type hints
mivanit Oct 7, 2025
2c60412
remove "distributed" test marker
mivanit Oct 7, 2025
44c6cc5
fix another timeout
mivanit Oct 7, 2025
0d5137c
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 7, 2025
a48ce55
Merge branch 'fix/ci-disk-usage' into feature/clustering
mivanit Oct 7, 2025
055f3cc
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 7, 2025
afcea86
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 7, 2025
b1604bb
replace get_module_device -> get_obj_device
mivanit Oct 7, 2025
d20ea4f
Merge branch 'main' into refactor/clustering-prereqs
mivanit Oct 7, 2025
42f58a7
better comments on port uniqueness
mivanit Oct 7, 2025
c722ddd
remove old markers
mivanit Oct 7, 2025
b472f5d
remove timeout TODO comments
mivanit Oct 7, 2025
921010b
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit Oct 7, 2025
888f5f2
Merge branch 'main' into feature/clustering
mivanit Oct 7, 2025
5d092e8
removed checks.yaml timeout todo, clustering tests pass in ~12min
mivanit Oct 7, 2025
5eb1af7
[diff-min] transformers version issue from #139 resolved
mivanit Oct 7, 2025
dc6e32c
fix comment
mivanit Oct 7, 2025
aae70da
wip jaccard
mivanit Oct 7, 2025
9c6103f
pyright fixes to jaccard, wip
mivanit Oct 8, 2025
c38a6ec
Merge branch 'feature/clustering' into feature/clustering-dashboard
mivanit Oct 8, 2025
9b4c7d4
properly merge checks.yaml
mivanit Oct 8, 2025
fc6e57e
fix torch dep issues
mivanit Oct 8, 2025
1890401
wip
mivanit Oct 8, 2025
bbb1c0d
wip dashboard
mivanit Oct 8, 2025
5475528
wip
mivanit Oct 8, 2025
5bd70d9
wip dashboard. bit of alpine, better filters, kinda fixed hists
mivanit Oct 8, 2025
cb10f87
wip
mivanit Oct 8, 2025
4fa3c10
wip
mivanit Oct 8, 2025
9fb9ac1
wip
mivanit Oct 8, 2025
c31a3af
refactor core dashboard data gen
mivanit Oct 8, 2025
6370d10
make format
mivanit Oct 8, 2025
360a164
wip big refactor of compute_max_act
mivanit Oct 8, 2025
b2d62a5
more refactor, split up the big one
mivanit Oct 8, 2025
95e0203
filtering working, table uses id & not key. more data refactor
mivanit Oct 8, 2025
7a1b296
format and bundle
mivanit Oct 8, 2025
fc5e538
wip, more spinners
mivanit Oct 8, 2025
66e7efd
wip. better logging, idk
mivanit Oct 8, 2025
c4adf0a
wip
mivanit Oct 8, 2025
af38dd2
wip
mivanit Oct 8, 2025
b5d0f6d
wip
mivanit Oct 8, 2025
0feb1cc
wip
mivanit Oct 8, 2025
74f4a2a
make tokenizer decode go brr?
mivanit Oct 8, 2025
116a2cc
tokenizer go brr
mivanit Oct 8, 2025
b1de815
tokenization!
mivanit Oct 8, 2025
03877ee
wip
mivanit Oct 8, 2025
0db9957
[!!!] cleanup
mivanit Oct 8, 2025
5c40c5e
wip
mivanit Oct 8, 2025
398c8ef
wip
mivanit Oct 8, 2025
f8ede56
cleanup
mivanit Oct 8, 2025
cc37557
wip
mivanit Oct 8, 2025
ff6af18
pyright fixes
mivanit Oct 8, 2025
57e17f9
old
mivanit Oct 9, 2025
6ccbd79
Update docs about grad syncing with DDP
danbraunai-goodfire Oct 8, 2025
fe0de02
Mention feature/memorization-experiments in README
danbraunai-goodfire Oct 8, 2025
0bdbd4e
Fix train and eval metrics and hidden_act_recon (#189)
danbraunai-goodfire Oct 9, 2025
a2bcaa3
Update canonical runs and change target model path (#197)
danbraunai-goodfire Oct 9, 2025
dadb9c7
Avoid using too many processes in tests
danbraunai-goodfire Oct 9, 2025
882d659
Merge branch 'main' into feature/clustering
mivanit Oct 10, 2025
e7e1b1d
fix wandb model paths to older runs
mivanit Oct 10, 2025
3401f84
Merge branch 'clustering/main' into feature/clustering-merge-history-…
mivanit Oct 10, 2025
ca8a367
Merge branch 'feature/clustering-merge-history-adapt' into feature/cl…
mivanit Oct 10, 2025
61f9822
delete old temp json file
mivanit Oct 10, 2025
a782c5e
minimize diff
mivanit Oct 10, 2025
6f08e7e
remove traces of component activation data
mivanit Oct 10, 2025
f54c7fd
add `make clustering-dashboard` to CI
mivanit Oct 10, 2025
55f0b5c
comment out a lot of component-level data code
mivanit Oct 10, 2025
5d059af
fix test issues resulting from bytes arr -> unicode arr tok switch
mivanit Oct 10, 2025
07b7564
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 10, 2025
3bd4f19
fix makefile env usage
mivanit Oct 10, 2025
7261cca
wandb api key for the clustering dashboard CI test
mivanit Oct 10, 2025
6f32e76
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 13, 2025
cfb7cd2
merge tokenization and text_processing files
mivanit Oct 13, 2025
beb3649
fix import relating to last commit
mivanit Oct 13, 2025
89e1464
reorg
mivanit Oct 13, 2025
e9fba91
rename frontend config file
mivanit Oct 13, 2025
f059257
add `DashboardConfig` instead of just tons of cli args
mivanit Oct 13, 2025
4c6d1cf
wip reorg
mivanit Oct 13, 2025
e0ace71
clean up interface of compute_max_activations
mivanit Oct 13, 2025
fea5725
wip
mivanit Oct 13, 2025
bc2d2de
wip
mivanit Oct 13, 2025
e06be80
some renaming, working on using SubComponentInfo instead of str label
mivanit Oct 13, 2025
8246f10
working on using SubComponentInfo in more places
mivanit Oct 13, 2025
10a42c7
more switching to SubComponentInfo
mivanit Oct 13, 2025
c13e921
move `_zip_save_arr*` merge_history.py -> data_utils.py
mivanit Oct 13, 2025
633523b
wip
mivanit Oct 13, 2025
3eff182
wip refactor, continued
mivanit Oct 13, 2025
4e6afca
wip
mivanit Oct 13, 2025
9c33140
big rename `subcomponents` -> `subcomponent_keys`
mivanit Oct 13, 2025
5aabc1a
fix tests, add hashing
mivanit Oct 13, 2025
7c11139
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 13, 2025
480acc6
DashboardConfig inherits from BaseConfig instead of BaseModel
mivanit Oct 13, 2025
e9cad1f
wip
mivanit Oct 13, 2025
16d05ef
wip
mivanit Oct 13, 2025
6b45faa
wip
mivanit Oct 14, 2025
5e6c6d9
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 14, 2025
2d322ee
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 14, 2025
14e679e
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 16, 2025
b30b338
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 20, 2025
22ecf45
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 20, 2025
3a206ed
allow specifying either config path or mrc cfg in pipeline cfg
mivanit Oct 21, 2025
eb831c0
[wip] reorg configs
mivanit Oct 21, 2025
89e5c36
added default `None` for slurm partition and job name prefix
mivanit Oct 21, 2025
8910bb4
refactor configs, add config tests
mivanit Oct 21, 2025
0b957f5
fix tests
mivanit Oct 21, 2025
7de545b
allow `None` or `-1` idx_in_ensemble
mivanit Oct 21, 2025
3d45ac4
whoops, wrong name on fixture
mivanit Oct 21, 2025
4adde10
fix idx passed in tests when not needed
mivanit Oct 21, 2025
189b64a
rename "mrc" -> "crc" in paths
mivanit Oct 21, 2025
57f445a
rename merge_run_config.py -> clustering_run_config.py
mivanit Oct 21, 2025
91f5348
fix pyright
mivanit Oct 21, 2025
11e5501
fix idx_in_ensemble being passed in tests
mivanit Oct 21, 2025
1d96054
rename cache dir 'merge_run_configs' -> 'clustering_run_configs'
mivanit Oct 21, 2025
a1f1146
remove component popping
mivanit Oct 21, 2025
1e3fbb2
dont pass batch size, change not brought in here
mivanit Oct 21, 2025
76757c7
uv sync
mivanit Oct 21, 2025
9ebcf53
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 21, 2025
00397d4
fix clustering dashboard!
mivanit Oct 21, 2025
1bbf34f
add zanj.js from https://github.com/mivanit/zanj.js
mivanit Oct 21, 2025
6ae1dc7
zanj dep. big refactor incoming
mivanit Oct 21, 2025
ef40780
major refactor. very broken, but not even due to refactor
mivanit Oct 21, 2025
3de57c9
continued
mivanit Oct 21, 2025
efec7e7
Merge branch 'clustering/add-mrc-inline-in-pipeline-cfg' into feature…
mivanit Oct 22, 2025
e86adc9
fix history_path extension and storage usage
mivanit Oct 22, 2025
b8bbb08
dev pipeline
mivanit Oct 22, 2025
733d47f
better config validation tests
mivanit Oct 22, 2025
2fa1f21
set default base output dir
mivanit Oct 22, 2025
eec80fb
wandb use run id for clustering, TODO for spd decomp
mivanit Oct 22, 2025
6098536
basedpyright 1.32.0 causes issues, esp w/ wandb
mivanit Oct 22, 2025
4f0a761
swap out clustering run
mivanit Oct 22, 2025
ca72074
Merge branch 'clustering/add-mrc-inline-in-pipeline-cfg' into feature…
mivanit Oct 22, 2025
02dd44a
fix merge
mivanit Oct 22, 2025
0c94f1e
fix history path
mivanit Oct 22, 2025
d4f3c4d
pyright fixes
mivanit Oct 22, 2025
d612da0
pyright stuff
mivanit Oct 22, 2025
305624e
better asserts for debugging
mivanit Oct 22, 2025
3df9900
[!!!] semi working again
mivanit Oct 22, 2025
bd60f2c
Merge branch 'clustering/main' into clustering/add-mrc-inline-in-pipe…
mivanit Oct 23, 2025
4a0f53f
wip
mivanit Oct 23, 2025
d6a5930
fix token display?
mivanit Oct 23, 2025
3da458a
[!!!] WORKING
mivanit Oct 23, 2025
097b511
wip
mivanit Oct 24, 2025
4cf945f
testing via playwright
mivanit Oct 24, 2025
c28d0eb
wip testing framework
mivanit Oct 24, 2025
11315c8
wip
mivanit Oct 24, 2025
72d8589
working but not in tests?
mivanit Oct 24, 2025
40df505
remove idx_in_ensemble, always auto-assigned now
mivanit Oct 24, 2025
bd8a442
Merge branch 'clustering/main' into clustering/add-mrc-inline-in-pipe…
mivanit Oct 24, 2025
7339a9f
wip
mivanit Oct 24, 2025
c93aea4
format
mivanit Oct 24, 2025
8e95223
something fixed???
mivanit Oct 24, 2025
5fc21a3
!!!!!!!!!!!!!! PASSING?????
mivanit Oct 24, 2025
4561ed7
claude still working on stuff
mivanit Oct 24, 2025
bc15921
make playwright tests work in CI?
mivanit Oct 24, 2025
3caf5e9
fix pyright issues
mivanit Oct 24, 2025
cf64a79
only allow passing clustering run config path, not inline
mivanit Oct 24, 2025
2a9f731
rename run_clustering_config_path -> clustering_run_config_path
mivanit Oct 24, 2025
59875b9
remove junk from readme
mivanit Oct 24, 2025
7bbd792
Merge branch 'clustering/main' into feature/clustering-dashboard
miv-goodfire Oct 24, 2025
9be9929
Merge branch 'clustering/add-mrc-inline-in-pipeline-cfg' into feature…
mivanit Oct 24, 2025
8df885e
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 24, 2025
5c40ef6
properly install playwright
mivanit Oct 24, 2025
2062b18
dataset streaming for dashboard tests
mivanit Oct 24, 2025
e5ca948
remove temp CI step, we now test this in pytest
mivanit Oct 24, 2025
040abd5
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 27, 2025
687ef4c
Merge branch 'clustering/main' into feature/clustering-dashboard
mivanit Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ jobs:
- name: Print dependencies
run: uv pip list

- name: Install Playwright
run: make install-playwright

- name: Run basedpyright
run: uv run basedpyright

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ docs/coverage/**
artifacts/**
docs/dep_graph/**
tests/.temp/**
*.prof

**/out/
neuronpedia_outputs/
Expand Down
44 changes: 44 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,50 @@ coverage:
uv run python -m coverage html --directory=$(COVERAGE_DIR)/html/


BUNDLED_DASHBOARD_DIR=spd/clustering/dashboard/_bundled

.PHONY: bundle-dashboard
bundle-dashboard:
@mkdir -p $(BUNDLED_DASHBOARD_DIR)
uv run python -m muutils.web.bundle_html \
spd/clustering/dashboard/index.html \
--output $(BUNDLED_DASHBOARD_DIR)/index.html \
--source-dir spd/clustering/dashboard
uv run python -m muutils.web.bundle_html \
spd/clustering/dashboard/cluster.html \
--output $(BUNDLED_DASHBOARD_DIR)/cluster.html \
--source-dir spd/clustering/dashboard
@echo "Bundled HTML files to $(BUNDLED_DASHBOARD_DIR)/"

.PHONY: clean-test-dashboard
clean-test-dashboard:
rm -rf tests/.temp/dashboard-integration


.PHONY: install-playwright
install-playwright:
@echo "Install Playwright browsers, used for dashboard tests"
uv run playwright install chromium
uv run playwright install-deps

.PHONY: test-dashboard
test-dashboard: clean-test-dashboard bundle-dashboard
pytest tests/clustering/dashboard/test_dashboard_integration.py --runslow -v --durations 10


.PHONY: clustering-dashboard
clustering-dashboard: bundle-dashboard
uv run python spd/clustering/dashboard/run.py \
spd/clustering/dashboard/dashboard_config.yaml

.PHONY: clustering-dashboard-profile
clustering-dashboard-profile: bundle-dashboard
uv run python -m cProfile -o dashboard.prof spd/clustering/dashboard/run.py \
spd/clustering/dashboard/dashboard_config.yaml
@echo "\nProfile saved to dashboard.prof"
@echo "View with: python -m pstats dashboard.prof"
@echo "Or install snakeviz and run: snakeviz dashboard.prof"

.PHONY: clean
clean:
@echo "Cleaning Python cache and build artifacts..."
Expand Down
73 changes: 0 additions & 73 deletions TODO.md

This file was deleted.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev",
"scipy>=1.14.1",
"muutils",
"zanj", # for dashboard data saving/loading
"fastapi",
"uvicorn",
]
Expand All @@ -42,6 +43,7 @@ dev = [
"ruff",
"basedpyright<1.32.0", # pyright and wandb issues, see https://github.com/goodfire-ai/spd/pull/232
"pre-commit",
"playwright", # for browser-based integration tests
]

[project.scripts]
Expand Down
78 changes: 38 additions & 40 deletions spd/clustering/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ActivationsTensor,
BoolActivationsTensor,
ClusterCoactivationShaped,
ComponentLabels,
SubComponentKey,
)
from spd.clustering.util import ModuleFilterFunc
from spd.models.component_model import ComponentModel, OutputWithCache
Expand Down Expand Up @@ -54,16 +54,16 @@ class FilteredActivations(NamedTuple):
activations: ActivationsTensor
"activations after filtering dead components"

labels: ComponentLabels
"list of length c with labels for each preserved component"
subcomponent_keys: list[SubComponentKey]
"list of length c with SubComponentInfo for each preserved component"

dead_components_labels: ComponentLabels | None
"list of labels for dead components, or None if no filtering was applied"
dead_subcomponent_keys: list[SubComponentKey] | None
"list of SubComponentInfo for dead components, or None if no filtering was applied"

@property
def n_alive(self) -> int:
"""Number of alive components after filtering."""
n_alive: int = len(self.labels)
n_alive: int = len(self.subcomponent_keys)
assert n_alive == self.activations.shape[1], (
f"{n_alive = } != {self.activations.shape[1] = }"
)
Expand All @@ -72,12 +72,12 @@ def n_alive(self) -> int:
@property
def n_dead(self) -> int:
"""Number of dead components after filtering."""
return len(self.dead_components_labels) if self.dead_components_labels else 0
return len(self.dead_subcomponent_keys) if self.dead_subcomponent_keys else 0


def filter_dead_components(
activations: ActivationsTensor,
labels: ComponentLabels,
subcomponent_keys: list[SubComponentKey],
filter_dead_threshold: float = 0.01,
) -> FilteredActivations:
"""Filter out dead components based on a threshold
Expand All @@ -86,31 +86,29 @@ def filter_dead_components(
activations and labels are returned as is, `dead_components_labels` is `None`.

otherwise, components whose **maximum** activations across all samples is below the threshold
are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`.
are considered dead and filtered out. The SubComponentInfo of these components are returned in `dead_components_labels`.
`dead_components_labels` will also be `None` if no components were below the threshold.
"""
dead_components_lst: ComponentLabels | None = None
dead_components_lst: list[SubComponentKey] | None = None
if filter_dead_threshold > 0:
dead_components_lst = ComponentLabels(list())
dead_components_lst = []
max_act: Float[Tensor, " c"] = activations.max(dim=0).values
dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold
dead_components_mask: Bool[Tensor, " c"] = max_act < filter_dead_threshold

if dead_components.any():
activations = activations[:, ~dead_components]
alive_labels: list[tuple[str, bool]] = [
(lbl, bool(keep.item()))
for lbl, keep in zip(labels, ~dead_components, strict=False)
if dead_components_mask.any():
activations = activations[:, ~dead_components_mask]
alive_labels: list[tuple[SubComponentKey, bool]] = [
(comp, bool(keep.item()))
for comp, keep in zip(subcomponent_keys, ~dead_components_mask, strict=False)
]
# re-assign labels only if we are filtering
labels = ComponentLabels([label for label, keep in alive_labels if keep])
dead_components_lst = ComponentLabels(
[label for label, keep in alive_labels if not keep]
)
subcomponent_keys = [comp for comp, keep in alive_labels if keep]
dead_components_lst = [comp for comp, keep in alive_labels if not keep]

return FilteredActivations(
activations=activations,
labels=labels,
dead_components_labels=dead_components_lst if dead_components_lst else None,
subcomponent_keys=subcomponent_keys,
dead_subcomponent_keys=dead_components_lst if dead_components_lst else None,
)


Expand All @@ -124,11 +122,11 @@ class ProcessedActivations:
activations: ActivationsTensor
"activations after filtering and concatenation"

labels: ComponentLabels
"list of length c with labels for each preserved component, format `{module_name}:{component_index}`"
subcomponent_keys: list[SubComponentKey]
"list of length c with SubComponentInfo for each preserved component"

dead_components_lst: ComponentLabels | None
"list of labels for dead components, or None if no filtering was applied"
dead_subcomponent_keys: list[SubComponentKey] | None
"list of SubComponentInfo for dead components, or None if no filtering was applied"

def validate(self) -> None:
"""Validate the processed activations"""
Expand All @@ -143,7 +141,7 @@ def n_components_original(self) -> int:
@property
def n_components_alive(self) -> int:
"""Number of alive components after filtering. equal to the length of `labels`"""
n_alive: int = len(self.labels)
n_alive: int = len(self.subcomponent_keys)
assert n_alive + self.n_components_dead == self.n_components_original, (
f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })"
)
Expand All @@ -156,26 +154,26 @@ def n_components_alive(self) -> int:
@property
def n_components_dead(self) -> int:
"""Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise"""
return len(self.dead_components_lst) if self.dead_components_lst else 0
return len(self.dead_subcomponent_keys) if self.dead_subcomponent_keys else 0

@cached_property
def label_index(self) -> dict[str, int | None]:
"""Create a mapping from label to alive index (`None` if dead)"""
"""Create a mapping from label string to alive index (`None` if dead)"""
return {
**{label: i for i, label in enumerate(self.labels)},
**{comp.label: i for i, comp in enumerate(self.subcomponent_keys)},
**(
{label: None for label in self.dead_components_lst}
if self.dead_components_lst
{comp.label: None for comp in self.dead_subcomponent_keys}
if self.dead_subcomponent_keys
else {}
),
}

def get_label_index(self, label: str) -> int | None:
"""Get the index of a label in the activations, or None if it is dead"""
"""Get the index of a label string in the activations, or None if it is dead"""
return self.label_index[label]

def get_label_index_alive(self, label: str) -> int:
"""Get the index of a label in the activations, or raise if it is dead"""
"""Get the index of a label string in the activations, or raise if it is dead"""
idx: int | None = self.get_label_index(label)
if idx is None:
raise ValueError(f"Label '{label}' is dead and has no index in the activations.")
Expand Down Expand Up @@ -239,10 +237,10 @@ def process_activations(

# compute the labels and total component count
total_c: int = 0
labels: ComponentLabels = ComponentLabels(list())
labels: list[SubComponentKey] = []
for key, act in activations_.items():
c: int = act.shape[-1]
labels.extend([f"{key}:{i}" for i in range(c)])
labels.extend([SubComponentKey(module=key, index=i) for i in range(c)])
total_c += c

# concat the activations
Expand All @@ -251,7 +249,7 @@ def process_activations(
# filter dead components
filtered_components: FilteredActivations = filter_dead_components(
activations=act_concat,
labels=labels,
subcomponent_keys=labels,
filter_dead_threshold=filter_dead_threshold,
)

Expand All @@ -262,6 +260,6 @@ def process_activations(
return ProcessedActivations(
activations_raw=activations_,
activations=filtered_components.activations,
labels=filtered_components.labels,
dead_components_lst=filtered_components.dead_components_labels,
subcomponent_keys=filtered_components.subcomponent_keys,
dead_subcomponent_keys=filtered_components.dead_subcomponent_keys,
)
39 changes: 36 additions & 3 deletions spd/clustering/consts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Constants and shared abstractions for clustering pipeline."""

import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, NewType
from typing import Literal, NewType, override

import numpy as np
from jaxtyping import Bool, Float, Int
Expand All @@ -15,8 +17,39 @@
DistancesArray = Float[np.ndarray, "n_iters n_ens n_ens"]

# Component and label types (NewType for stronger type safety)
ComponentLabel = NewType("ComponentLabel", str) # Format: "module_name:component_index"
ComponentLabels = NewType("ComponentLabels", list[str])
SubComponentLabel = NewType("SubComponentLabel", str) # Format: "module_name:component_index"


@dataclass(frozen=True, slots=True, kw_only=True)
class SubComponentKey:
"""unique identifier of a subcomponent. indices can refer to dead components"""

module: str
index: int

@property
def label(self) -> SubComponentLabel:
"""Component label as 'module:index'."""
return SubComponentLabel(f"{self.module}:{self.index}")

@classmethod
def from_label(cls, label: SubComponentLabel) -> "SubComponentKey":
"""Create SubComponentInfo from a component label."""
assert label.count(":") == 1, (
"Invalid component label format, expected '{{module}}:{{index}}'"
)
module, index_str = label.rsplit(":", 1)
return cls(module=module, index=int(index_str))

@override
def __str__(self) -> str:
return self.label

@override
def __hash__(self) -> int:
return int(hashlib.md5(str(self).encode()).hexdigest(), 16)


BatchId = NewType("BatchId", str)

# Path types
Expand Down
Loading