-
Notifications
You must be signed in to change notification settings - Fork 35
clustering #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
clustering #43
Changes from all commits
Commits
Show all changes
447 commits
Select commit
Hold shift + click to select a range
8906628
refactor to use MergeRunConfig everywhere
mivanit b197aad
Wip
mivanit 65e7b93
wip
mivanit 4fa6b2d
wip
mivanit 78ea395
some info to yaml
mivanit 21a244c
todos
mivanit 82e38c6
LFGgsgsgs
mivanit 99993f5
wip
mivanit 5a3f78f
refactor configs. new system with ability to choose sampling mechanis…
mivanit 69f1ac7
refactor makefile to use new configs
mivanit d28b1c1
add tests. code still not comitted yet
mivanit cfb17bb
format
mivanit 6bcad5c
fix tests
mivanit acadfac
test configs
mivanit 6294631
TESTS PASSIN LFGGGG
mivanit 543a457
format
mivanit d984d3d
pyright fixes
mivanit d44c176
Merge branch 'dev' into feature/clustering
mivanit d2796fe
fix some pyright issues
mivanit b50c08f
parallelizing tests
mivanit 11a25fe
distributed tests in CI
mivanit ed719f1
fix action
mivanit 06ffad4
try to make tests faster
mivanit f96e864
remove experiment with no canonical run
mivanit fdb2572
try to debug issue with normalizing ensemble
mivanit 1fe9ad3
[important] remove old files
mivanit d7a3343
move the merge pair samplers code to math folder
mivanit 4e04542
fix import in tests
mivanit 9bff3a1
default to cpu if no cuda in spd-cluster
mivanit 9c31c1b
default to cpu if no cuda in spd-cluster
mivanit 863ae60
wip wandb logging for spd-cluster refactor
mivanit f3e0b48
more wip wandb logging for spd-cluster refactor
mivanit edae24c
format?
mivanit a159f60
wandb log tensor info
mivanit 9c2cea6
wandb log tensor info wip
mivanit 7e73cf7
wandb log tensor info wip
mivanit 969243a
wandb log tensor info wip
mivanit b7b8f61
some figs on wandb
mivanit 4b1a9fb
format
mivanit b25823c
wip
mivanit 14b896c
[temp] ignore config rename/deprecated warns
mivanit 7c51471
wip
mivanit c6f5b44
wip
mivanit 3d03fed
wip
mivanit 1d278b7
wip
mivanit 74fe2f5
wip
mivanit fae9ca7
wip
mivanit 6eeb11f
wip
mivanit 4e54996
wip
mivanit ad4fc74
wip
mivanit 54d1921
wip
mivanit 8293a15
Merge branch 'dev' into feature/clustering
mivanit 4c310f1
[hack] lack of canonical runs causing error
mivanit 60dc336
wip
mivanit b76c93c
wip, cleaning up s2
mivanit 481beda
wip
mivanit 4c6446c
nathan approved MDL cost computation
mivanit e377f07
swap to log2 in mdl
mivanit 8e91e6e
log MDL cost to wandb, other minor fixes
mivanit 162f229
wip
mivanit bc7f178
wip
mivanit 852ece2
wip
mivanit 5f78a7e
some logging stuff fixed. now tryna figure out why way too many merge…
mivanit d00e7b9
[important] got run loading to work right!!!
mivanit 4ee62db
wip
mivanit 2a0edfd
format
mivanit f6c62d2
[important] FEATURE COMPLETE??
mivanit b5f7483
format
mivanit 82602c0
minor changes to wandb stuff
mivanit 5d0a5d0
fix tests?
mivanit 69c239e
format
mivanit 6c8ba33
fix pyright errors, some warnings remain
mivanit 0a57602
pyright passing!!!
mivanit 26a5631
removed old ignores in pyproject.toml, fixed the resulting errors
mivanit e1ee44e
?
mivanit 1fc039c
fix
mivanit 2117700
Merge branch 'dev' into feature/clustering
mivanit 32d6e9f
fix registry
mivanit 057d82e
numprocess to auto for pytest with xdist
mivanit b4ee5ee
move the TaskName definition
mivanit a035627
remove pyright ignore
mivanit 820b6cf
move clustering stuff into clustering folders
mivanit 87dcd44
format
mivanit 0adbd18
fix type error
mivanit 11e6091
no more dep graph
mivanit ad47db2
remove outdated TODO.md
mivanit 5e13e96
remove old ruff exclude
mivanit 79b953e
remove unused function
mivanit 3ab600e
remove `from __future__ import annotations`
mivanit 3f288b2
factor out component filtering, add tests for it
mivanit 2c8270b
tensor stats randomized for big matrix
mivanit 70de4e1
wip
mivanit 9b55cd1
minor fix related to passing filtered labels
mivanit 485b3c7
histograms as wandb Images instead of plotly
mivanit 354ea73
wip
mivanit 64b3d30
factor out some logging stuff, lots of minor changes
mivanit 3777b1d
wandb log semilog of merge pair cost
mivanit 39613f7
oops, fix semilog
mivanit eb6ffb8
log hists for stuff
mivanit 4a2cdac
close figures to avoid memory leaks, add some logging to help profile
mivanit de5e7ad
wip, added some logging
mivanit ce96406
logging stuff
mivanit 6640430
pyright fixes
mivanit 01059b2
make the merge history more lightweight
mivanit c03969f
wip
mivanit f6ec48e
format
mivanit 371e015
intervals dict
mivanit 58b50f3
fix issue with zip file closed
mivanit a0e621d
wip
mivanit f735c5a
remove merge profiling code
mivanit f8f98e4
get rid of some old comments
mivanit 7247800
more explanation around popping
mivanit ae3ec43
components_in_pop_grp -> n_components_in_pop_grp
mivanit 8e86541
Merge branch 'main' into feature/clustering
mivanit 14f44e7
remove ignore deprecated config warnings todo
mivanit 1a900ef
some typing fixes
mivanit 4de2154
fix pyright issues by pinning `transformers` package
mivanit 038ad58
fix annoying mock test issue
mivanit 069dffc
Merge branch 'main' into feature/clustering
mivanit a99c2b1
fix pyright issue
mivanit 8b7ecfd
remove call to deprecated plotting function
mivanit faca7bb
Merge branch 'main' into feature/clustering
mivanit 933a71d
remove dead test
oclivegriffin c051b69
Merge branch 'main' into feature/clustering
oclivegriffin 7d2885d
give sigmoid type
oclivegriffin f1799b3
update default ss decomp run
mivanit f192b69
Merge branch 'main' into feature/clustering
mivanit 91242aa
wip
oclivegriffin c631c61
give sigmoid type
oclivegriffin 0cafdd6
wip
oclivegriffin 5c45477
merge
oclivegriffin 3026407
fix interface, store ComponentModel.module_paths
mivanit 14a0c89
patched model -> target model when using config
mivanit 95bbdb5
fix: no more patched_model, use target_model for config
mivanit 675fb5e
comments
mivanit 7ff6246
fix cuda_memory_used.py
mivanit caf09f1
removed `spd/clustering/math/dev.py`
mivanit 2b6449f
`StatsKeys` -> `StatsKey`
mivanit 4c57042
remove `plot_merge_history_costs`
mivanit 6617a1d
remove old js-embedding-vis dep
mivanit 6e6da1d
Merge branch 'main' into feature/clustering
mivanit 8b2a58b
Oli clustering refactor (#172)
mivanit df72249
fixes/improvements to dist_utils
mivanit b00243e
Merge branch 'main' into feature/clustering
mivanit 667c836
Merge branch 'main' into feature/clustering
mivanit 9efcbca
type fixes?
mivanit 24eff75
wip
mivanit c29e868
Sync everything from feature/clustering-dashboard except spd/clusteri…
mivanit 961894f
minimizing diff
mivanit cf45a57
minimize pyprojec.toml diff
mivanit d3a21e9
minimizing diff, removed deps
mivanit 6789ca8
uv sync
mivanit 1ab117d
fixing state dict mappings
mivanit 9baa5e3
Merge branch 'fix/state-dict-key-mapping' into feature/clustering
mivanit d45df85
test parallelization
mivanit 86b6f95
parallelize tests
mivanit 7c3a2f0
device getting utils
mivanit 61f0482
add TaskName type
mivanit 7144f18
uv sync (pytest-xdist dep)
mivanit 2130566
Merge branch 'main' into feature/clustering
mivanit 7951b31
remove old junk from Makefile
mivanit 4c9b6be
Merge branch 'main' into refactor/clustering-prereqs
mivanit 44f8c94
globally unique ports in tests to allow parallelization
mivanit 88022ea
comments explaining port allocation in tests
mivanit 5bcd8a3
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit 0e38131
Merge branch 'main' into feature/clustering
danbraunai-goodfire 3f55ffa
add distributed marker, rull all distributed tests on same worker
mivanit 29c2738
Revert "add distributed marker, rull all distributed tests on same wo…
mivanit f5b3288
add distributed marker, rull all distributed tests on same worker
mivanit c3dca4a
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit 3b116f5
Merge branch 'main' into refactor/clustering-prereqs
mivanit 2e60193
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit deb59a8
refactor: use general_utils methods for getting device everywhere
mivanit bac464c
Merge branch 'main' into refactor/clustering-prereqs
mivanit 2e29b2f
Merge branch 'main' into feature/clustering
mivanit 2eb96fc
wip jaccard
mivanit c71c696
wip jaccard
mivanit f9b228c
Merge branch 'main' into refactor/clustering-prereqs
mivanit 8a7a423
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit a25a9a9
wip jaccard
mivanit d23823d
wip jaccard (plotting)
mivanit 6280591
found where to increase timeout
mivanit 85f789b
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit db59ce6
wip jaccard
mivanit e218796
make format
mivanit 6225e0b
fixes
mivanit 9f4c347
typing fixes
mivanit b12403a
claude doing a bunch of type hinting
mivanit 9a06d29
Merge branch 'main' into refactor/clustering-prereqs
mivanit 7f605e5
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit f112af9
trying to get pyright passing?
mivanit cfc03ed
pyright works both locally and in CI
mivanit a92ae11
allow installing cpu-only torch in CI
mivanit 5ad324c
figure out CI disk usage by tests on main
mivanit 7978c44
alternate strategy for install
mivanit 98ba633
fixes to the last commit
mivanit 989c2dd
cleanup temp changes
mivanit 5f9500f
make in CI
mivanit 7eeea74
wip
mivanit 4bc5728
wip
mivanit 3c17aee
wip
mivanit 1110c2d
uv sync
mivanit 9cfc051
try to fix markupsafe?
mivanit 23c41dd
pin markup safe with explanation
mivanit 63d6432
update lockfile??
mivanit 1d8da48
Merge branch 'main' into fix/ci-disk-usage
mivanit 1bb5ac8
nope i think we need the index strategy
mivanit 75a0efe
?
mivanit 019e1b3
markupsafe issue
mivanit 3082817
remove disk usage printing
mivanit d520a54
fix pyright issue
mivanit 1eab6fe
dependency hell
mivanit 3a05d5a
fix deps???
mivanit d053433
oops, missing index strategy. moved to makefile
mivanit 25aa615
re-lock
mivanit c1ffcda
make from /usr/bin/ ?
mivanit f006884
dependency hell
mivanit 533ad20
type checking hell
mivanit 4524965
Update spd/utils/general_utils.py
mivanit 15b314e
wrap and fix Conv1D imports
mivanit f28b670
minimize diff cleanup
mivanit aa98d4a
try compile-bytecode for ci install
mivanit 740f6a2
dont compile bytecode actually
mivanit cfe1f81
remove markupsafe constraint?
mivanit 17803ed
switched to use get_obj_device
mivanit 153d044
remove device: torch.device type hints
mivanit 2c60412
remove "distributed" test marker
mivanit 44c6cc5
fix another timeout
mivanit 0d5137c
Merge branch 'main' into refactor/clustering-prereqs
mivanit a48ce55
Merge branch 'fix/ci-disk-usage' into feature/clustering
mivanit 055f3cc
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit b1604bb
replace get_module_device -> get_obj_device
mivanit d20ea4f
Merge branch 'main' into refactor/clustering-prereqs
mivanit 42f58a7
better comments on port uniqueness
mivanit c722ddd
remove old markers
mivanit b472f5d
remove timeout TODO comments
mivanit 921010b
Merge branch 'refactor/clustering-prereqs' into feature/clustering
mivanit 888f5f2
Merge branch 'main' into feature/clustering
mivanit 5d092e8
removed checks.yaml timeout todo, clustering tests pass in ~12min
mivanit 5eb1af7
[diff-min] transformers version issue from #139 resolved
mivanit dc6e32c
fix comment
mivanit aae70da
wip jaccard
mivanit 9c6103f
pyright fixes to jaccard, wip
mivanit 6ccbd79
Update docs about grad syncing with DDP
danbraunai-goodfire fe0de02
Mention feature/memorization-experiments in README
danbraunai-goodfire 0bdbd4e
Fix train and eval metrics and hidden_act_recon (#189)
danbraunai-goodfire a2bcaa3
Update canonical runs and change target model path (#197)
danbraunai-goodfire dadb9c7
Avoid using too many processes in tests
danbraunai-goodfire 882d659
Merge branch 'main' into feature/clustering
mivanit e7e1b1d
fix wandb model paths to older runs
mivanit File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # TODO: Cluster Coactivation Matrix Implementation | ||
|
|
||
| ## What Was Changed | ||
|
|
||
| ### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`) | ||
| - New dataclass to hold vectorized cluster activations for all clusters | ||
| - Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list | ||
|
|
||
| ### 2. Added `compute_all_cluster_activations()` function | ||
| - Vectorized computation of all cluster activations at once | ||
| - Replaces the per-cluster loop for better performance | ||
| - Returns `ClusterActivations` object | ||
|
|
||
| ### 3. Added `compute_cluster_coactivations()` function | ||
| - Computes coactivation matrix from list of `ClusterActivations` across batches | ||
| - Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask` | ||
| - Follows the pattern from `spd/clustering/merge.py:69` | ||
| - Returns tuple of (coactivation_matrix, cluster_indices) | ||
|
|
||
| ### 4. Modified `compute_max_activations()` function | ||
| - Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list | ||
| - Calls `compute_cluster_coactivations()` to compute the matrix | ||
| - **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]` | ||
| - Added coactivation matrix and cluster_indices to return value | ||
|
|
||
| ### 5. Modified `spd/clustering/dashboard/run.py` | ||
| - Updated to handle new return value from `compute_max_activations()` | ||
| - Saves coactivation matrix as `coactivations.npz` in the dashboard output directory | ||
| - NPZ file contains: | ||
| - `coactivations`: the [n_clusters, n_clusters] matrix | ||
| - `cluster_indices`: array mapping matrix positions to cluster IDs | ||
|
|
||
| ## What Needs to be Checked | ||
|
|
||
| ### Testing | ||
| - [ ] **Run the dashboard pipeline** on a real clustering run to verify: | ||
| - Coactivation computation doesn't crash | ||
| - Coactivations are saved correctly to NPZ file | ||
| - Matrix dimensions are correct | ||
| - `cluster_indices` mapping is correct | ||
|
|
||
| ### Type Checking | ||
| - [ ] Run `make type` to ensure no type errors were introduced | ||
| - [ ] Verify jaxtyping annotations are correct | ||
|
|
||
| ### Verification | ||
| - [ ] Load a saved `coactivations.npz` file and verify: | ||
| ```python | ||
| data = np.load("coactivations.npz") | ||
| coact = data["coactivations"] | ||
| cluster_indices = data["cluster_indices"] | ||
| # Check: coact should be symmetric | ||
| # Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most) | ||
| # Check: cluster_indices length should match coact.shape[0] | ||
| ``` | ||
|
|
||
| ### Performance | ||
| - [ ] Check if vectorization actually improved performance | ||
| - [ ] Monitor memory usage with large numbers of clusters | ||
|
|
||
| ### Edge Cases | ||
| - [ ] Test with clusters that have zero activations | ||
| - [ ] Test with single-batch runs | ||
| - [ ] Test with very large number of clusters | ||
|
|
||
| ### Integration | ||
| - [ ] Verify the coactivation matrix can be used in downstream analysis | ||
| - [ ] Consider if visualization of coactivations should be added to dashboard | ||
|
|
||
| ## Notes | ||
| - The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples) | ||
| - Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable | ||
| - The computation happens in the dashboard pipeline, NOT during the main clustering pipeline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,269 @@ | ||
| from dataclasses import dataclass | ||
| from functools import cached_property | ||
| from typing import Literal, NamedTuple | ||
|
|
||
| import torch | ||
| from jaxtyping import Bool, Float, Float16, Int | ||
| from torch import Tensor | ||
|
|
||
| from spd.clustering.consts import ( | ||
| ActivationsTensor, | ||
| BoolActivationsTensor, | ||
| ClusterCoactivationShaped, | ||
| ComponentLabels, | ||
| ) | ||
| from spd.clustering.util import ModuleFilterFunc | ||
| from spd.models.component_model import ComponentModel, OutputWithCache | ||
| from spd.models.sigmoids import SigmoidTypes | ||
|
|
||
|
|
||
| def component_activations( | ||
| model: ComponentModel, | ||
| device: torch.device | str, | ||
| batch: Int[Tensor, "batch_size n_ctx"], | ||
| sigmoid_type: SigmoidTypes, | ||
| ) -> dict[str, ActivationsTensor]: | ||
| """Get the component activations over a **single** batch.""" | ||
| causal_importances: dict[str, ActivationsTensor] | ||
| with torch.no_grad(): | ||
| model_output: OutputWithCache = model( | ||
| batch.to(device), | ||
| cache_type="input", | ||
| ) | ||
|
|
||
| causal_importances, _ = model.calc_causal_importances( | ||
| pre_weight_acts=model_output.cache, | ||
| sigmoid_type=sigmoid_type, | ||
| sampling="continuous", | ||
| detach_inputs=False, | ||
| ) | ||
|
|
||
| return causal_importances | ||
|
|
||
|
|
||
| def compute_coactivatons( | ||
| activations: ActivationsTensor | BoolActivationsTensor, | ||
| ) -> ClusterCoactivationShaped: | ||
| """Compute the coactivations matrix from the activations.""" | ||
| # TODO: this works for both boolean and continuous activations, | ||
| # but we could do better by just using OR for boolean activations | ||
| # and maybe even some bitshift hacks. but for now, we convert to float16 | ||
| activations_f16: Float16[Tensor, "samples C"] = activations.to(torch.float16) | ||
| return activations_f16.T @ activations_f16 | ||
|
|
||
|
|
||
| class FilteredActivations(NamedTuple): | ||
| activations: ActivationsTensor | ||
| "activations after filtering dead components" | ||
|
|
||
| labels: ComponentLabels | ||
| "list of length c with labels for each preserved component" | ||
|
|
||
| dead_components_labels: ComponentLabels | None | ||
| "list of labels for dead components, or None if no filtering was applied" | ||
|
|
||
| @property | ||
| def n_alive(self) -> int: | ||
| """Number of alive components after filtering.""" | ||
| n_alive: int = len(self.labels) | ||
| assert n_alive == self.activations.shape[1], ( | ||
| f"{n_alive = } != {self.activations.shape[1] = }" | ||
| ) | ||
| return n_alive | ||
|
|
||
| @property | ||
| def n_dead(self) -> int: | ||
| """Number of dead components after filtering.""" | ||
| return len(self.dead_components_labels) if self.dead_components_labels else 0 | ||
|
|
||
|
|
||
| def filter_dead_components( | ||
| activations: ActivationsTensor, | ||
| labels: ComponentLabels, | ||
| filter_dead_threshold: float = 0.01, | ||
| ) -> FilteredActivations: | ||
| """Filter out dead components based on a threshold | ||
|
|
||
| if `filter_dead_threshold` is 0, no filtering is applied. | ||
| activations and labels are returned as is, `dead_components_labels` is `None`. | ||
|
|
||
| otherwise, components whose **maximum** activations across all samples is below the threshold | ||
| are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`. | ||
| `dead_components_labels` will also be `None` if no components were below the threshold. | ||
| """ | ||
| dead_components_lst: ComponentLabels | None = None | ||
| if filter_dead_threshold > 0: | ||
| dead_components_lst = ComponentLabels(list()) | ||
| max_act: Float[Tensor, " c"] = activations.max(dim=0).values | ||
| dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold | ||
|
|
||
| if dead_components.any(): | ||
| activations = activations[:, ~dead_components] | ||
| alive_labels: list[tuple[str, bool]] = [ | ||
| (lbl, bool(keep.item())) | ||
| for lbl, keep in zip(labels, ~dead_components, strict=False) | ||
| ] | ||
| # re-assign labels only if we are filtering | ||
| labels = ComponentLabels([label for label, keep in alive_labels if keep]) | ||
| dead_components_lst = ComponentLabels( | ||
| [label for label, keep in alive_labels if not keep] | ||
| ) | ||
|
|
||
| return FilteredActivations( | ||
| activations=activations, | ||
| labels=labels, | ||
| dead_components_labels=dead_components_lst if dead_components_lst else None, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ProcessedActivations: | ||
| """Processed activations after filtering and concatenation""" | ||
|
|
||
| activations_raw: dict[str, ActivationsTensor] | ||
| "activations after filtering, but prior to concatenation" | ||
|
|
||
| activations: ActivationsTensor | ||
| "activations after filtering and concatenation" | ||
|
|
||
| labels: ComponentLabels | ||
| "list of length c with labels for each preserved component, format `{module_name}:{component_index}`" | ||
|
|
||
| dead_components_lst: ComponentLabels | None | ||
| "list of labels for dead components, or None if no filtering was applied" | ||
|
|
||
| def validate(self) -> None: | ||
| """Validate the processed activations""" | ||
| # getting this property will also perform a variety of other checks | ||
| assert self.n_components_alive > 0 | ||
|
|
||
| @property | ||
| def n_components_original(self) -> int: | ||
| """Total number of components before filtering. equal to the sum of all components in `activations_raw`, or to `n_components_alive + n_components_dead`""" | ||
| return sum(act.shape[1] for act in self.activations_raw.values()) | ||
|
|
||
| @property | ||
| def n_components_alive(self) -> int: | ||
| """Number of alive components after filtering. equal to the length of `labels`""" | ||
| n_alive: int = len(self.labels) | ||
| assert n_alive + self.n_components_dead == self.n_components_original, ( | ||
| f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })" | ||
| ) | ||
| assert n_alive == self.activations.shape[1], ( | ||
| f"{n_alive = } != {self.activations.shape[1] = }" | ||
| ) | ||
|
|
||
| return n_alive | ||
|
|
||
| @property | ||
| def n_components_dead(self) -> int: | ||
| """Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise""" | ||
| return len(self.dead_components_lst) if self.dead_components_lst else 0 | ||
|
|
||
| @cached_property | ||
| def label_index(self) -> dict[str, int | None]: | ||
| """Create a mapping from label to alive index (`None` if dead)""" | ||
| return { | ||
| **{label: i for i, label in enumerate(self.labels)}, | ||
| **( | ||
| {label: None for label in self.dead_components_lst} | ||
| if self.dead_components_lst | ||
| else {} | ||
| ), | ||
| } | ||
|
|
||
| def get_label_index(self, label: str) -> int | None: | ||
| """Get the index of a label in the activations, or None if it is dead""" | ||
| return self.label_index[label] | ||
|
|
||
| def get_label_index_alive(self, label: str) -> int: | ||
| """Get the index of a label in the activations, or raise if it is dead""" | ||
| idx: int | None = self.get_label_index(label) | ||
| if idx is None: | ||
| raise ValueError(f"Label '{label}' is dead and has no index in the activations.") | ||
| return idx | ||
|
|
||
| @property | ||
| def module_keys(self) -> list[str]: | ||
| """Get the module keys from the activations_raw""" | ||
| return list(self.activations_raw.keys()) | ||
|
|
||
| def get_module_indices(self, module_key: str) -> list[int | None]: | ||
| """given a module key, return a list len "num components in that moduel", with int index in alive components, or None if dead""" | ||
| num_components: int = self.activations_raw[module_key].shape[1] | ||
| return [self.label_index[f"{module_key}:{i}"] for i in range(num_components)] | ||
|
|
||
|
|
||
| def process_activations( | ||
| activations: dict[ | ||
| str, # module name to | ||
| Float[Tensor, "samples C"] # (sample x component gate activations) | ||
| | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) | ||
| ], | ||
| filter_dead_threshold: float = 0.01, | ||
| seq_mode: Literal["concat", "seq_mean", None] = None, | ||
| filter_modules: ModuleFilterFunc | None = None, | ||
| ) -> ProcessedActivations: | ||
| """get back a dict of coactivations, slices, and concated activations | ||
|
|
||
| Args: | ||
| activations: Dictionary of activations by module | ||
| filter_dead_threshold: Threshold for filtering dead components | ||
| seq_mode: How to handle sequence dimension | ||
| filter_modules: Function to filter modules | ||
| sort_components: Whether to sort components by similarity within each module | ||
| """ | ||
|
|
||
| # reshape -- special cases for llms | ||
| # ============================================================ | ||
| activations_: dict[str, ActivationsTensor] | ||
| if seq_mode == "concat": | ||
| # Concatenate the sequence dimension into the sample dimension | ||
| activations_ = { | ||
| key: act.reshape(act.shape[0] * act.shape[1], act.shape[2]) | ||
| for key, act in activations.items() | ||
| } | ||
| elif seq_mode == "seq_mean": | ||
| # Take the mean over the sequence dimension | ||
| activations_ = { | ||
| key: act.mean(dim=1) if act.ndim == 3 else act for key, act in activations.items() | ||
| } | ||
| else: | ||
| # Use the activations as they are | ||
| activations_ = activations | ||
|
|
||
| # put the labelled activations into one big matrix and filter them | ||
| # ============================================================ | ||
|
|
||
| # filter activations for only the modules we want | ||
| if filter_modules is not None: | ||
| activations_ = {key: act for key, act in activations_.items() if filter_modules(key)} | ||
|
|
||
| # compute the labels and total component count | ||
| total_c: int = 0 | ||
| labels: ComponentLabels = ComponentLabels(list()) | ||
| for key, act in activations_.items(): | ||
| c: int = act.shape[-1] | ||
| labels.extend([f"{key}:{i}" for i in range(c)]) | ||
| total_c += c | ||
|
|
||
| # concat the activations | ||
| act_concat: ActivationsTensor = torch.cat([activations_[key] for key in activations_], dim=-1) | ||
|
|
||
| # filter dead components | ||
| filtered_components: FilteredActivations = filter_dead_components( | ||
| activations=act_concat, | ||
| labels=labels, | ||
| filter_dead_threshold=filter_dead_threshold, | ||
| ) | ||
|
|
||
| assert filtered_components.n_alive + filtered_components.n_dead == total_c, ( | ||
| f"({filtered_components.n_alive = }) + ({filtered_components.n_dead = }) != ({total_c = })" | ||
| ) | ||
|
|
||
| return ProcessedActivations( | ||
| activations_raw=activations_, | ||
| activations=filtered_components.activations, | ||
| labels=filtered_components.labels, | ||
| dead_components_lst=filtered_components.dead_components_labels, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.