From 83cb844af2595fe837b896a83cf2dee0587113a6 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 10:56:58 +0100 Subject: [PATCH 01/61] remove component popping functionality --- spd/clustering/compute_costs.py | 110 +----------------- spd/clustering/configs/example.toml | 1 - spd/clustering/configs/example.yaml | 1 - spd/clustering/configs/resid_mlp1.json | 1 - spd/clustering/configs/resid_mlp2.json | 1 - spd/clustering/configs/resid_mlp3.json | 1 - spd/clustering/configs/simplestories_dev.json | 1 - spd/clustering/configs/test-resid_mlp1.json | 1 - .../configs/test-simplestories.json | 1 - spd/clustering/merge.py | 51 +------- spd/clustering/merge_config.py | 5 - tests/clustering/scripts/cluster_resid_mlp.py | 1 - tests/clustering/scripts/cluster_ss.py | 1 - tests/clustering/test_merge_config.py | 2 - tests/clustering/test_merge_integration.py | 39 ------- tests/clustering/test_storage.py | 3 - tests/clustering/test_wandb_integration.py | 4 - 17 files changed, 2 insertions(+), 222 deletions(-) diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index ba1ff274c..f1b3425d1 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -1,7 +1,7 @@ import math import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from spd.clustering.consts import ClusterCoactivationShaped, MergePair @@ -187,111 +187,3 @@ def recompute_coacts_merge_pair( coact_new, activation_mask_new, ) - - -def recompute_coacts_pop_group( - coact: ClusterCoactivationShaped, - merges: GroupMerge, - component_idx: int, - activation_mask: Bool[Tensor, "n_samples k_groups"], - activation_mask_orig: Bool[Tensor, "n_samples n_components"], -) -> tuple[ - GroupMerge, - Float[Tensor, "k_groups+1 k_groups+1"], - Bool[Tensor, "n_samples k_groups+1"], -]: - # sanity check dims - # ================================================== - - k_groups: int = coact.shape[0] - n_samples: int = activation_mask.shape[0] - k_groups_new: int = k_groups + 1 - assert coact.shape[1] == k_groups, "Coactivation matrix must be square" - assert activation_mask.shape[1] == k_groups, ( - "Activation mask must match coactivation matrix shape" - ) - assert n_samples == activation_mask_orig.shape[0], ( - "Activation mask original must match number of samples" - ) - - # get the activations we need - # ================================================== - # which group does the component belong to? - group_idx: int = int(merges.group_idxs[component_idx].item()) - group_size_old: int = int(merges.components_per_group[group_idx].item()) - group_size_new: int = group_size_old - 1 - - # activations of component we are popping out - acts_pop: Bool[Tensor, " samples"] = activation_mask_orig[:, component_idx] - - # activations of the "remainder" -- everything other than the component we are popping out, - # in the group we're popping it out of - acts_remainder: Bool[Tensor, " samples"] = ( - activation_mask_orig[ - :, [i for i in merges.components_in_group(group_idx) if i != component_idx] - ] - .max(dim=-1) - .values - ) - - # assemble the new activation mask - # ================================================== - # first concat the popped-out component onto the end - activation_mask_new: Bool[Tensor, " samples k_groups+1"] = torch.cat( - [activation_mask, acts_pop.unsqueeze(1)], - dim=1, - ) - # then replace the group we are popping out of with the remainder - activation_mask_new[:, group_idx] = acts_remainder - - # assemble the new coactivation matrix - # ================================================== - coact_new: Float[Tensor, "k_groups+1 k_groups+1"] = torch.full( - (k_groups_new, k_groups_new), - fill_value=float("nan"), - dtype=coact.dtype, - device=coact.device, - ) - # copy in the old coactivation matrix - coact_new[:k_groups, :k_groups] = coact.clone() - # compute new coactivations we need - coact_pop: Float[Tensor, " k_groups"] = acts_pop.float() @ activation_mask_new.float() - coact_remainder: Float[Tensor, " k_groups"] = ( - acts_remainder.float() @ activation_mask_new.float() - ) - - # replace the relevant rows and columns - coact_new[group_idx, :] = coact_remainder - coact_new[:, group_idx] = coact_remainder - coact_new[-1, :] = coact_pop - coact_new[:, -1] = coact_pop - - # assemble the new group merge - # ================================================== - group_idxs_new: Int[Tensor, " k_groups+1"] = merges.group_idxs.clone() - # the popped-out component is now its own group - new_group_idx: int = k_groups_new - 1 - group_idxs_new[component_idx] = new_group_idx - merge_new: GroupMerge = GroupMerge( - group_idxs=group_idxs_new, - k_groups=k_groups_new, - ) - - # sanity check - assert merge_new.components_per_group.shape == (k_groups_new,), ( - "New merge must have k_groups+1 components" - ) - assert merge_new.components_per_group[new_group_idx] == 1, ( - "New group must have exactly one component" - ) - assert merge_new.components_per_group[group_idx] == group_size_new, ( - "Old group must have one less component" - ) - - # return - # ================================================== - return ( - merge_new, - coact_new, - activation_mask_new, - ) diff --git a/spd/clustering/configs/example.toml b/spd/clustering/configs/example.toml index d5cfe46d6..d73d74a59 100644 --- a/spd/clustering/configs/example.toml +++ b/spd/clustering/configs/example.toml @@ -27,7 +27,6 @@ artifact = 100 # for calling the artifact callback activation_threshold = 0.01 # set to null to use scalar activations for cost calculation alpha = 1.0 # rank penalty term iters = 100 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? -pop_component_prob = 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold = 0.001 # Threshold for filtering dead components module_name_filter = "__NULL__" # Can be a string prefix like "model.layers.0." if you want to do only some modules rank_cost_fn_name = "const_1" # Options: const_1, const_2, log, linear diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/example.yaml index 5f3cd5fa5..62577d5bd 100644 --- a/spd/clustering/configs/example.yaml +++ b/spd/clustering/configs/example.yaml @@ -8,7 +8,6 @@ merge_config: merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from - pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules rank_cost_fn_name: const_1 # Options: const_1, const_2, log, linear diff --git a/spd/clustering/configs/resid_mlp1.json b/spd/clustering/configs/resid_mlp1.json index e825215ee..7e00a9aa6 100644 --- a/spd/clustering/configs/resid_mlp1.json +++ b/spd/clustering/configs/resid_mlp1.json @@ -5,7 +5,6 @@ "iters": null, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0, "module_name_filter": null }, diff --git a/spd/clustering/configs/resid_mlp2.json b/spd/clustering/configs/resid_mlp2.json index 2be350979..e6c4e7b07 100644 --- a/spd/clustering/configs/resid_mlp2.json +++ b/spd/clustering/configs/resid_mlp2.json @@ -5,7 +5,6 @@ "iters": 100, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.01, "module_name_filter": null }, diff --git a/spd/clustering/configs/resid_mlp3.json b/spd/clustering/configs/resid_mlp3.json index 5d87e08d5..2fa5fb1b6 100644 --- a/spd/clustering/configs/resid_mlp3.json +++ b/spd/clustering/configs/resid_mlp3.json @@ -5,7 +5,6 @@ "iters": 350, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.01, "module_name_filter": null }, diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/simplestories_dev.json index c82b11710..a28b58a95 100644 --- a/spd/clustering/configs/simplestories_dev.json +++ b/spd/clustering/configs/simplestories_dev.json @@ -5,7 +5,6 @@ "iters": null, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.001}, - "pop_component_prob": 0, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/test-resid_mlp1.json index 75877dd25..0adb9b42f 100644 --- a/spd/clustering/configs/test-resid_mlp1.json +++ b/spd/clustering/configs/test-resid_mlp1.json @@ -5,7 +5,6 @@ "iters": 140, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.1, "module_name_filter": null, "rank_cost_fn_name": "const_1" diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/test-simplestories.json index 377eb6af1..fb79c7021 100644 --- a/spd/clustering/configs/test-simplestories.json +++ b/spd/clustering/configs/test-simplestories.json @@ -5,7 +5,6 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.9, "module_name_filter": "model.layers.0" }, diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 3692e1687..38f11462c 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -8,7 +8,7 @@ from typing import Protocol import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from tqdm import tqdm @@ -16,7 +16,6 @@ compute_mdl_cost, compute_merge_costs, recompute_coacts_merge_pair, - recompute_coacts_pop_group, ) from spd.clustering.consts import ( ActivationsTensor, @@ -83,24 +82,6 @@ def merge_iteration( # determine number of iterations based on config and number of components num_iters: int = merge_config.get_num_iters(c_components) - # pop logic setup - # -------------------------------------------------- - # for speed, we precompute whether to pop components and which components to pop - # if we are not popping, we don't need these variables and can also delete other things - do_pop: bool = merge_config.pop_component_prob > 0.0 - if do_pop: - # at each iteration, we will pop a component with probability `pop_component_prob` - iter_pop: Bool[Tensor, " iters"] = ( - torch.rand(num_iters, device=coact.device) < merge_config.pop_component_prob - ) - # we pick a subcomponent at random, and if we decide to pop, we pop that one out of its group - # if the component is a singleton, nothing happens. this naturally biases towards popping - # less at the start and more at the end, since the effective probability of popping a component - # is actually something like `pop_component_prob * (c_components - k_groups) / c_components` - pop_component_idx: Int[Tensor, " iters"] = torch.randint( - 0, c_components, (num_iters,), device=coact.device - ) - # initialize vars # -------------------------------------------------- # start with an identity merge @@ -117,12 +98,6 @@ def merge_iteration( labels=component_labels, ) - # free up memory - if not do_pop: - del coact - del activation_mask_orig - activation_mask_orig = None - # merge iteration # ================================================== pbar: tqdm[int] = tqdm( @@ -131,30 +106,6 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: - # pop components - # -------------------------------------------------- - if do_pop and iter_pop[iter_idx]: # pyright: ignore[reportPossiblyUnboundVariable] - # we split up the group which our chosen component belongs to - pop_component_idx_i: int = int(pop_component_idx[iter_idx].item()) # pyright: ignore[reportPossiblyUnboundVariable] - n_components_in_pop_grp: int = int( - current_merge.components_per_group[ # pyright: ignore[reportArgumentType] - current_merge.group_idxs[pop_component_idx_i].item() - ] - ) - - # but, if the component is the only one in its group, there is nothing to do - if n_components_in_pop_grp > 1: - current_merge, current_coact, current_act_mask = recompute_coacts_pop_group( - coact=current_coact, - merges=current_merge, - component_idx=pop_component_idx_i, - activation_mask=current_act_mask, - # this complains if `activation_mask_orig is None`, but this is only the case - # if `do_pop` is False, which it won't be here. we do this to save memory - activation_mask_orig=activation_mask_orig, # pyright: ignore[reportArgumentType] - ) - k_groups = current_coact.shape[0] - # compute costs, figure out what to merge # -------------------------------------------------- # HACK: this is messy diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 03c601a9f..2f75a8f32 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -23,7 +23,6 @@ "iters", "merge_pair_sampling_method", "merge_pair_sampling_kwargs", - "pop_component_prob", "filter_dead_threshold", ] @@ -65,10 +64,6 @@ class MergeConfig(BaseModel): default_factory=lambda: {"threshold": 0.05}, description="Keyword arguments for the merge pair sampling method.", ) - pop_component_prob: Probability = Field( - default=0, - description="Probability of popping a component in each iteration. If 0, no components are popped.", - ) filter_dead_threshold: float = Field( default=0.001, description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index df8dd0a2f..52fd20cb8 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -121,7 +121,6 @@ iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9), merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, - pop_component_prob=0, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 00ef733d6..41a87375c 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -105,7 +105,6 @@ iters=2, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, module_name_filter=FILTER_MODULES, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py index 9f191075b..63f4e88f7 100644 --- a/tests/clustering/test_merge_config.py +++ b/tests/clustering/test_merge_config.py @@ -74,7 +74,6 @@ def test_config_with_all_parameters(self): iters=200, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.5}, - pop_component_prob=0.1, filter_dead_threshold=0.001, module_name_filter="model.layers", ) @@ -84,7 +83,6 @@ def test_config_with_all_parameters(self): assert config.iters == 200 assert config.merge_pair_sampling_method == "mcmc" assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} - assert config.pop_component_prob == 0.1 assert config.filter_dead_threshold == 0.001 assert config.module_name_filter == "model.layers" diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 6463ad07b..25ae62319 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -25,7 +25,6 @@ def test_merge_with_range_sampler(self): iters=5, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -62,7 +61,6 @@ def test_merge_with_mcmc_sampler(self): iters=5, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -83,40 +81,6 @@ def test_merge_with_mcmc_sampler(self): assert history.merges.k_groups[-1].item() < n_components assert history.merges.k_groups[-1].item() >= 2 - def test_merge_with_popping(self): - """Test merge iteration with component popping.""" - # Create test data - n_samples = 100 - n_components = 15 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) - - # Configure with popping enabled - config = MergeConfig( - activation_threshold=0.1, - alpha=1.0, - iters=10, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0.3, # 30% chance of popping - filter_dead_threshold=0.001, - ) - - # Run merge iteration - history = merge_iteration( - activations=activations, - batch_id="test_merge_with_popping", - merge_config=config, - component_labels=component_labels, - ) - - # Check results - assert history is not None - # First entry is after first merge, so should be n_components - 1 - assert history.merges.k_groups[0].item() == n_components - 1 - # Final group count depends on pops, but should be less than initial - assert history.merges.k_groups[-1].item() < n_components - def test_merge_comparison_samplers(self): """Compare behavior of different samplers with same data.""" # Create test data with clear structure @@ -137,7 +101,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum - pop_component_prob=0, ) history_range = merge_iteration( @@ -154,7 +117,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp - pop_component_prob=0, ) history_mcmc = merge_iteration( @@ -184,7 +146,6 @@ def test_merge_with_small_components(self): iters=1, # Just one merge merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0}, - pop_component_prob=0, ) history = merge_iteration( diff --git a/tests/clustering/test_storage.py b/tests/clustering/test_storage.py index d5e3d535e..2bf322a22 100644 --- a/tests/clustering/test_storage.py +++ b/tests/clustering/test_storage.py @@ -30,7 +30,6 @@ def sample_config() -> MergeConfig: iters=5, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) @@ -74,7 +73,6 @@ def test_save_and_load_run_config(self, temp_storage: ClusteringStorage): iters=10, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ), model_path="wandb:entity/project/run_id", task_name="lm", @@ -331,7 +329,6 @@ def test_storage_filesystem_structure(self, temp_storage: ClusteringStorage): iters=1, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ), model_path="wandb:e/p/r", task_name="lm", diff --git a/tests/clustering/test_wandb_integration.py b/tests/clustering/test_wandb_integration.py index cf400ca2b..7892089a2 100644 --- a/tests/clustering/test_wandb_integration.py +++ b/tests/clustering/test_wandb_integration.py @@ -27,7 +27,6 @@ def test_wandb_url_parsing_short_format(): iters=5, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) # Save histories using storage @@ -59,7 +58,6 @@ def test_merge_history_ensemble(): iters=3, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) histories = [] @@ -88,7 +86,6 @@ def test_save_merge_history_to_wandb(): iters=5, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) history = MergeHistory.from_config( @@ -135,7 +132,6 @@ def test_wandb_url_field_in_merge_history(): iters=10, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) # Create MergeHistory with wandb_url From 7856ea52f2793b176ebf7e10b611d2b91f4cc353 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 11:14:35 +0100 Subject: [PATCH 02/61] rename s1 to dataset in prep for refactor --- spd/clustering/{pipeline/s1_split_dataset.py => dataset.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename spd/clustering/{pipeline/s1_split_dataset.py => dataset.py} (100%) diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/dataset.py similarity index 100% rename from spd/clustering/pipeline/s1_split_dataset.py rename to spd/clustering/dataset.py From 94f82f39b8966cad8ee139e9d4fbfe97166e98b5 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 12:00:26 +0100 Subject: [PATCH 03/61] wip --- spd/clustering/dataset.py | 33 ++++++++++++++++++-------- spd/clustering/merge_config.py | 9 +++++++ spd/clustering/merge_run_config.py | 4 ---- tests/clustering/scripts/cluster_ss.py | 5 ++-- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index 711cda65a..cd2b20b81 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -21,8 +21,10 @@ from spd.models.component_model import ComponentModel, SPDRunInfo -def split_dataset( +def get_clustering_dataloader( config: ClusteringRunConfig, + ddp_rank: int = 0, + ddp_world_size: int = 1, **kwargs: Any, ) -> tuple[Iterator[BatchTensor], dict[str, Any]]: """Split a dataset into n_batches of batch_size, returning iterator and config""" @@ -32,13 +34,17 @@ def split_dataset( case "lm": ds, ds_config_dict = _get_dataloader_lm( model_path=config.model_path, - batch_size=config.batch_size, + batch_size=config.merge_config.batch_size, + ddp_rank=ddp_rank, + ddp_world_size=ddp_world_size, **kwargs, ) case "resid_mlp": ds, ds_config_dict = _get_dataloader_resid_mlp( model_path=config.model_path, - batch_size=config.batch_size, + batch_size=config.merge_config.batch_size, + ddp_rank=ddp_rank, + ddp_world_size=ddp_world_size, **kwargs, ) case name: @@ -61,7 +67,9 @@ def limited_iterator() -> Iterator[BatchTensor]: def _get_dataloader_lm( model_path: str, batch_size: int, - config_kwargs: dict[str, Any] | None = None, + dataset_config_kwargs: dict[str, Any] | None = None, + ddp_rank: int = 0, + ddp_world_size: int = 1, ) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]: """split up a SS dataset into n_batches of batch_size, returned the saved paths @@ -87,13 +95,13 @@ def _get_dataloader_lm( f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" ) - config_kwargs_: dict[str, Any] = { + dataset_config_kwargs_: dict[str, Any] = { **dict( is_tokenized=False, streaming=False, seed=0, ), - **(config_kwargs or {}), + **(dataset_config_kwargs or {}), } dataset_config: DatasetConfig = DatasetConfig( @@ -102,7 +110,7 @@ def _get_dataloader_lm( split=cfg.task_config.train_data_split, n_ctx=cfg.task_config.max_seq_len, column_name=cfg.task_config.column_name, - **config_kwargs_, + **dataset_config_kwargs_, ) with SpinnerContext(message="getting dataloader..."): @@ -112,8 +120,8 @@ def _get_dataloader_lm( batch_size=batch_size, buffer_size=cfg.task_config.buffer_size, global_seed=cfg.seed, - ddp_rank=0, - ddp_world_size=1, + ddp_rank=ddp_rank, + ddp_world_size=ddp_world_size, ) return (batch["input_ids"] for batch in dataloader), dataset_config.model_dump(mode="json") @@ -122,11 +130,16 @@ def _get_dataloader_lm( def _get_dataloader_resid_mlp( model_path: str, batch_size: int, + ddp_rank: int = 0, + ddp_world_size: int = 1, ) -> tuple[Generator[torch.Tensor, None, None], dict[str, Any]]: """Split a ResidMLP dataset into n_batches of batch_size and save the batches.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader + # TODO: this is a hack. idk what the best way to handle this is + shuffle_data: bool = ddp_world_size <= 1 + with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) # SPD_RUN = SPDRunInfo.from_path(EXPERIMENT_REGISTRY["resid_mlp3"].canonical_run) @@ -158,7 +171,7 @@ def _get_dataloader_resid_mlp( dataset: ResidMLPDataset = ResidMLPDataset(**resid_mlp_dataset_kwargs) dataloader: DatasetGeneratedDataLoader[tuple[Tensor, Tensor]] = DatasetGeneratedDataLoader( - dataset, batch_size=batch_size, shuffle=False + dataset, batch_size=batch_size, shuffle=shuffle_data ) return (batch[0] for batch in dataloader), resid_mlp_dataset_kwargs diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 2f75a8f32..aa4684ed6 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -72,6 +72,15 @@ class MergeConfig(BaseModel): default=None, description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", ) + # TODO: unsure of this var name + recompute_costs_every: PositiveInt = Field( + default=10, + description="How often to recompute the full cost matrix, replacing NaN values of merged components with their true value. Higher values mean less accurate merges but faster computation.", + ) + batch_size: PositiveInt = Field( + default=64, + description="Size of each batch for processing", + ) @property def merge_pair_sample_func(self) -> MergePairSampler: diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index d86cbc7a6..0a0fb07ae 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -92,10 +92,6 @@ class ClusteringRunConfig(BaseModel): default=10, description="Number of batches to split the dataset into (ensemble size)", ) - batch_size: PositiveInt = Field( - default=64, - description="Size of each batch for processing", - ) distances_method: DistancesMethod = Field( default="perm_invariant_hamming", description="Method to use for computing distances between clusterings", diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 41a87375c..ac144420e 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,11 +16,11 @@ component_activations, process_activations, ) +from spd.clustering.dataset import get_clustering_dataloader from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.clustering.pipeline.s1_split_dataset import split_dataset from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution from spd.models.component_model import ComponentModel, SPDRunInfo @@ -51,10 +51,9 @@ model_path=MODEL_PATH, task_name="lm", n_batches=1, - batch_size=2, ) -BATCHES, _ = split_dataset( +BATCHES, _ = get_clustering_dataloader( config=CONFIG, config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199 ) From 0ccc65d2bbffbb5b6890f82798d0be1292e652c5 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 10:55:22 +0100 Subject: [PATCH 04/61] make format --- spd/clustering/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index 82b5c4167..ec9bfba04 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -132,7 +132,7 @@ def _get_dataloader_resid_mlp( batch_size: int, ddp_rank: int = 0, ddp_world_size: int = 1, -) -> tuple[Generator[torch.Tensor, None, None], dict[str, Any]]: +) -> tuple[Generator[torch.Tensor], dict[str, Any]]: """Split a ResidMLP dataset into n_batches of batch_size and save the batches.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader From 064632cb0582ab650ef74d99a9100c054a2068f2 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 12:25:18 +0100 Subject: [PATCH 05/61] format --- spd/clustering/dataset_multibatch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spd/clustering/dataset_multibatch.py b/spd/clustering/dataset_multibatch.py index f44014f30..11fd067e4 100644 --- a/spd/clustering/dataset_multibatch.py +++ b/spd/clustering/dataset_multibatch.py @@ -1,6 +1,7 @@ """ Loads and splits dataset into batches, returning them as an iterator. """ + # TODO: figure out this file vs spd/clustering/dataset.py from collections.abc import Generator, Iterator from typing import Any From b5804556dc6547deb8b6a68651448133f2832f7c Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 12:32:25 +0100 Subject: [PATCH 06/61] typing fixes --- spd/clustering/dataset_multibatch.py | 11 ++++++++--- tests/clustering/test_calc_distances.py | 1 - tests/clustering/test_run_clustering_happy_path.py | 1 - 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/spd/clustering/dataset_multibatch.py b/spd/clustering/dataset_multibatch.py index 11fd067e4..d0eb4b636 100644 --- a/spd/clustering/dataset_multibatch.py +++ b/spd/clustering/dataset_multibatch.py @@ -20,10 +20,13 @@ from spd.experiments.resid_mlp.configs import ResidMLPModelConfig, ResidMLPTaskConfig from spd.experiments.resid_mlp.models import ResidMLP from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName def get_clustering_dataloader( config: ClusteringRunConfig, + task_name: TaskName, + n_batches: int, ddp_rank: int = 0, ddp_world_size: int = 1, **kwargs: Any, @@ -31,7 +34,7 @@ def get_clustering_dataloader( """Split a dataset into n_batches of batch_size, returning iterator and config""" ds: Generator[BatchTensor] ds_config_dict: dict[str, Any] - match config.task_name: + match task_name: case "lm": ds, ds_config_dict = _get_dataloader_lm( model_path=config.model_path, @@ -57,8 +60,8 @@ def get_clustering_dataloader( def limited_iterator() -> Iterator[BatchTensor]: batch_idx: int batch: BatchTensor - for batch_idx, batch in tqdm(enumerate(ds), total=config.n_batches, unit="batch"): - if batch_idx >= config.n_batches: + for batch_idx, batch in tqdm(enumerate(ds), total=n_batches, unit="batch"): + if batch_idx >= n_batches: break yield batch @@ -141,6 +144,8 @@ def _get_dataloader_resid_mlp( # TODO: this is a hack. idk what the best way to handle this is shuffle_data: bool = ddp_world_size <= 1 + assert ddp_rank >= 0 + with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) # SPD_RUN = SPDRunInfo.from_path(EXPERIMENT_REGISTRY["resid_mlp3"].canonical_run) diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py index d8971df05..b06350f4b 100644 --- a/tests/clustering/test_calc_distances.py +++ b/tests/clustering/test_calc_distances.py @@ -11,7 +11,6 @@ def test_merge_history_normalization_happy_path(): iters=3, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) histories = [] diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 91a7cf2ad..a418c1999 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -25,7 +25,6 @@ def test_run_clustering_happy_path(): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0, ), wandb_project=None, wandb_entity="goodfire", From 3a206ed64c65c7854157d0cc1126929c9b767c28 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:23:25 +0100 Subject: [PATCH 07/61] allow specifying either config path or mrc cfg in pipeline cfg --- spd/clustering/merge_run_config.py | 19 +++++++ spd/clustering/scripts/run_pipeline.py | 74 ++++++++++++++++++++++++-- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 60a5244d6..6671127b9 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -1,5 +1,8 @@ """ClusteringRunConfig""" +import base64 +import hashlib +import json from pathlib import Path from typing import Any, Self @@ -127,3 +130,19 @@ def model_dump_with_properties(self) -> dict[str, Any]: ) return base_dump + + def stable_hash_b64(self) -> str: + """Generate a stable, deterministic base64-encoded hash of this config. + + Uses SHA256 hash of the JSON representation with sorted keys for determinism. + Returns URL-safe base64 encoding without padding. + + Returns: + URL-safe base64-encoded hash (without padding) + """ + config_dict: dict[str, Any] = self.model_dump(mode="json") + config_json: str = json.dumps(config_dict, indent=2, sort_keys=True) + hash_digest: bytes = hashlib.sha256(config_json.encode()).digest() + # Use base64 URL-safe encoding and strip padding for filesystem safety + hash_b64: str = base64.urlsafe_b64encode(hash_digest).decode().rstrip("=") + return hash_b64 diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index cde83ffa1..34a7ef2f8 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -25,12 +25,14 @@ from typing import Any import wandb_workspaces.workspaces as ws -from pydantic import Field, PositiveInt, field_validator +from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.consts import DistancesMethod +from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.storage import StorageBase from spd.log import logger +from spd.settings import SPD_CACHE_DIR from spd.utils.command_utils import run_script_array_local from spd.utils.general_utils import replace_pydantic_model from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str @@ -69,7 +71,14 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path = Field(description="Path to ClusteringRunConfig file.") + run_clustering_config_path: Path | None = Field( + default=None, + description="Path to ClusteringRunConfig file. Mutually exclusive with run_clustering_config.", + ) + run_clustering_config: ClusteringRunConfig | None = Field( + default=None, + description="Inline ClusteringRunConfig. Mutually exclusive with run_clustering_config_path.", + ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") distances_methods: list[DistancesMethod] = Field( description="List of method(s) to use for calculating distances" @@ -84,6 +93,25 @@ class ClusteringPipelineConfig(BaseConfig): wandb_entity: str = Field(description="WandB entity (team/user) name") create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + @model_validator(mode="after") + def validate_config_fields(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" + has_path: bool = self.run_clustering_config_path is not None + has_inline: bool = self.run_clustering_config is not None + + if not has_path and not has_inline: + raise ValueError( + "Must specify exactly one of 'run_clustering_config_path' or 'run_clustering_config'" + ) + + if has_path and has_inline: + raise ValueError( + "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " + "Use only one." + ) + + return self + @field_validator("distances_methods") @classmethod def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesMethod]: @@ -94,6 +122,46 @@ def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesM return v + def get_config_path(self) -> Path: + """Get the path to the ClusteringRunConfig file. + + - If run_clustering_config_path is provided, returns it directly. + - If run_clustering_config is provided, caches it to a deterministic path + based on its content hash and returns that path. + - if the config file already exists in the cache, assert that it is identical. + + Returns: + Path to the (potentially newly created) ClusteringRunConfig file + """ + if self.run_clustering_config_path is not None: + return self.run_clustering_config_path + + assert self.run_clustering_config is not None, ( + "Either run_clustering_config_path or run_clustering_config must be set" + ) + + # Generate deterministic hash from config + hash_b64: str = self.run_clustering_config.stable_hash_b64() + + # Create cache directory + cache_dir: Path = SPD_CACHE_DIR / "merge_run_configs" + cache_dir.mkdir(parents=True, exist_ok=True) + + # Write config to cache if it doesn't exist + config_path: Path = cache_dir / f"{hash_b64}.json" + if not config_path.exists(): + self.run_clustering_config.to_file(config_path) + logger.info(f"Cached inline config to {config_path}") + else: + # Verify that existing file matches + existing_config = ClusteringRunConfig.from_file(config_path) + if existing_config != self.run_clustering_config: + raise ValueError( + f"Hash collision detected for config hash {hash_b64} at {config_path}\n{existing_config=}\n{self.run_clustering_config=}" + ) + + return config_path + def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: """Create WandB workspace view for clustering runs. @@ -148,7 +216,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.run_clustering_config_path.as_posix(), + pipeline_config.get_config_path().as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", From eb831c0eff389fea68a836a59db8ea5160ef454a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:26:56 +0100 Subject: [PATCH 08/61] [wip] reorg configs --- spd/clustering/configs/{ => mrc}/example.yaml | 0 spd/clustering/configs/{ => mrc}/resid_mlp1.json | 0 spd/clustering/configs/{ => mrc}/resid_mlp2.json | 0 spd/clustering/configs/{ => mrc}/resid_mlp3.json | 0 spd/clustering/configs/{ => mrc}/simplestories_dev.json | 0 spd/clustering/configs/{ => mrc}/test-resid_mlp1.json | 0 spd/clustering/configs/{ => mrc}/test-simplestories.json | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename spd/clustering/configs/{ => mrc}/example.yaml (100%) rename spd/clustering/configs/{ => mrc}/resid_mlp1.json (100%) rename spd/clustering/configs/{ => mrc}/resid_mlp2.json (100%) rename spd/clustering/configs/{ => mrc}/resid_mlp3.json (100%) rename spd/clustering/configs/{ => mrc}/simplestories_dev.json (100%) rename spd/clustering/configs/{ => mrc}/test-resid_mlp1.json (100%) rename spd/clustering/configs/{ => mrc}/test-simplestories.json (100%) diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/mrc/example.yaml similarity index 100% rename from spd/clustering/configs/example.yaml rename to spd/clustering/configs/mrc/example.yaml diff --git a/spd/clustering/configs/resid_mlp1.json b/spd/clustering/configs/mrc/resid_mlp1.json similarity index 100% rename from spd/clustering/configs/resid_mlp1.json rename to spd/clustering/configs/mrc/resid_mlp1.json diff --git a/spd/clustering/configs/resid_mlp2.json b/spd/clustering/configs/mrc/resid_mlp2.json similarity index 100% rename from spd/clustering/configs/resid_mlp2.json rename to spd/clustering/configs/mrc/resid_mlp2.json diff --git a/spd/clustering/configs/resid_mlp3.json b/spd/clustering/configs/mrc/resid_mlp3.json similarity index 100% rename from spd/clustering/configs/resid_mlp3.json rename to spd/clustering/configs/mrc/resid_mlp3.json diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/mrc/simplestories_dev.json similarity index 100% rename from spd/clustering/configs/simplestories_dev.json rename to spd/clustering/configs/mrc/simplestories_dev.json diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/mrc/test-resid_mlp1.json similarity index 100% rename from spd/clustering/configs/test-resid_mlp1.json rename to spd/clustering/configs/mrc/test-resid_mlp1.json diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/mrc/test-simplestories.json similarity index 100% rename from spd/clustering/configs/test-simplestories.json rename to spd/clustering/configs/mrc/test-simplestories.json From 89e5c36aef43139e7c19364f85a39f0d1a7b1ea0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:52:03 +0100 Subject: [PATCH 09/61] added default `None` for slurm partition and job name prefix --- spd/clustering/scripts/run_pipeline.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 34a7ef2f8..d970cc31d 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -84,14 +84,18 @@ class ClusteringPipelineConfig(BaseConfig): description="List of method(s) to use for calculating distances" ) base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") - slurm_job_name_prefix: str | None = Field(description="Prefix for SLURM job names") - slurm_partition: str | None = Field(description="SLURM partition to use") + slurm_job_name_prefix: str | None = Field( + default=None, description="Prefix for SLURM job names" + ) + slurm_partition: str | None = Field(default=None, description="SLURM partition to use") wandb_project: str | None = Field( default=None, description="Weights & Biases project name (set to None to disable WandB logging)", ) - wandb_entity: str = Field(description="WandB entity (team/user) name") - create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + create_git_snapshot: bool = Field( + default=False, description="Create a git snapshot for the run" + ) @model_validator(mode="after") def validate_config_fields(self) -> "ClusteringPipelineConfig": From 8910bb479081dd8be312707b88fed20f5b18b69d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:54:33 +0100 Subject: [PATCH 10/61] refactor configs, add config tests --- spd/clustering/configs/README.md | 1 + spd/clustering/configs/mrc/resid_mlp1.json | 5 +- spd/clustering/configs/mrc/resid_mlp2.json | 4 +- spd/clustering/configs/mrc/resid_mlp3.json | 23 - .../configs/pipeline-dev-simplestories.yaml | 2 +- .../configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- tests/clustering/test_pipeline_config.py | 473 ++++++++++++++++++ 9 files changed, 480 insertions(+), 34 deletions(-) create mode 100644 spd/clustering/configs/README.md delete mode 100644 spd/clustering/configs/mrc/resid_mlp3.json create mode 100644 tests/clustering/test_pipeline_config.py diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md new file mode 100644 index 000000000..ed6efe090 --- /dev/null +++ b/spd/clustering/configs/README.md @@ -0,0 +1 @@ +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/mrc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/mrc/resid_mlp1.json b/spd/clustering/configs/mrc/resid_mlp1.json index a7d118ac7..506717282 100644 --- a/spd/clustering/configs/mrc/resid_mlp1.json +++ b/spd/clustering/configs/mrc/resid_mlp1.json @@ -10,12 +10,9 @@ "module_name_filter": null }, "experiment_key": "resid_mlp1", - "distances_methods": ["perm_invariant_hamming"], - "n_batches": 8, "batch_size": 128, - "wandb_enabled": true, "wandb_project": "spd-cluster", - "intervals": { + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 5, diff --git a/spd/clustering/configs/mrc/resid_mlp2.json b/spd/clustering/configs/mrc/resid_mlp2.json index 2be350979..af645f3bd 100644 --- a/spd/clustering/configs/mrc/resid_mlp2.json +++ b/spd/clustering/configs/mrc/resid_mlp2.json @@ -10,11 +10,9 @@ "module_name_filter": null }, "experiment_key": "resid_mlp2", - "n_batches": 16, "batch_size": 1024, - "wandb_enabled": true, "wandb_project": "spd-cluster", - "intervals": { + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 5, diff --git a/spd/clustering/configs/mrc/resid_mlp3.json b/spd/clustering/configs/mrc/resid_mlp3.json deleted file mode 100644 index 5d87e08d5..000000000 --- a/spd/clustering/configs/mrc/resid_mlp3.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "merge_config": { - "activation_threshold": 0.01, - "alpha": 1, - "iters": 350, - "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, - "filter_dead_threshold": 0.01, - "module_name_filter": null - }, - "experiment_key": "resid_mlp3", - "n_batches": 4, - "batch_size": 1024, - "wandb_enabled": true, - "wandb_project": "spd-cluster", - "intervals": { - "stat": 1, - "tensor": 32, - "plot": 32, - "artifact": 32 - } -} \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 6909c5841..7eef9cfc9 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/simplestories_dev.json" +run_clustering_config_path: "spd/clustering/configs/mrc/simplestories_dev.json" n_runs: 4 distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index a413a5438..a3c02da5e 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/test-resid_mlp1.json" +run_clustering_config_path: "spd/clustering/configs/mrc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index e406628c4..c98895ab4 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/test-simplestories.json" +run_clustering_config_path: "spd/clustering/configs/mrc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 6a40c9b29..42db7ac84 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/example.yaml" +run_clustering_config_path: "spd/clustering/configs/mrc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py new file mode 100644 index 000000000..f64ff15c6 --- /dev/null +++ b/tests/clustering/test_pipeline_config.py @@ -0,0 +1,473 @@ +"""Tests for ClusteringPipelineConfig and ClusteringRunConfig with inline config support.""" + +from pathlib import Path + +import pytest + +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig +from spd.settings import REPO_ROOT, SPD_CACHE_DIR + + +class TestClusteringRunConfigStableHash: + """Test ClusteringRunConfig.stable_hash_b64() method.""" + + def test_deterministic_hash(self): + """Test that stable_hash_b64 is deterministic for identical configs.""" + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) > 0 + + def test_different_configs_different_hashes(self): + """Test that different configs produce different hashes.""" + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different model_path + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + config3 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=64, # Different batch_size + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + hash3 = config3.stable_hash_b64() + + assert hash1 != hash2 + assert hash1 != hash3 + assert hash2 != hash3 + + def test_hash_is_url_safe(self): + """Test that hash is URL-safe base64 (no padding, URL-safe chars).""" + config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + hash_value = config.stable_hash_b64() + + # Should not contain padding + assert "=" not in hash_value + + # Should only contain URL-safe base64 characters + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in hash_value) + + def test_nested_config_changes_hash(self): + """Test that changes in nested merge_config affect the hash.""" + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(activation_threshold=0.1), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(activation_threshold=0.2), # Different threshold + ) + + assert config1.stable_hash_b64() != config2.stable_hash_b64() + + +class TestClusteringPipelineConfigValidation: + """Test ClusteringPipelineConfig validation logic.""" + + def test_error_when_neither_field_provided(self): + """Test that error is raised when neither path nor inline config is provided.""" + with pytest.raises(ValueError, match="Must specify exactly one"): + ClusteringPipelineConfig( + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_error_when_both_fields_provided(self): + """Test that error is raised when both path and inline config are provided.""" + with pytest.raises(ValueError, match="Cannot specify both"): + ClusteringPipelineConfig( + run_clustering_config_path=Path("some/path.json"), + run_clustering_config=ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + merge_config=MergeConfig(), + dataset_seed=0, + idx_in_ensemble=0, + ), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_success_with_only_path(self): + """Test that config validates successfully with only path provided.""" + config = ClusteringPipelineConfig( + run_clustering_config_path=Path("some/path.json"), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.run_clustering_config_path == Path("some/path.json") + assert config.run_clustering_config is None + + def test_success_with_only_inline_config(self): + """Test that config validates successfully with only inline config provided.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.run_clustering_config_path is None + assert config.run_clustering_config == inline_config + + +class TestClusteringPipelineConfigGetConfigPath: + """Test ClusteringPipelineConfig.get_config_path() method.""" + + def test_returns_path_directly_when_using_path_field(self): + """Test that get_config_path returns the path directly when using run_clustering_config_path.""" + expected_path = Path("some/path.json") + + config = ClusteringPipelineConfig( + run_clustering_config_path=expected_path, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.get_config_path() == expected_path + + def test_creates_cached_file_when_using_inline_config(self): + """Test that get_config_path creates a cached file when using inline config.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + config_path = config.get_config_path() + + # Check that file exists + assert config_path.exists() + + # Check that it's in the expected directory + expected_cache_dir = SPD_CACHE_DIR / "merge_run_configs" + assert config_path.parent == expected_cache_dir + + # Check that filename is the hash + expected_hash = inline_config.stable_hash_b64() + assert config_path.name == f"{expected_hash}.json" + + # Check that file contents match the config + loaded_config = ClusteringRunConfig.from_file(config_path) + assert loaded_config == inline_config + + # Clean up + config_path.unlink() + + def test_reuses_existing_cached_file(self): + """Test that get_config_path reuses existing cached file with same hash.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config1 = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + # First call creates the file + config_path1 = config1.get_config_path() + assert config_path1.exists() + + # Record modification time + mtime1 = config_path1.stat().st_mtime + + # Create another config with same inline config + config2 = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=3, # Different n_runs shouldn't matter + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + # Second call should reuse the file + config_path2 = config2.get_config_path() + + assert config_path1 == config_path2 + assert config_path2.stat().st_mtime == mtime1 # File not modified + + # Clean up + config_path1.unlink() + + def test_hash_collision_detection(self): + """Test that hash collision is detected when existing file differs.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + # Create a fake collision by manually creating a file with same hash + hash_value = inline_config.stable_hash_b64() + cache_dir = SPD_CACHE_DIR / "merge_run_configs" + cache_dir.mkdir(parents=True, exist_ok=True) + collision_path = cache_dir / f"{hash_value}.json" + + # Write a different config to the file + different_config = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different! + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + different_config.to_file(collision_path) + + try: + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + # Should raise ValueError about hash collision + with pytest.raises(ValueError, match="Hash collision detected"): + config.get_config_path() + finally: + # Clean up + if collision_path.exists(): + collision_path.unlink() + + def test_cache_directory_created_if_not_exists(self): + """Test that cache directory is created if it doesn't exist.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + cache_dir = SPD_CACHE_DIR / "merge_run_configs" + + # Even if cache dir doesn't exist, get_config_path should create it + config_path = config.get_config_path() + + assert cache_dir.exists() + assert config_path.exists() + + # Clean up + config_path.unlink() + + +class TestAllConfigsValidation: + """Test that all existing config files can be loaded and validated.""" + + def test_all_pipeline_configs_valid(self): + """Test that all pipeline config files are valid.""" + configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" + + # Find all YAML/YML files in the configs directory (not subdirectories) + pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) + + # Should have at least some configs + assert len(pipeline_config_files) > 0, "No pipeline config files found" + + errors: list[tuple[Path, Exception]] = [] + + for config_file in pipeline_config_files: + try: + config = ClusteringPipelineConfig.from_file(config_file) + # Basic sanity checks + assert config.n_runs > 0 + assert len(config.distances_methods) > 0 + assert config.wandb_entity is not None + except Exception as e: + errors.append((config_file, e)) + + # Report all errors at once + if errors: + error_msg = "Failed to validate pipeline configs:\n" + for path, exc in errors: + error_msg += f" - {path.name}: {exc}\n" + pytest.fail(error_msg) + + def test_all_merge_run_configs_valid(self): + """Test that all merge run config files are valid.""" + mrc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "mrc" + + # Find all JSON/YAML/YML files in the mrc directory + mrc_files = ( + list(mrc_dir.glob("*.json")) + + list(mrc_dir.glob("*.yaml")) + + list(mrc_dir.glob("*.yml")) + ) + + # Should have at least some configs + assert len(mrc_files) > 0, "No merge run config files found" + + errors: list[tuple[Path, Exception]] = [] + + for config_file in mrc_files: + try: + config = ClusteringRunConfig.from_file(config_file) + # Basic sanity checks + assert config.batch_size > 0 + assert config.model_path.startswith("wandb:") + assert config.merge_config is not None + except Exception as e: + errors.append((config_file, e)) + + # Report all errors at once + if errors: + error_msg = "Failed to validate merge run configs:\n" + for path, exc in errors: + error_msg += f" - {path.name}: {exc}\n" + pytest.fail(error_msg) + + def test_pipeline_configs_reference_valid_mrc_files(self): + """Test that pipeline configs reference merge run config files that exist.""" + configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" + pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) + + errors: list[tuple[Path, str]] = [] + + for config_file in pipeline_config_files: + try: + config = ClusteringPipelineConfig.from_file(config_file) + + # Skip configs that use inline config + if config.run_clustering_config is not None: + continue + + # Check that referenced file exists + assert config.run_clustering_config_path is not None + mrc_path = REPO_ROOT / config.run_clustering_config_path + + if not mrc_path.exists(): + errors.append( + ( + config_file, + f"References non-existent file: {config.run_clustering_config_path}", + ) + ) + else: + # Try to load the referenced config + try: + ClusteringRunConfig.from_file(mrc_path) + except Exception as e: + errors.append( + ( + config_file, + f"Referenced file {mrc_path.name} is invalid: {e}", + ) + ) + except Exception as e: + errors.append((config_file, f"Failed to load pipeline config: {e}")) + + if errors: + error_msg = "Pipeline configs with invalid merge run config references:\n" + for path, msg in errors: + error_msg += f" - {path.name}: {msg}\n" + pytest.fail(error_msg) From 0b957f5fbf1ef97102c8aef2a8c83816ed9b4635 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:09:40 +0100 Subject: [PATCH 11/61] fix tests --- spd/clustering/merge_run_config.py | 2 + spd/clustering/scripts/run_pipeline.py | 23 ++- tests/clustering/test_pipeline_config.py | 209 +++-------------------- 3 files changed, 45 insertions(+), 189 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 6671127b9..19450ff66 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -54,6 +54,8 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) + # TODO: allow idx_in_ensemble to be `None` if ensemble_id is `None`? + # TODO: allow idx_in_ensemble to be auto-assigned by reading from db if -1? idx_in_ensemble: int = Field(0, description="Index of this run in the ensemble") merge_config: MergeConfig = Field(description="Merge algorithm configuration") diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index d970cc31d..334f5b418 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -98,7 +98,7 @@ class ClusteringPipelineConfig(BaseConfig): ) @model_validator(mode="after") - def validate_config_fields(self) -> "ClusteringPipelineConfig": + def validate_mrc_fields(self) -> "ClusteringPipelineConfig": """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" has_path: bool = self.run_clustering_config_path is not None has_inline: bool = self.run_clustering_config is not None @@ -108,11 +108,19 @@ def validate_config_fields(self) -> "ClusteringPipelineConfig": "Must specify exactly one of 'run_clustering_config_path' or 'run_clustering_config'" ) - if has_path and has_inline: - raise ValueError( - "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " - "Use only one." - ) + if has_path: + if has_inline: + raise ValueError( + "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " + "Use only one." + ) + else: + # Ensure the path exists + # pyright ignore because it doesn't recognize that has_path implies not None + if not self.run_clustering_config_path.exists(): # pyright: ignore[reportOptionalMemberAccess] + raise ValueError( + f"run_clustering_config_path does not exist: {self.run_clustering_config_path = }" + ) return self @@ -138,6 +146,9 @@ def get_config_path(self) -> Path: Path to the (potentially newly created) ClusteringRunConfig file """ if self.run_clustering_config_path is not None: + assert self.run_clustering_config_path.exists(), ( + f"no file at run_clustering_config_path: {self.run_clustering_config_path = }" + ) return self.run_clustering_config_path assert self.run_clustering_config is not None, ( diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index f64ff15c6..826a7fc28 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -13,8 +13,9 @@ class TestClusteringRunConfigStableHash: """Test ClusteringRunConfig.stable_hash_b64() method.""" - def test_deterministic_hash(self): - """Test that stable_hash_b64 is deterministic for identical configs.""" + def test_stable_hash_b64(self): + """Test that stable_hash_b64 is deterministic, unique, and URL-safe.""" + # Create 4 configs: 2 identical, 2 different config1 = ClusteringRunConfig( model_path="wandb:test/project/run1", batch_size=32, @@ -29,83 +30,44 @@ def test_deterministic_hash(self): idx_in_ensemble=0, merge_config=MergeConfig(), ) - - hash1 = config1.stable_hash_b64() - hash2 = config2.stable_hash_b64() - - assert hash1 == hash2 - assert isinstance(hash1, str) - assert len(hash1) > 0 - - def test_different_configs_different_hashes(self): - """Test that different configs produce different hashes.""" - config1 = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) - config2 = ClusteringRunConfig( + config3 = ClusteringRunConfig( model_path="wandb:test/project/run2", # Different model_path batch_size=32, dataset_seed=0, idx_in_ensemble=0, merge_config=MergeConfig(), ) - config3 = ClusteringRunConfig( + config4 = ClusteringRunConfig( model_path="wandb:test/project/run1", - batch_size=64, # Different batch_size + batch_size=32, dataset_seed=0, idx_in_ensemble=0, - merge_config=MergeConfig(), + merge_config=MergeConfig( + activation_threshold=0.2 + ), # Different merge_config to test nested fields ) hash1 = config1.stable_hash_b64() hash2 = config2.stable_hash_b64() hash3 = config3.stable_hash_b64() + hash4 = config4.stable_hash_b64() - assert hash1 != hash2 - assert hash1 != hash3 - assert hash2 != hash3 - - def test_hash_is_url_safe(self): - """Test that hash is URL-safe base64 (no padding, URL-safe chars).""" - config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) + # Identical configs produce identical hashes + assert hash1 == hash2 - hash_value = config.stable_hash_b64() + # Different configs produce different hashes + assert hash1 != hash3 + assert hash1 != hash4 + assert hash3 != hash4 - # Should not contain padding - assert "=" not in hash_value + # Hashes are strings + assert isinstance(hash1, str) + assert len(hash1) > 0 - # Should only contain URL-safe base64 characters + # Hashes are URL-safe base64 (no padding, URL-safe chars only) + assert "=" not in hash1 valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") - assert all(c in valid_chars for c in hash_value) - - def test_nested_config_changes_hash(self): - """Test that changes in nested merge_config affect the hash.""" - config1 = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(activation_threshold=0.1), - ) - config2 = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(activation_threshold=0.2), # Different threshold - ) - - assert config1.stable_hash_b64() != config2.stable_hash_b64() + assert all(c in valid_chars for c in hash1) class TestClusteringPipelineConfigValidation: @@ -145,49 +107,13 @@ def test_error_when_both_fields_provided(self): create_git_snapshot=False, ) - def test_success_with_only_path(self): - """Test that config validates successfully with only path provided.""" - config = ClusteringPipelineConfig( - run_clustering_config_path=Path("some/path.json"), - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - assert config.run_clustering_config_path == Path("some/path.json") - assert config.run_clustering_config is None - - def test_success_with_only_inline_config(self): - """Test that config validates successfully with only inline config provided.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) - - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - assert config.run_clustering_config_path is None - assert config.run_clustering_config == inline_config - class TestClusteringPipelineConfigGetConfigPath: """Test ClusteringPipelineConfig.get_config_path() method.""" def test_returns_path_directly_when_using_path_field(self): """Test that get_config_path returns the path directly when using run_clustering_config_path.""" - expected_path = Path("some/path.json") + expected_path = Path("spd/clustering/configs/mrc/resid_mlp1.json") config = ClusteringPipelineConfig( run_clustering_config_path=expected_path, @@ -330,36 +256,6 @@ def test_hash_collision_detection(self): if collision_path.exists(): collision_path.unlink() - def test_cache_directory_created_if_not_exists(self): - """Test that cache directory is created if it doesn't exist.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) - - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - cache_dir = SPD_CACHE_DIR / "merge_run_configs" - - # Even if cache dir doesn't exist, get_config_path should create it - config_path = config.get_config_path() - - assert cache_dir.exists() - assert config_path.exists() - - # Clean up - config_path.unlink() - class TestAllConfigsValidation: """Test that all existing config files can be loaded and validated.""" @@ -378,11 +274,8 @@ def test_all_pipeline_configs_valid(self): for config_file in pipeline_config_files: try: - config = ClusteringPipelineConfig.from_file(config_file) - # Basic sanity checks - assert config.n_runs > 0 - assert len(config.distances_methods) > 0 - assert config.wandb_entity is not None + _config = ClusteringPipelineConfig.from_file(config_file) + assert _config.get_config_path().exists() except Exception as e: errors.append((config_file, e)) @@ -411,11 +304,7 @@ def test_all_merge_run_configs_valid(self): for config_file in mrc_files: try: - config = ClusteringRunConfig.from_file(config_file) - # Basic sanity checks - assert config.batch_size > 0 - assert config.model_path.startswith("wandb:") - assert config.merge_config is not None + _config = ClusteringRunConfig.from_file(config_file) except Exception as e: errors.append((config_file, e)) @@ -425,49 +314,3 @@ def test_all_merge_run_configs_valid(self): for path, exc in errors: error_msg += f" - {path.name}: {exc}\n" pytest.fail(error_msg) - - def test_pipeline_configs_reference_valid_mrc_files(self): - """Test that pipeline configs reference merge run config files that exist.""" - configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" - pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) - - errors: list[tuple[Path, str]] = [] - - for config_file in pipeline_config_files: - try: - config = ClusteringPipelineConfig.from_file(config_file) - - # Skip configs that use inline config - if config.run_clustering_config is not None: - continue - - # Check that referenced file exists - assert config.run_clustering_config_path is not None - mrc_path = REPO_ROOT / config.run_clustering_config_path - - if not mrc_path.exists(): - errors.append( - ( - config_file, - f"References non-existent file: {config.run_clustering_config_path}", - ) - ) - else: - # Try to load the referenced config - try: - ClusteringRunConfig.from_file(mrc_path) - except Exception as e: - errors.append( - ( - config_file, - f"Referenced file {mrc_path.name} is invalid: {e}", - ) - ) - except Exception as e: - errors.append((config_file, f"Failed to load pipeline config: {e}")) - - if errors: - error_msg = "Pipeline configs with invalid merge run config references:\n" - for path, msg in errors: - error_msg += f" - {path.name}: {msg}\n" - pytest.fail(error_msg) From 7de545b1e5502146623c83a16fed704d0eeff007 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:33:04 +0100 Subject: [PATCH 12/61] allow `None` or `-1` idx_in_ensemble - idx_in_ensemble is None iff ensemble_id is None - idx_in_ensemble == -1 will make register_clustering_run() auto-assign next avalible index - added tests for ensemble registry --- spd/clustering/ensemble_registry.py | 28 ++- spd/clustering/merge_run_config.py | 41 ++-- spd/clustering/scripts/run_clustering.py | 22 ++- tests/clustering/test_ensemble_registry.py | 215 +++++++++++++++++++++ tests/clustering/test_pipeline_config.py | 9 - 5 files changed, 278 insertions(+), 37 deletions(-) create mode 100644 tests/clustering/test_ensemble_registry.py diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index 7756877d8..b3b1711ab 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -6,6 +6,7 @@ import sqlite3 from contextlib import contextmanager +from spd.clustering.merge_run_config import ClusteringEnsembleIndex from spd.settings import SPD_CACHE_DIR # SQLite database path @@ -39,21 +40,42 @@ def _get_connection(): conn.close() -def register_clustering_run(pipeline_run_id: str, idx: int, clustering_run_id: str) -> None: +def register_clustering_run( + pipeline_run_id: str, idx: ClusteringEnsembleIndex, clustering_run_id: str +) -> int: """Register a clustering run as part of a pipeline ensemble. Args: pipeline_run_id: The ensemble/pipeline run ID - idx: Index of this run in the ensemble + idx: Index of this run in the ensemble. If -1, auto-assigns the next available index. clustering_run_id: The individual clustering run ID + + Returns: + The index assigned to this run (either the provided idx or the auto-assigned one) """ with _get_connection() as conn: + # Use BEGIN IMMEDIATE for thread-safe auto-increment + conn.execute("BEGIN IMMEDIATE") + + assigned_idx: int + if idx == -1: + # Auto-assign next available index + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx = cursor.fetchone()[0] + else: + assigned_idx = idx + conn.execute( "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", - (pipeline_run_id, idx, clustering_run_id), + (pipeline_run_id, assigned_idx, clustering_run_id), ) conn.commit() + return assigned_idx + def get_clustering_runs(pipeline_run_id: str) -> list[tuple[int, str]]: """Get all clustering runs for a pipeline ensemble. diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 19450ff66..f82e00203 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -4,9 +4,9 @@ import hashlib import json from pathlib import Path -from typing import Any, Self +from typing import Any, Literal, Self -from pydantic import Field, PositiveInt, model_validator +from pydantic import Field, NonNegativeInt, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.merge_config import MergeConfig @@ -31,6 +31,10 @@ class LoggingIntervals(BaseConfig): ) +ClusteringEnsembleIndex = NonNegativeInt | Literal[-1] +"index in an ensemble; -1 will cause register_clustering_run() to auto-assign the next available index" + + class ClusteringRunConfig(BaseConfig): """Configuration for a single clustering run. @@ -54,9 +58,11 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) - # TODO: allow idx_in_ensemble to be `None` if ensemble_id is `None`? - # TODO: allow idx_in_ensemble to be auto-assigned by reading from db if -1? - idx_in_ensemble: int = Field(0, description="Index of this run in the ensemble") + # TODO: given our use of `register_clustering_run()` and the atomic guarantees of that, do we even need this index? + # probably still nice to have for clarity + idx_in_ensemble: ClusteringEnsembleIndex | None = Field( + default=None, description="Index of this run in the ensemble" + ) merge_config: MergeConfig = Field(description="Merge algorithm configuration") logging_intervals: LoggingIntervals = Field( @@ -74,16 +80,6 @@ class ClusteringRunConfig(BaseConfig): description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", ) - # TODO: no way to check this without knowing task - # @model_validator(mode="after") - # def validate_streaming_compatibility(self) -> Self: - # """Ensure dataset_streaming is only enabled for compatible tasks.""" - # if self.dataset_streaming and self.task_name != "lm": - # raise ValueError( - # f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" - # ) - # return self - @model_validator(mode="before") def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: experiment_key: str | None = values.get("experiment_key") @@ -105,11 +101,18 @@ def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: return values - @model_validator(mode="after") - def validate_model_path(self) -> Self: + @field_validator("model_path") + def validate_model_path(cls, v: str) -> str: """Validate that model_path is a proper WandB path.""" - if not self.model_path.startswith("wandb:"): - raise ValueError(f"model_path must start with 'wandb:', got: {self.model_path}") + if not v.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {v}") + return v + + @model_validator(mode="after") + def validate_ensemble_id_index(self) -> Self: + assert (self.idx_in_ensemble is None) == (self.ensemble_id is None), ( + "If ensemble_id is None, idx_in_ensemble must also be None" + ) return self @property diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 7c614407a..04c48be7f 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -229,17 +229,29 @@ def main(run_config: ClusteringRunConfig) -> Path: # Register with ensemble if this is part of a pipeline if run_config.ensemble_id: assert run_config.idx_in_ensemble is not None, ( - "idx_in_ensemble must be set when ensemble_id is provided" + "idx_in_ensemble must be set when ensemble_id is provided! to auto-assign, set idx_in_ensemble = -1.\n" + f"{'!' * 50}\nNOTE: this should be an unreachable state -- such a case should have been caught by the pydantic validator.\n{'!' * 50}" ) - register_clustering_run( + assigned_idx: int = register_clustering_run( run_config.ensemble_id, run_config.idx_in_ensemble, clustering_run_id, ) + + # Update config if index was auto-assigned + if run_config.idx_in_ensemble == -1: + run_config = replace_pydantic_model(run_config, {"idx_in_ensemble": assigned_idx}) + logger.info(f"Auto-assigned ensemble index: {assigned_idx}") + logger.info( - f"Registered with pipeline {run_config.ensemble_id} at index {run_config.idx_in_ensemble} in {_ENSEMBLE_REGISTRY_DB}" + f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" ) + # save config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + + # start logger.info("Starting clustering run") logger.info(f"Output directory: {storage.base_dir}") device = get_device() @@ -347,9 +359,7 @@ def main(run_config: ClusteringRunConfig) -> Path: log_callback=log_callback, ) - # 8. Save merge history and config - run_config.to_file(storage.config_path) - logger.info(f"Config saved to {storage.config_path}") + # 8. Save merge history history.save(storage.history_path) logger.info(f"History saved to {storage.history_path}") diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py new file mode 100644 index 000000000..e71e8e228 --- /dev/null +++ b/tests/clustering/test_ensemble_registry.py @@ -0,0 +1,215 @@ +"""Tests for ensemble_registry module.""" + +import tempfile +from pathlib import Path +from typing import Any + +import pytest + +from spd.clustering.ensemble_registry import ( + get_clustering_runs, + register_clustering_run, +) + + +@pytest.fixture +def temp_registry_db(monkeypatch: Any): + """Create a temporary registry database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + temp_db_path = Path(tmpdir) / "test_registry.db" + monkeypatch.setattr("spd.clustering.ensemble_registry._ENSEMBLE_REGISTRY_DB", temp_db_path) + yield temp_db_path + + +class TestRegisterClusteringRun: + """Test register_clustering_run() function.""" + + def test_register_with_explicit_index(self, _temp_registry_db: Any): + """Test registering a run with an explicit index.""" + pipeline_id = "pipeline_001" + idx = 0 + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, idx, run_id) + + # Should return the same index + assert assigned_idx == idx + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_register_multiple_explicit_indices(self, _temp_registry_db: Any): + """Test registering multiple runs with explicit indices.""" + pipeline_id = "pipeline_002" + + idx0 = register_clustering_run(pipeline_id, 0, "run_001") + idx1 = register_clustering_run(pipeline_id, 1, "run_002") + idx2 = register_clustering_run(pipeline_id, 2, "run_003") + + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify order in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_auto_assign_single_index(self, _temp_registry_db: Any): + """Test auto-assigning a single index.""" + pipeline_id = "pipeline_003" + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, -1, run_id) + + # First auto-assigned index should be 0 + assert assigned_idx == 0 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_auto_assign_multiple_indices(self, _temp_registry_db: Any): + """Test auto-assigning multiple indices sequentially.""" + pipeline_id = "pipeline_004" + + idx0 = register_clustering_run(pipeline_id, -1, "run_001") + idx1 = register_clustering_run(pipeline_id, -1, "run_002") + idx2 = register_clustering_run(pipeline_id, -1, "run_003") + + # Should auto-assign 0, 1, 2 + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_auto_assign_after_explicit_indices(self, _temp_registry_db: Any): + """Test that auto-assignment continues from max existing index.""" + pipeline_id = "pipeline_005" + + # Register explicit indices + register_clustering_run(pipeline_id, 0, "run_001") + register_clustering_run(pipeline_id, 1, "run_002") + + # Auto-assign should get index 2 + idx = register_clustering_run(pipeline_id, -1, "run_003") + assert idx == 2 + + # Auto-assign again should get index 3 + idx = register_clustering_run(pipeline_id, -1, "run_004") + assert idx == 3 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_001"), + (1, "run_002"), + (2, "run_003"), + (3, "run_004"), + ] + + def test_auto_assign_with_gaps(self, _temp_registry_db: Any): + """Test that auto-assignment uses max+1, even with gaps.""" + pipeline_id = "pipeline_006" + + # Register with gaps: 0, 5, 10 + register_clustering_run(pipeline_id, 0, "run_001") + register_clustering_run(pipeline_id, 5, "run_002") + register_clustering_run(pipeline_id, 10, "run_003") + + # Auto-assign should get index 11 (max + 1) + idx = register_clustering_run(pipeline_id, -1, "run_004") + assert idx == 11 + + # Verify in database (ordered by idx) + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_001"), + (5, "run_002"), + (10, "run_003"), + (11, "run_004"), + ] + + def test_mixed_explicit_and_auto_assign(self, _temp_registry_db: Any): + """Test mixing explicit and auto-assigned indices.""" + pipeline_id = "pipeline_007" + + # Mix of explicit and auto-assigned + idx0 = register_clustering_run(pipeline_id, -1, "run_001") # auto: 0 + idx1 = register_clustering_run(pipeline_id, 5, "run_002") # explicit: 5 + idx2 = register_clustering_run(pipeline_id, -1, "run_003") # auto: 6 + idx3 = register_clustering_run(pipeline_id, 2, "run_004") # explicit: 2 + idx4 = register_clustering_run(pipeline_id, -1, "run_005") # auto: 7 + + assert idx0 == 0 + assert idx1 == 5 + assert idx2 == 6 + assert idx3 == 2 + assert idx4 == 7 + + # Verify in database (ordered by idx) + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_001"), + (2, "run_004"), + (5, "run_002"), + (6, "run_003"), + (7, "run_005"), + ] + + def test_different_pipelines_independent(self, _temp_registry_db: Any): + """Test that different pipelines have independent index sequences.""" + pipeline_a = "pipeline_a" + pipeline_b = "pipeline_b" + + # Both should start at 0 when auto-assigning + idx_a0 = register_clustering_run(pipeline_a, -1, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, -1, "run_b1") + + assert idx_a0 == 0 + assert idx_b0 == 0 + + # Both should increment independently + idx_a1 = register_clustering_run(pipeline_a, -1, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, -1, "run_b2") + + assert idx_a1 == 1 + assert idx_b1 == 1 + + # Verify in database + runs_a = get_clustering_runs(pipeline_a) + runs_b = get_clustering_runs(pipeline_b) + + assert runs_a == [(0, "run_a1"), (1, "run_a2")] + assert runs_b == [(0, "run_b1"), (1, "run_b2")] + + +class TestGetClusteringRuns: + """Test get_clustering_runs() function.""" + + def test_get_empty_pipeline(self, _temp_registry_db: Any): + """Test getting runs from a pipeline that doesn't exist.""" + runs = get_clustering_runs("nonexistent_pipeline") + assert runs == [] + + def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): + """Test that runs are returned sorted by index.""" + pipeline_id = "pipeline_sort" + + # Register out of order + register_clustering_run(pipeline_id, 5, "run_005") + register_clustering_run(pipeline_id, 1, "run_001") + register_clustering_run(pipeline_id, 3, "run_003") + register_clustering_run(pipeline_id, 0, "run_000") + + # Should be returned in sorted order + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_000"), + (1, "run_001"), + (3, "run_003"), + (5, "run_005"), + ] diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 826a7fc28..010195694 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -20,28 +20,24 @@ def test_stable_hash_b64(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) config2 = ClusteringRunConfig( model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) config3 = ClusteringRunConfig( model_path="wandb:test/project/run2", # Different model_path batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) config4 = ClusteringRunConfig( model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig( activation_threshold=0.2 ), # Different merge_config to test nested fields @@ -96,7 +92,6 @@ def test_error_when_both_fields_provided(self): batch_size=32, merge_config=MergeConfig(), dataset_seed=0, - idx_in_ensemble=0, ), n_runs=2, distances_methods=["perm_invariant_hamming"], @@ -132,7 +127,6 @@ def test_creates_cached_file_when_using_inline_config(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) @@ -171,7 +165,6 @@ def test_reuses_existing_cached_file(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) @@ -216,7 +209,6 @@ def test_hash_collision_detection(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) @@ -231,7 +223,6 @@ def test_hash_collision_detection(self): model_path="wandb:test/project/run2", # Different! batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) different_config.to_file(collision_path) From 3d45ac4f38ab1a15415384ba05955eb56aa4fa3f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:40:14 +0100 Subject: [PATCH 13/61] whoops, wrong name on fixture --- tests/clustering/test_ensemble_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py index e71e8e228..ff751d5c0 100644 --- a/tests/clustering/test_ensemble_registry.py +++ b/tests/clustering/test_ensemble_registry.py @@ -13,7 +13,7 @@ @pytest.fixture -def temp_registry_db(monkeypatch: Any): +def _temp_registry_db(monkeypatch: Any): """Create a temporary registry database for testing.""" with tempfile.TemporaryDirectory() as tmpdir: temp_db_path = Path(tmpdir) / "test_registry.db" From 4adde100ca740571a13efd9309639c8958d09bcb Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:41:45 +0100 Subject: [PATCH 14/61] fix idx passed in tests when not needed --- tests/clustering/test_run_clustering_happy_path.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 91a7cf2ad..12c12c8b0 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -16,7 +16,6 @@ def test_run_clustering_happy_path(): model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run batch_size=4, dataset_seed=0, - idx_in_ensemble=0, base_output_dir=Path(temp_dir), ensemble_id=None, merge_config=MergeConfig( From 189b64aee14189f5e2f566caf0bf3c1dc26799aa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:43:21 +0100 Subject: [PATCH 15/61] rename "mrc" -> "crc" in paths I forgot its no longer called "MergeRunConfig" --- spd/clustering/configs/README.md | 2 +- .../configs/{mrc => crc}/example.yaml | 0 .../configs/{mrc => crc}/resid_mlp1.json | 0 .../configs/{mrc => crc}/resid_mlp2.json | 0 .../{mrc => crc}/simplestories_dev.json | 0 .../configs/{mrc => crc}/test-resid_mlp1.json | 0 .../{mrc => crc}/test-simplestories.json | 0 .../configs/pipeline-dev-simplestories.yaml | 2 +- .../configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- spd/clustering/scripts/run_pipeline.py | 2 +- tests/clustering/test_pipeline_config.py | 18 +++++++++--------- 13 files changed, 15 insertions(+), 15 deletions(-) rename spd/clustering/configs/{mrc => crc}/example.yaml (100%) rename spd/clustering/configs/{mrc => crc}/resid_mlp1.json (100%) rename spd/clustering/configs/{mrc => crc}/resid_mlp2.json (100%) rename spd/clustering/configs/{mrc => crc}/simplestories_dev.json (100%) rename spd/clustering/configs/{mrc => crc}/test-resid_mlp1.json (100%) rename spd/clustering/configs/{mrc => crc}/test-simplestories.json (100%) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md index ed6efe090..51db8e8a0 100644 --- a/spd/clustering/configs/README.md +++ b/spd/clustering/configs/README.md @@ -1 +1 @@ -this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/mrc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/mrc/example.yaml b/spd/clustering/configs/crc/example.yaml similarity index 100% rename from spd/clustering/configs/mrc/example.yaml rename to spd/clustering/configs/crc/example.yaml diff --git a/spd/clustering/configs/mrc/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json similarity index 100% rename from spd/clustering/configs/mrc/resid_mlp1.json rename to spd/clustering/configs/crc/resid_mlp1.json diff --git a/spd/clustering/configs/mrc/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json similarity index 100% rename from spd/clustering/configs/mrc/resid_mlp2.json rename to spd/clustering/configs/crc/resid_mlp2.json diff --git a/spd/clustering/configs/mrc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json similarity index 100% rename from spd/clustering/configs/mrc/simplestories_dev.json rename to spd/clustering/configs/crc/simplestories_dev.json diff --git a/spd/clustering/configs/mrc/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json similarity index 100% rename from spd/clustering/configs/mrc/test-resid_mlp1.json rename to spd/clustering/configs/crc/test-resid_mlp1.json diff --git a/spd/clustering/configs/mrc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json similarity index 100% rename from spd/clustering/configs/mrc/test-simplestories.json rename to spd/clustering/configs/crc/test-simplestories.json diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 7eef9cfc9..dc6e729d3 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/simplestories_dev.json" +run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" n_runs: 4 distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index a3c02da5e..db72fa3c0 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/test-resid_mlp1.json" +run_clustering_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index c98895ab4..24e686023 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/test-simplestories.json" +run_clustering_config_path: "spd/clustering/configs/crc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 42db7ac84..297b47d7b 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/example.yaml" +run_clustering_config_path: "spd/clustering/configs/crc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 334f5b418..7b04bcfc0 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -98,7 +98,7 @@ class ClusteringPipelineConfig(BaseConfig): ) @model_validator(mode="after") - def validate_mrc_fields(self) -> "ClusteringPipelineConfig": + def validate_crc_fields(self) -> "ClusteringPipelineConfig": """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" has_path: bool = self.run_clustering_config_path is not None has_inline: bool = self.run_clustering_config is not None diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 010195694..723192118 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -108,7 +108,7 @@ class TestClusteringPipelineConfigGetConfigPath: def test_returns_path_directly_when_using_path_field(self): """Test that get_config_path returns the path directly when using run_clustering_config_path.""" - expected_path = Path("spd/clustering/configs/mrc/resid_mlp1.json") + expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") config = ClusteringPipelineConfig( run_clustering_config_path=expected_path, @@ -279,21 +279,21 @@ def test_all_pipeline_configs_valid(self): def test_all_merge_run_configs_valid(self): """Test that all merge run config files are valid.""" - mrc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "mrc" + crc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "crc" - # Find all JSON/YAML/YML files in the mrc directory - mrc_files = ( - list(mrc_dir.glob("*.json")) - + list(mrc_dir.glob("*.yaml")) - + list(mrc_dir.glob("*.yml")) + # Find all JSON/YAML/YML files in the crc directory + crc_files = ( + list(crc_dir.glob("*.json")) + + list(crc_dir.glob("*.yaml")) + + list(crc_dir.glob("*.yml")) ) # Should have at least some configs - assert len(mrc_files) > 0, "No merge run config files found" + assert len(crc_files) > 0, "No merge run config files found" errors: list[tuple[Path, Exception]] = [] - for config_file in mrc_files: + for config_file in crc_files: try: _config = ClusteringRunConfig.from_file(config_file) except Exception as e: From 57f445a1fb276a47652a6a070d1da68a4637cf03 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:44:25 +0100 Subject: [PATCH 16/61] rename merge_run_config.py -> clustering_run_config.py --- .../{merge_run_config.py => clustering_run_config.py} | 0 spd/clustering/ensemble_registry.py | 2 +- spd/clustering/scripts/run_clustering.py | 2 +- spd/clustering/scripts/run_pipeline.py | 2 +- tests/clustering/scripts/cluster_ss.py | 2 +- tests/clustering/test_pipeline_config.py | 2 +- tests/clustering/test_run_clustering_happy_path.py | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename spd/clustering/{merge_run_config.py => clustering_run_config.py} (100%) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/clustering_run_config.py similarity index 100% rename from spd/clustering/merge_run_config.py rename to spd/clustering/clustering_run_config.py diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index b3b1711ab..540312d8e 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -6,7 +6,7 @@ import sqlite3 from contextlib import contextmanager -from spd.clustering.merge_run_config import ClusteringEnsembleIndex +from spd.clustering.clustering_run_config import ClusteringEnsembleIndex from spd.settings import SPD_CACHE_DIR # SQLite database path diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 04c48be7f..6b52a8bb3 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -31,6 +31,7 @@ component_activations, process_activations, ) +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( ActivationsTensor, BatchTensor, @@ -43,7 +44,6 @@ from spd.clustering.math.semilog import semilog from spd.clustering.merge import merge_iteration from spd.clustering.merge_history import MergeHistory -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 7b04bcfc0..5910599b7 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -28,8 +28,8 @@ from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import DistancesMethod -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.storage import StorageBase from spd.log import logger from spd.settings import SPD_CACHE_DIR diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index acb6f394e..173b8abe5 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,11 +16,11 @@ component_activations, process_activations, ) +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution from spd.models.component_model import ComponentModel, SPDRunInfo diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 723192118..311981037 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -4,8 +4,8 @@ import pytest +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig from spd.settings import REPO_ROOT, SPD_CACHE_DIR diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 12c12c8b0..57bb5e1ff 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -3,8 +3,8 @@ import pytest +from spd.clustering.clustering_run_config import ClusteringRunConfig, LoggingIntervals from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_run_config import ClusteringRunConfig, LoggingIntervals from spd.clustering.scripts.run_clustering import main From 91f53484f0fab58dca1b0f018fa09680ee6849fa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:45:43 +0100 Subject: [PATCH 17/61] fix pyright --- tests/clustering/test_ensemble_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py index ff751d5c0..bb2936cfd 100644 --- a/tests/clustering/test_ensemble_registry.py +++ b/tests/clustering/test_ensemble_registry.py @@ -13,7 +13,7 @@ @pytest.fixture -def _temp_registry_db(monkeypatch: Any): +def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction] """Create a temporary registry database for testing.""" with tempfile.TemporaryDirectory() as tmpdir: temp_db_path = Path(tmpdir) / "test_registry.db" From 11e5501637625630fadb87ec7e67eadff53f6e3b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:50:22 +0100 Subject: [PATCH 18/61] fix idx_in_ensemble being passed in tests --- tests/clustering/scripts/cluster_ss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 173b8abe5..3f5da34a0 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -51,7 +51,6 @@ model_path=MODEL_PATH, batch_size=2, dataset_seed=42, - idx_in_ensemble=0, dataset_streaming=True, # no effect since we do this manually ) From 1d96054af7ab1d86352acdfb81cac66eac42b801 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:17:17 +0100 Subject: [PATCH 19/61] rename cache dir 'merge_run_configs' -> 'clustering_run_configs' --- spd/clustering/scripts/run_pipeline.py | 2 +- tests/clustering/test_pipeline_config.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 5910599b7..cebc8fb06 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -159,7 +159,7 @@ def get_config_path(self) -> Path: hash_b64: str = self.run_clustering_config.stable_hash_b64() # Create cache directory - cache_dir: Path = SPD_CACHE_DIR / "merge_run_configs" + cache_dir: Path = SPD_CACHE_DIR / "clustering_run_configs" cache_dir.mkdir(parents=True, exist_ok=True) # Write config to cache if it doesn't exist diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 311981037..05dfa17b0 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -145,7 +145,7 @@ def test_creates_cached_file_when_using_inline_config(self): assert config_path.exists() # Check that it's in the expected directory - expected_cache_dir = SPD_CACHE_DIR / "merge_run_configs" + expected_cache_dir = SPD_CACHE_DIR / "clustering_run_configs" assert config_path.parent == expected_cache_dir # Check that filename is the hash @@ -214,7 +214,7 @@ def test_hash_collision_detection(self): # Create a fake collision by manually creating a file with same hash hash_value = inline_config.stable_hash_b64() - cache_dir = SPD_CACHE_DIR / "merge_run_configs" + cache_dir = SPD_CACHE_DIR / "clustering_run_configs" cache_dir.mkdir(parents=True, exist_ok=True) collision_path = cache_dir / f"{hash_value}.json" @@ -277,7 +277,7 @@ def test_all_pipeline_configs_valid(self): error_msg += f" - {path.name}: {exc}\n" pytest.fail(error_msg) - def test_all_merge_run_configs_valid(self): + def test_all_clustering_run_configs_valid(self): """Test that all merge run config files are valid.""" crc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "crc" From 9cbb52fd09cac8d79481a16de0a9e4c517960a33 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:51:43 +0100 Subject: [PATCH 20/61] fix import --- spd/clustering/dataset_multibatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/clustering/dataset_multibatch.py b/spd/clustering/dataset_multibatch.py index d0eb4b636..c64b4f08b 100644 --- a/spd/clustering/dataset_multibatch.py +++ b/spd/clustering/dataset_multibatch.py @@ -12,8 +12,8 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import BatchTensor -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.configs import Config from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig From a1f1146d0480b4ee08cfc2a7070be6170b9394d1 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:57:09 +0100 Subject: [PATCH 21/61] remove component popping changes brought in from PR https://github.com/goodfire-ai/spd/pull/206 branch clustering/refactor-multi-batch commit [9cbb52f](https://github.com/goodfire-ai/spd/pull/206/commits/9cbb52fd09cac8d79481a16de0a9e4c517960a33) --- spd/clustering/compute_costs.py | 110 +----------------- spd/clustering/configs/crc/example.yaml | 1 - spd/clustering/configs/crc/resid_mlp1.json | 1 - spd/clustering/configs/crc/resid_mlp2.json | 1 - .../configs/crc/simplestories_dev.json | 3 +- .../configs/crc/test-resid_mlp1.json | 1 - .../configs/crc/test-simplestories.json | 1 - spd/clustering/merge.py | 51 +------- spd/clustering/merge_config.py | 5 - tests/clustering/scripts/cluster_resid_mlp.py | 1 - tests/clustering/scripts/cluster_ss.py | 3 +- tests/clustering/test_calc_distances.py | 1 - tests/clustering/test_merge_config.py | 2 - tests/clustering/test_merge_integration.py | 36 ------ .../test_run_clustering_happy_path.py | 1 - 15 files changed, 4 insertions(+), 214 deletions(-) diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index ba1ff274c..f1b3425d1 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -1,7 +1,7 @@ import math import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from spd.clustering.consts import ClusterCoactivationShaped, MergePair @@ -187,111 +187,3 @@ def recompute_coacts_merge_pair( coact_new, activation_mask_new, ) - - -def recompute_coacts_pop_group( - coact: ClusterCoactivationShaped, - merges: GroupMerge, - component_idx: int, - activation_mask: Bool[Tensor, "n_samples k_groups"], - activation_mask_orig: Bool[Tensor, "n_samples n_components"], -) -> tuple[ - GroupMerge, - Float[Tensor, "k_groups+1 k_groups+1"], - Bool[Tensor, "n_samples k_groups+1"], -]: - # sanity check dims - # ================================================== - - k_groups: int = coact.shape[0] - n_samples: int = activation_mask.shape[0] - k_groups_new: int = k_groups + 1 - assert coact.shape[1] == k_groups, "Coactivation matrix must be square" - assert activation_mask.shape[1] == k_groups, ( - "Activation mask must match coactivation matrix shape" - ) - assert n_samples == activation_mask_orig.shape[0], ( - "Activation mask original must match number of samples" - ) - - # get the activations we need - # ================================================== - # which group does the component belong to? - group_idx: int = int(merges.group_idxs[component_idx].item()) - group_size_old: int = int(merges.components_per_group[group_idx].item()) - group_size_new: int = group_size_old - 1 - - # activations of component we are popping out - acts_pop: Bool[Tensor, " samples"] = activation_mask_orig[:, component_idx] - - # activations of the "remainder" -- everything other than the component we are popping out, - # in the group we're popping it out of - acts_remainder: Bool[Tensor, " samples"] = ( - activation_mask_orig[ - :, [i for i in merges.components_in_group(group_idx) if i != component_idx] - ] - .max(dim=-1) - .values - ) - - # assemble the new activation mask - # ================================================== - # first concat the popped-out component onto the end - activation_mask_new: Bool[Tensor, " samples k_groups+1"] = torch.cat( - [activation_mask, acts_pop.unsqueeze(1)], - dim=1, - ) - # then replace the group we are popping out of with the remainder - activation_mask_new[:, group_idx] = acts_remainder - - # assemble the new coactivation matrix - # ================================================== - coact_new: Float[Tensor, "k_groups+1 k_groups+1"] = torch.full( - (k_groups_new, k_groups_new), - fill_value=float("nan"), - dtype=coact.dtype, - device=coact.device, - ) - # copy in the old coactivation matrix - coact_new[:k_groups, :k_groups] = coact.clone() - # compute new coactivations we need - coact_pop: Float[Tensor, " k_groups"] = acts_pop.float() @ activation_mask_new.float() - coact_remainder: Float[Tensor, " k_groups"] = ( - acts_remainder.float() @ activation_mask_new.float() - ) - - # replace the relevant rows and columns - coact_new[group_idx, :] = coact_remainder - coact_new[:, group_idx] = coact_remainder - coact_new[-1, :] = coact_pop - coact_new[:, -1] = coact_pop - - # assemble the new group merge - # ================================================== - group_idxs_new: Int[Tensor, " k_groups+1"] = merges.group_idxs.clone() - # the popped-out component is now its own group - new_group_idx: int = k_groups_new - 1 - group_idxs_new[component_idx] = new_group_idx - merge_new: GroupMerge = GroupMerge( - group_idxs=group_idxs_new, - k_groups=k_groups_new, - ) - - # sanity check - assert merge_new.components_per_group.shape == (k_groups_new,), ( - "New merge must have k_groups+1 components" - ) - assert merge_new.components_per_group[new_group_idx] == 1, ( - "New group must have exactly one component" - ) - assert merge_new.components_per_group[group_idx] == group_size_new, ( - "Old group must have one less component" - ) - - # return - # ================================================== - return ( - merge_new, - coact_new, - activation_mask_new, - ) diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml index efa36d693..3729941ce 100644 --- a/spd/clustering/configs/crc/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -12,7 +12,6 @@ merge_config: merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from - pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules diff --git a/spd/clustering/configs/crc/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json index 506717282..1e13ce23e 100644 --- a/spd/clustering/configs/crc/resid_mlp1.json +++ b/spd/clustering/configs/crc/resid_mlp1.json @@ -5,7 +5,6 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json index af645f3bd..edc4849e2 100644 --- a/spd/clustering/configs/crc/resid_mlp2.json +++ b/spd/clustering/configs/crc/resid_mlp2.json @@ -5,7 +5,6 @@ "iters": 100, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.01, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json index f585e848f..e1647b6e4 100644 --- a/spd/clustering/configs/crc/simplestories_dev.json +++ b/spd/clustering/configs/crc/simplestories_dev.json @@ -4,8 +4,7 @@ "alpha": 1.0, "iters": 100, "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.01}, - "pop_component_prob": 0, + "merge_pair_sampling_kwargs": {"threshold": 0.001}, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json index 01b510200..4b3a26ff8 100644 --- a/spd/clustering/configs/crc/test-resid_mlp1.json +++ b/spd/clustering/configs/crc/test-resid_mlp1.json @@ -5,7 +5,6 @@ "iters": 16, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json index 147634edb..911f71529 100644 --- a/spd/clustering/configs/crc/test-simplestories.json +++ b/spd/clustering/configs/crc/test-simplestories.json @@ -5,7 +5,6 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.9, "module_name_filter": "model.layers.0" }, diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index fd982b83f..dba55c878 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -8,7 +8,7 @@ from typing import Protocol import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from tqdm import tqdm @@ -16,7 +16,6 @@ compute_mdl_cost, compute_merge_costs, recompute_coacts_merge_pair, - recompute_coacts_pop_group, ) from spd.clustering.consts import ( ActivationsTensor, @@ -76,24 +75,6 @@ def merge_iteration( # determine number of iterations based on config and number of components num_iters: int = merge_config.get_num_iters(c_components) - # pop logic setup - # -------------------------------------------------- - # for speed, we precompute whether to pop components and which components to pop - # if we are not popping, we don't need these variables and can also delete other things - do_pop: bool = merge_config.pop_component_prob > 0.0 - if do_pop: - # at each iteration, we will pop a component with probability `pop_component_prob` - iter_pop: Bool[Tensor, " iters"] = ( - torch.rand(num_iters, device=coact.device) < merge_config.pop_component_prob - ) - # we pick a subcomponent at random, and if we decide to pop, we pop that one out of its group - # if the component is a singleton, nothing happens. this naturally biases towards popping - # less at the start and more at the end, since the effective probability of popping a component - # is actually something like `pop_component_prob * (c_components - k_groups) / c_components` - pop_component_idx: Int[Tensor, " iters"] = torch.randint( - 0, c_components, (num_iters,), device=coact.device - ) - # initialize vars # -------------------------------------------------- # start with an identity merge @@ -110,12 +91,6 @@ def merge_iteration( labels=component_labels, ) - # free up memory - if not do_pop: - del coact - del activation_mask_orig - activation_mask_orig = None - # merge iteration # ================================================== pbar: tqdm[int] = tqdm( @@ -124,30 +99,6 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: - # pop components - # -------------------------------------------------- - if do_pop and iter_pop[iter_idx]: # pyright: ignore[reportPossiblyUnboundVariable] - # we split up the group which our chosen component belongs to - pop_component_idx_i: int = int(pop_component_idx[iter_idx].item()) # pyright: ignore[reportPossiblyUnboundVariable] - n_components_in_pop_grp: int = int( - current_merge.components_per_group[ # pyright: ignore[reportArgumentType] - current_merge.group_idxs[pop_component_idx_i].item() - ] - ) - - # but, if the component is the only one in its group, there is nothing to do - if n_components_in_pop_grp > 1: - current_merge, current_coact, current_act_mask = recompute_coacts_pop_group( - coact=current_coact, - merges=current_merge, - component_idx=pop_component_idx_i, - activation_mask=current_act_mask, - # this complains if `activation_mask_orig is None`, but this is only the case - # if `do_pop` is False, which it won't be here. we do this to save memory - activation_mask_orig=activation_mask_orig, # pyright: ignore[reportArgumentType] - ) - k_groups = current_coact.shape[0] - # compute costs, figure out what to merge # -------------------------------------------------- # HACK: this is messy diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 3bf8b6d5b..f471879b2 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -23,7 +23,6 @@ "iters", "merge_pair_sampling_method", "merge_pair_sampling_kwargs", - "pop_component_prob", "filter_dead_threshold", ] @@ -65,10 +64,6 @@ class MergeConfig(BaseConfig): default_factory=lambda: {"threshold": 0.05}, description="Keyword arguments for the merge pair sampling method.", ) - pop_component_prob: Probability = Field( - default=0, - description="Probability of popping a component in each iteration. If 0, no components are popped.", - ) filter_dead_threshold: float = Field( default=0.001, description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index 1d2a69c93..bbfb5259e 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -120,7 +120,6 @@ iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9), merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, - pop_component_prob=0, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 3f5da34a0..45c142fa0 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -47,7 +47,7 @@ # Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), + merge_config=MergeConfig(batch_size=2), model_path=MODEL_PATH, batch_size=2, dataset_seed=42, @@ -103,7 +103,6 @@ iters=2, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, module_name_filter=FILTER_MODULES, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py index d8971df05..b06350f4b 100644 --- a/tests/clustering/test_calc_distances.py +++ b/tests/clustering/test_calc_distances.py @@ -11,7 +11,6 @@ def test_merge_history_normalization_happy_path(): iters=3, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) histories = [] diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py index 9f191075b..63f4e88f7 100644 --- a/tests/clustering/test_merge_config.py +++ b/tests/clustering/test_merge_config.py @@ -74,7 +74,6 @@ def test_config_with_all_parameters(self): iters=200, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.5}, - pop_component_prob=0.1, filter_dead_threshold=0.001, module_name_filter="model.layers", ) @@ -84,7 +83,6 @@ def test_config_with_all_parameters(self): assert config.iters == 200 assert config.merge_pair_sampling_method == "mcmc" assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} - assert config.pop_component_prob == 0.1 assert config.filter_dead_threshold == 0.001 assert config.module_name_filter == "model.layers" diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 14811b7c5..8492300de 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -25,7 +25,6 @@ def test_merge_with_range_sampler(self): iters=5, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -59,7 +58,6 @@ def test_merge_with_mcmc_sampler(self): iters=5, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -77,37 +75,6 @@ def test_merge_with_mcmc_sampler(self): assert history.merges.k_groups[-1].item() < n_components assert history.merges.k_groups[-1].item() >= 2 - def test_merge_with_popping(self): - """Test merge iteration with component popping.""" - # Create test data - n_samples = 100 - n_components = 15 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) - - # Configure with popping enabled - config = MergeConfig( - activation_threshold=0.1, - alpha=1.0, - iters=10, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0.3, # 30% chance of popping - filter_dead_threshold=0.001, - ) - - # Run merge iteration - history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels - ) - - # Check results - assert history is not None - # First entry is after first merge, so should be n_components - 1 - assert history.merges.k_groups[0].item() == n_components - 1 - # Final group count depends on pops, but should be less than initial - assert history.merges.k_groups[-1].item() < n_components - def test_merge_comparison_samplers(self): """Compare behavior of different samplers with same data.""" # Create test data with clear structure @@ -128,7 +95,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum - pop_component_prob=0, ) history_range = merge_iteration( @@ -144,7 +110,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp - pop_component_prob=0, ) history_mcmc = merge_iteration( @@ -173,7 +138,6 @@ def test_merge_with_small_components(self): iters=1, # Just one merge merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0}, - pop_component_prob=0, ) history = merge_iteration( diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 57bb5e1ff..5e2cbbd1c 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -24,7 +24,6 @@ def test_run_clustering_happy_path(): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0, ), wandb_project=None, wandb_entity="goodfire", From 1e3fbb292130164bbf278314dae101386ea667fc Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:59:30 +0100 Subject: [PATCH 22/61] dont pass batch size, change not brought in here --- tests/clustering/scripts/cluster_ss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 45c142fa0..0b7f8de97 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -47,7 +47,7 @@ # Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(batch_size=2), + merge_config=MergeConfig(), model_path=MODEL_PATH, batch_size=2, dataset_seed=42, From e86adc988e3cf00d7c189257169335aefba7961f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 22 Oct 2025 14:58:41 +0100 Subject: [PATCH 23/61] fix history_path extension and storage usage --- spd/clustering/scripts/calc_distances.py | 13 ++++++++++++- spd/clustering/scripts/run_clustering.py | 3 ++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/spd/clustering/scripts/calc_distances.py b/spd/clustering/scripts/calc_distances.py index 709d3c1c6..993335671 100644 --- a/spd/clustering/scripts/calc_distances.py +++ b/spd/clustering/scripts/calc_distances.py @@ -25,8 +25,10 @@ from spd.clustering.math.merge_distances import compute_distances from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble from spd.clustering.plotting.merge import plot_dists_distribution +from spd.clustering.scripts.run_clustering import ClusteringRunStorage from spd.log import logger from spd.settings import SPD_CACHE_DIR +from spd.utils.run_utils import ExecutionStamp # Set spawn method for CUDA compatibility with multiprocessing # Must be done before any CUDA operations @@ -57,7 +59,16 @@ def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: # Load histories from individual clustering run directories histories: list[MergeHistory] = [] for idx, clustering_run_id in clustering_runs: - history_path = SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" + history_path = ClusteringRunStorage( + ExecutionStamp( + run_id=clustering_run_id, + snapshot_branch="", + commit_hash="", + run_type="cluster", + ) + ).history_path + + # SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" if not history_path.exists(): raise FileNotFoundError( f"History not found for run {clustering_run_id}: {history_path}" diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 6b52a8bb3..de791f0b2 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -66,7 +66,8 @@ class ClusteringRunStorage(StorageBase): # Relative path constants _CONFIG = "clustering_run_config.json" - _HISTORY = "history.npz" + # we are saving a zip file with things in it besides npy files -- hence, `.zip` and not `.npz` + _HISTORY = "history.zip" def __init__(self, execution_stamp: ExecutionStamp) -> None: super().__init__(execution_stamp) From b8bbb088adbe7e3a6ee15e4ad7554e56cf891182 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 22 Oct 2025 15:10:41 +0100 Subject: [PATCH 24/61] dev pipeline --- .../configs/pipeline-dev-simplestories.yaml | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index dc6e729d3..dfee51d64 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -1,9 +1,27 @@ -run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" -n_runs: 4 -distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] -base_output_dir: "tests/.temp/clustering" +n_runs: 2 +distances_methods: ["matching_dist"] +# base_output_dir: "tests/.temp/clustering" slurm_job_name_prefix: null slurm_partition: null -wandb_project: null # wandb fails in CI +wandb_project: "spd-cluster" # wandb fails in CI wandb_entity: "goodfire" -create_git_snapshot: false \ No newline at end of file +create_git_snapshot: false +# run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" +run_clustering_config: + model_path: "wandb:goodfire/spd/runs/lxs77xye" + batch_size: 16 + wandb_project: "spd-cluster" + logging_intervals: + stat: 5 + tensor: 100 + plot: 10000 + artifact: 10000 + merge_config: + activation_threshold: 0.1 + alpha: 1.0 + iters: null + merge_pair_sampling_method: "range" + merge_pair_sampling_kwargs: + threshold: 0.001 + filter_dead_threshold: 0.1 + module_name_filter: null \ No newline at end of file From 733d47fb793133b4c4d76563b5fa03325a2aac87 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 22 Oct 2025 15:21:31 +0100 Subject: [PATCH 25/61] better config validation tests --- tests/clustering/test_pipeline_config.py | 78 ++++++++---------------- 1 file changed, 24 insertions(+), 54 deletions(-) diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 05dfa17b0..5e80f9c06 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -248,60 +248,30 @@ def test_hash_collision_detection(self): collision_path.unlink() -class TestAllConfigsValidation: - """Test that all existing config files can be loaded and validated.""" - - def test_all_pipeline_configs_valid(self): - """Test that all pipeline config files are valid.""" - configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" - - # Find all YAML/YML files in the configs directory (not subdirectories) - pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) - - # Should have at least some configs - assert len(pipeline_config_files) > 0, "No pipeline config files found" - - errors: list[tuple[Path, Exception]] = [] - - for config_file in pipeline_config_files: - try: - _config = ClusteringPipelineConfig.from_file(config_file) - assert _config.get_config_path().exists() - except Exception as e: - errors.append((config_file, e)) - - # Report all errors at once - if errors: - error_msg = "Failed to validate pipeline configs:\n" - for path, exc in errors: - error_msg += f" - {path.name}: {exc}\n" - pytest.fail(error_msg) - - def test_all_clustering_run_configs_valid(self): - """Test that all merge run config files are valid.""" - crc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "crc" - - # Find all JSON/YAML/YML files in the crc directory - crc_files = ( - list(crc_dir.glob("*.json")) - + list(crc_dir.glob("*.yaml")) - + list(crc_dir.glob("*.yml")) - ) +def _get_config_files(path: Path): + """Helper to get all config files.""" + pipeline_config_files = ( + list(path.glob("*.yaml")) + list(path.glob("*.yml")) + list(path.glob("*.json")) + ) + assert len(pipeline_config_files) > 0, f"No pipeline files found in {path}" + return pipeline_config_files - # Should have at least some configs - assert len(crc_files) > 0, "No merge run config files found" - errors: list[tuple[Path, Exception]] = [] - - for config_file in crc_files: - try: - _config = ClusteringRunConfig.from_file(config_file) - except Exception as e: - errors.append((config_file, e)) +class TestAllConfigsValidation: + """Test that all existing config files can be loaded and validated.""" - # Report all errors at once - if errors: - error_msg = "Failed to validate merge run configs:\n" - for path, exc in errors: - error_msg += f" - {path.name}: {exc}\n" - pytest.fail(error_msg) + @pytest.mark.parametrize( + "config_file", _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs") + ) + def test_pipeline_config_valid(self, config_file: Path): + """Test that each pipeline config file is valid.""" + _config = ClusteringPipelineConfig.from_file(config_file) + assert _config.get_config_path().exists() + + @pytest.mark.parametrize( + "config_file", _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs" / "crc") + ) + def test_clustering_run_config_valid(self, config_file: Path): + """Test that each clustering run config file is valid.""" + _config = ClusteringRunConfig.from_file(config_file) + assert isinstance(_config, ClusteringRunConfig) From 2fa1f214cac422064568033bfaf19e0e43f4d343 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 22 Oct 2025 15:25:18 +0100 Subject: [PATCH 26/61] set default base output dir --- spd/clustering/scripts/run_pipeline.py | 5 ++++- tests/clustering/test_pipeline_config.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index cebc8fb06..5396cb640 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -83,7 +83,10 @@ class ClusteringPipelineConfig(BaseConfig): distances_methods: list[DistancesMethod] = Field( description="List of method(s) to use for calculating distances" ) - base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering_pipeline", + description="Base directory for outputs of clustering ensemble pipeline runs.", + ) slurm_job_name_prefix: str | None = Field( default=None, description="Prefix for SLURM job names" ) diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 5e80f9c06..8d527bd6c 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -261,17 +261,25 @@ class TestAllConfigsValidation: """Test that all existing config files can be loaded and validated.""" @pytest.mark.parametrize( - "config_file", _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs") + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs"), + ids=lambda p: p.stem, ) - def test_pipeline_config_valid(self, config_file: Path): + def test_config_validate_pipeline(self, config_file: Path): """Test that each pipeline config file is valid.""" + print(config_file) _config = ClusteringPipelineConfig.from_file(config_file) - assert _config.get_config_path().exists() + crc_path = _config.get_config_path() + print(f"{crc_path = }") + assert crc_path.exists() @pytest.mark.parametrize( - "config_file", _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs" / "crc") + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs" / "crc"), + ids=lambda p: p.stem, ) - def test_clustering_run_config_valid(self, config_file: Path): + def test_config_validate_pipeline_clustering_run(self, config_file: Path): """Test that each clustering run config file is valid.""" + print(config_file) _config = ClusteringRunConfig.from_file(config_file) assert isinstance(_config, ClusteringRunConfig) From eec80fb7f0918924b763c0d26b799a58c51eb996 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 22 Oct 2025 15:28:21 +0100 Subject: [PATCH 27/61] wandb use run id for clustering, TODO for spd decomp --- spd/clustering/scripts/run_clustering.py | 1 + spd/utils/wandb_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index de791f0b2..6f8dfd722 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -282,6 +282,7 @@ def main(run_config: ClusteringRunConfig) -> Path: wandb_run: Run | None = None if run_config.wandb_project is not None: wandb_run = wandb.init( + id=clustering_run_id, entity=run_config.wandb_entity, project=run_config.wandb_project, group=run_config.ensemble_id, diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 70766bbcf..b7de574a0 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -153,6 +153,7 @@ def init_wandb[T_config: BaseConfig]( """ load_dotenv(override=True) + # TODO: pass run id from ExecutionStamp wandb.init( project=project, entity=os.getenv("WANDB_ENTITY"), From 6098536e1d07c6ada13e91bdecae6f146b94d5de Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 22 Oct 2025 15:32:32 +0100 Subject: [PATCH 28/61] basedpyright 1.32.0 causes issues, esp w/ wandb https://github.com/goodfire-ai/spd/actions/runs/18719611602/job/53388090437 --- pyproject.toml | 2 +- uv.lock | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ac30280e3..03ae4b28b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dev = [ "pytest-cov", # for coverage reports "pytest-xdist", # parallel test execution "ruff", - "basedpyright", + "basedpyright<1.32.0", "pre-commit", ] diff --git a/uv.lock b/uv.lock index 26bfb9af8..d61b6ee23 100644 --- a/uv.lock +++ b/uv.lock @@ -1091,7 +1091,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -1102,7 +1102,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -1129,9 +1129,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -1142,7 +1142,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -1949,7 +1949,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "basedpyright" }, + { name = "basedpyright", specifier = "<1.32.0" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -2226,7 +2226,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools" }, + { name = "setuptools", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, From 40df505c543bda9d5e21938728159a217640fc78 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 16:33:17 +0100 Subject: [PATCH 29/61] remove idx_in_ensemble, always auto-assigned now see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454317036 --- spd/clustering/clustering_run_config.py | 21 +-- spd/clustering/configs/crc/example.yaml | 1 - spd/clustering/dataset.py | 2 +- spd/clustering/ensemble_registry.py | 21 +-- spd/clustering/scripts/run_clustering.py | 29 ++--- tests/clustering/test_ensemble_registry.py | 145 +++------------------ 6 files changed, 42 insertions(+), 177 deletions(-) diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index f82e00203..95d72f9bd 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -4,9 +4,9 @@ import hashlib import json from pathlib import Path -from typing import Any, Literal, Self +from typing import Any -from pydantic import Field, NonNegativeInt, PositiveInt, field_validator, model_validator +from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.merge_config import MergeConfig @@ -31,10 +31,6 @@ class LoggingIntervals(BaseConfig): ) -ClusteringEnsembleIndex = NonNegativeInt | Literal[-1] -"index in an ensemble; -1 will cause register_clustering_run() to auto-assign the next available index" - - class ClusteringRunConfig(BaseConfig): """Configuration for a single clustering run. @@ -58,12 +54,6 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) - # TODO: given our use of `register_clustering_run()` and the atomic guarantees of that, do we even need this index? - # probably still nice to have for clarity - idx_in_ensemble: ClusteringEnsembleIndex | None = Field( - default=None, description="Index of this run in the ensemble" - ) - merge_config: MergeConfig = Field(description="Merge algorithm configuration") logging_intervals: LoggingIntervals = Field( default_factory=LoggingIntervals, @@ -108,13 +98,6 @@ def validate_model_path(cls, v: str) -> str: raise ValueError(f"model_path must start with 'wandb:', got: {v}") return v - @model_validator(mode="after") - def validate_ensemble_id_index(self) -> Self: - assert (self.idx_in_ensemble is None) == (self.ensemble_id is None), ( - "If ensemble_id is None, idx_in_ensemble must also be None" - ) - return self - @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition.""" diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml index 3729941ce..9345307d2 100644 --- a/spd/clustering/configs/crc/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -1,7 +1,6 @@ model_path: wandb:goodfire/spd/runs/zxbu57pt # WandB path to the decomposed model batch_size: 8 # Batch size for processing -- number of samples for each run in the ensemble dataset_seed: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) -# idx_in_ensemble: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) # output_dir: .data/clustering/clustering_runs # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) # ensemble_id: 1234567890 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c514aa69f..ea9b9f904 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -23,7 +23,7 @@ def load_dataset( ) -> BatchTensor: """Load a single batch for clustering. - Each run gets its own dataset batch, seeded by idx_in_ensemble. + Each run gets its own dataset batch, seeded by index in ensemble. Args: model_path: Path to decomposed model diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index 540312d8e..c54fe408b 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -6,7 +6,6 @@ import sqlite3 from contextlib import contextmanager -from spd.clustering.clustering_run_config import ClusteringEnsembleIndex from spd.settings import SPD_CACHE_DIR # SQLite database path @@ -40,9 +39,7 @@ def _get_connection(): conn.close() -def register_clustering_run( - pipeline_run_id: str, idx: ClusteringEnsembleIndex, clustering_run_id: str -) -> int: +def register_clustering_run(pipeline_run_id: str, clustering_run_id: str) -> int: """Register a clustering run as part of a pipeline ensemble. Args: @@ -57,16 +54,12 @@ def register_clustering_run( # Use BEGIN IMMEDIATE for thread-safe auto-increment conn.execute("BEGIN IMMEDIATE") - assigned_idx: int - if idx == -1: - # Auto-assign next available index - cursor = conn.execute( - "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", - (pipeline_run_id,), - ) - assigned_idx = cursor.fetchone()[0] - else: - assigned_idx = idx + # Auto-assign next available index, we rely on atomicity of the transaction here + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx: int = cursor.fetchone()[0] conn.execute( "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 6f8dfd722..54f0805c6 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -228,25 +228,23 @@ def main(run_config: ClusteringRunConfig) -> Path: logger.info(f"Clustering run ID: {clustering_run_id}") # Register with ensemble if this is part of a pipeline + assigned_idx: int | None if run_config.ensemble_id: - assert run_config.idx_in_ensemble is not None, ( - "idx_in_ensemble must be set when ensemble_id is provided! to auto-assign, set idx_in_ensemble = -1.\n" - f"{'!' * 50}\nNOTE: this should be an unreachable state -- such a case should have been caught by the pydantic validator.\n{'!' * 50}" + assigned_idx = register_clustering_run( + pipeline_run_id=run_config.ensemble_id, + clustering_run_id=clustering_run_id, ) - assigned_idx: int = register_clustering_run( - run_config.ensemble_id, - run_config.idx_in_ensemble, - clustering_run_id, - ) - - # Update config if index was auto-assigned - if run_config.idx_in_ensemble == -1: - run_config = replace_pydantic_model(run_config, {"idx_in_ensemble": assigned_idx}) - logger.info(f"Auto-assigned ensemble index: {assigned_idx}") logger.info( f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" ) + # IMPORTANT: set dataset seed based on assigned index + run_config = replace_pydantic_model( + run_config, + {"dataset_seed": run_config.dataset_seed + assigned_idx}, + ) + else: + assigned_idx = None # save config run_config.to_file(storage.config_path) @@ -292,7 +290,7 @@ def main(run_config: ClusteringRunConfig) -> Path: f"task:{task_name}", f"model:{run_config.wandb_decomp_model}", f"ensemble_id:{run_config.ensemble_id}", - f"idx:{run_config.idx_in_ensemble}", + f"assigned_idx:{assigned_idx}", ], ) # logger.info(f"WandB run: {wandb_run.url}") @@ -426,9 +424,6 @@ def cli() -> None: } # Handle ensemble-related overrides - if args.idx_in_ensemble is not None: - overrides["dataset_seed"] = run_config.dataset_seed + args.idx_in_ensemble - overrides["idx_in_ensemble"] = args.idx_in_ensemble if args.pipeline_run_id is not None: overrides["ensemble_id"] = args.pipeline_run_id diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py index bb2936cfd..c903af801 100644 --- a/tests/clustering/test_ensemble_registry.py +++ b/tests/clustering/test_ensemble_registry.py @@ -24,58 +24,27 @@ def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction class TestRegisterClusteringRun: """Test register_clustering_run() function.""" - def test_register_with_explicit_index(self, _temp_registry_db: Any): - """Test registering a run with an explicit index.""" + def test_register_single_run(self, _temp_registry_db: Any): + """Test registering a single run.""" pipeline_id = "pipeline_001" - idx = 0 run_id = "run_001" - assigned_idx = register_clustering_run(pipeline_id, idx, run_id) + assigned_idx = register_clustering_run(pipeline_id, run_id) - # Should return the same index - assert assigned_idx == idx - - # Verify in database - runs = get_clustering_runs(pipeline_id) - assert runs == [(0, "run_001")] - - def test_register_multiple_explicit_indices(self, _temp_registry_db: Any): - """Test registering multiple runs with explicit indices.""" - pipeline_id = "pipeline_002" - - idx0 = register_clustering_run(pipeline_id, 0, "run_001") - idx1 = register_clustering_run(pipeline_id, 1, "run_002") - idx2 = register_clustering_run(pipeline_id, 2, "run_003") - - assert idx0 == 0 - assert idx1 == 1 - assert idx2 == 2 - - # Verify order in database - runs = get_clustering_runs(pipeline_id) - assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] - - def test_auto_assign_single_index(self, _temp_registry_db: Any): - """Test auto-assigning a single index.""" - pipeline_id = "pipeline_003" - run_id = "run_001" - - assigned_idx = register_clustering_run(pipeline_id, -1, run_id) - - # First auto-assigned index should be 0 + # First index should be 0 assert assigned_idx == 0 # Verify in database runs = get_clustering_runs(pipeline_id) assert runs == [(0, "run_001")] - def test_auto_assign_multiple_indices(self, _temp_registry_db: Any): - """Test auto-assigning multiple indices sequentially.""" - pipeline_id = "pipeline_004" + def test_register_multiple_runs(self, _temp_registry_db: Any): + """Test registering multiple runs sequentially.""" + pipeline_id = "pipeline_002" - idx0 = register_clustering_run(pipeline_id, -1, "run_001") - idx1 = register_clustering_run(pipeline_id, -1, "run_002") - idx2 = register_clustering_run(pipeline_id, -1, "run_003") + idx0 = register_clustering_run(pipeline_id, "run_001") + idx1 = register_clustering_run(pipeline_id, "run_002") + idx2 = register_clustering_run(pipeline_id, "run_003") # Should auto-assign 0, 1, 2 assert idx0 == 0 @@ -86,95 +55,21 @@ def test_auto_assign_multiple_indices(self, _temp_registry_db: Any): runs = get_clustering_runs(pipeline_id) assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] - def test_auto_assign_after_explicit_indices(self, _temp_registry_db: Any): - """Test that auto-assignment continues from max existing index.""" - pipeline_id = "pipeline_005" - - # Register explicit indices - register_clustering_run(pipeline_id, 0, "run_001") - register_clustering_run(pipeline_id, 1, "run_002") - - # Auto-assign should get index 2 - idx = register_clustering_run(pipeline_id, -1, "run_003") - assert idx == 2 - - # Auto-assign again should get index 3 - idx = register_clustering_run(pipeline_id, -1, "run_004") - assert idx == 3 - - # Verify in database - runs = get_clustering_runs(pipeline_id) - assert runs == [ - (0, "run_001"), - (1, "run_002"), - (2, "run_003"), - (3, "run_004"), - ] - - def test_auto_assign_with_gaps(self, _temp_registry_db: Any): - """Test that auto-assignment uses max+1, even with gaps.""" - pipeline_id = "pipeline_006" - - # Register with gaps: 0, 5, 10 - register_clustering_run(pipeline_id, 0, "run_001") - register_clustering_run(pipeline_id, 5, "run_002") - register_clustering_run(pipeline_id, 10, "run_003") - - # Auto-assign should get index 11 (max + 1) - idx = register_clustering_run(pipeline_id, -1, "run_004") - assert idx == 11 - - # Verify in database (ordered by idx) - runs = get_clustering_runs(pipeline_id) - assert runs == [ - (0, "run_001"), - (5, "run_002"), - (10, "run_003"), - (11, "run_004"), - ] - - def test_mixed_explicit_and_auto_assign(self, _temp_registry_db: Any): - """Test mixing explicit and auto-assigned indices.""" - pipeline_id = "pipeline_007" - - # Mix of explicit and auto-assigned - idx0 = register_clustering_run(pipeline_id, -1, "run_001") # auto: 0 - idx1 = register_clustering_run(pipeline_id, 5, "run_002") # explicit: 5 - idx2 = register_clustering_run(pipeline_id, -1, "run_003") # auto: 6 - idx3 = register_clustering_run(pipeline_id, 2, "run_004") # explicit: 2 - idx4 = register_clustering_run(pipeline_id, -1, "run_005") # auto: 7 - - assert idx0 == 0 - assert idx1 == 5 - assert idx2 == 6 - assert idx3 == 2 - assert idx4 == 7 - - # Verify in database (ordered by idx) - runs = get_clustering_runs(pipeline_id) - assert runs == [ - (0, "run_001"), - (2, "run_004"), - (5, "run_002"), - (6, "run_003"), - (7, "run_005"), - ] - def test_different_pipelines_independent(self, _temp_registry_db: Any): """Test that different pipelines have independent index sequences.""" pipeline_a = "pipeline_a" pipeline_b = "pipeline_b" # Both should start at 0 when auto-assigning - idx_a0 = register_clustering_run(pipeline_a, -1, "run_a1") - idx_b0 = register_clustering_run(pipeline_b, -1, "run_b1") + idx_a0 = register_clustering_run(pipeline_a, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, "run_b1") assert idx_a0 == 0 assert idx_b0 == 0 # Both should increment independently - idx_a1 = register_clustering_run(pipeline_a, -1, "run_a2") - idx_b1 = register_clustering_run(pipeline_b, -1, "run_b2") + idx_a1 = register_clustering_run(pipeline_a, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, "run_b2") assert idx_a1 == 1 assert idx_b1 == 1 @@ -199,17 +94,17 @@ def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): """Test that runs are returned sorted by index.""" pipeline_id = "pipeline_sort" - # Register out of order - register_clustering_run(pipeline_id, 5, "run_005") - register_clustering_run(pipeline_id, 1, "run_001") - register_clustering_run(pipeline_id, 3, "run_003") - register_clustering_run(pipeline_id, 0, "run_000") + # Register runs (indices will be auto-assigned in order) + register_clustering_run(pipeline_id, "run_000") + register_clustering_run(pipeline_id, "run_001") + register_clustering_run(pipeline_id, "run_002") + register_clustering_run(pipeline_id, "run_003") # Should be returned in sorted order runs = get_clustering_runs(pipeline_id) assert runs == [ (0, "run_000"), (1, "run_001"), + (2, "run_002"), (3, "run_003"), - (5, "run_005"), ] From cf64a7972127fe59c8360ec720599c842b0b3e11 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 17:56:27 +0100 Subject: [PATCH 30/61] only allow passing clustering run config path, not inline see discussion at https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 have tried to make this change as isolated as possible -- i think this was a useful feature and we may want to add it back at some point --- .../configs/pipeline-dev-simplestories.yaml | 20 +-- spd/clustering/scripts/run_pipeline.py | 84 ++------- tests/clustering/test_pipeline_config.py | 168 ++---------------- 3 files changed, 21 insertions(+), 251 deletions(-) diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index dfee51d64..6d181424a 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,22 +6,4 @@ slurm_partition: null wandb_project: "spd-cluster" # wandb fails in CI wandb_entity: "goodfire" create_git_snapshot: false -# run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" -run_clustering_config: - model_path: "wandb:goodfire/spd/runs/lxs77xye" - batch_size: 16 - wandb_project: "spd-cluster" - logging_intervals: - stat: 5 - tensor: 100 - plot: 10000 - artifact: 10000 - merge_config: - activation_threshold: 0.1 - alpha: 1.0 - iters: null - merge_pair_sampling_method: "range" - merge_pair_sampling_kwargs: - threshold: 0.001 - filter_dead_threshold: 0.1 - module_name_filter: null \ No newline at end of file +run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 5396cb640..614d7ac17 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -71,13 +71,8 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path | None = Field( - default=None, - description="Path to ClusteringRunConfig file. Mutually exclusive with run_clustering_config.", - ) - run_clustering_config: ClusteringRunConfig | None = Field( - default=None, - description="Inline ClusteringRunConfig. Mutually exclusive with run_clustering_config_path.", + run_clustering_config_path: Path = Field( + description="Path to ClusteringRunConfig file.", ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") distances_methods: list[DistancesMethod] = Field( @@ -101,29 +96,13 @@ class ClusteringPipelineConfig(BaseConfig): ) @model_validator(mode="after") - def validate_crc_fields(self) -> "ClusteringPipelineConfig": - """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" - has_path: bool = self.run_clustering_config_path is not None - has_inline: bool = self.run_clustering_config is not None - - if not has_path and not has_inline: - raise ValueError( - "Must specify exactly one of 'run_clustering_config_path' or 'run_clustering_config'" - ) - - if has_path: - if has_inline: - raise ValueError( - "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " - "Use only one." - ) - else: - # Ensure the path exists - # pyright ignore because it doesn't recognize that has_path implies not None - if not self.run_clustering_config_path.exists(): # pyright: ignore[reportOptionalMemberAccess] - raise ValueError( - f"run_clustering_config_path does not exist: {self.run_clustering_config_path = }" - ) + def validate_crc(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of run_clustering_config_path points to a valid `ClusteringRunConfig`.""" + assert self.run_clustering_config_path.exists(), ( + f"run_clustering_config_path does not exist: {self.run_clustering_config_path}" + ) + # Try to load ClusteringRunConfig + assert ClusteringRunConfig.from_file(self.run_clustering_config_path) return self @@ -137,49 +116,6 @@ def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesM return v - def get_config_path(self) -> Path: - """Get the path to the ClusteringRunConfig file. - - - If run_clustering_config_path is provided, returns it directly. - - If run_clustering_config is provided, caches it to a deterministic path - based on its content hash and returns that path. - - if the config file already exists in the cache, assert that it is identical. - - Returns: - Path to the (potentially newly created) ClusteringRunConfig file - """ - if self.run_clustering_config_path is not None: - assert self.run_clustering_config_path.exists(), ( - f"no file at run_clustering_config_path: {self.run_clustering_config_path = }" - ) - return self.run_clustering_config_path - - assert self.run_clustering_config is not None, ( - "Either run_clustering_config_path or run_clustering_config must be set" - ) - - # Generate deterministic hash from config - hash_b64: str = self.run_clustering_config.stable_hash_b64() - - # Create cache directory - cache_dir: Path = SPD_CACHE_DIR / "clustering_run_configs" - cache_dir.mkdir(parents=True, exist_ok=True) - - # Write config to cache if it doesn't exist - config_path: Path = cache_dir / f"{hash_b64}.json" - if not config_path.exists(): - self.run_clustering_config.to_file(config_path) - logger.info(f"Cached inline config to {config_path}") - else: - # Verify that existing file matches - existing_config = ClusteringRunConfig.from_file(config_path) - if existing_config != self.run_clustering_config: - raise ValueError( - f"Hash collision detected for config hash {hash_b64} at {config_path}\n{existing_config=}\n{self.run_clustering_config=}" - ) - - return config_path - def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: """Create WandB workspace view for clustering runs. @@ -234,7 +170,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.get_config_path().as_posix(), + pipeline_config.run_clustering_config_path.as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 8d527bd6c..264078392 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -2,12 +2,13 @@ from pathlib import Path +import pydantic_core import pytest from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.merge_config import MergeConfig from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig -from spd.settings import REPO_ROOT, SPD_CACHE_DIR +from spd.settings import REPO_ROOT class TestClusteringRunConfigStableHash: @@ -69,10 +70,11 @@ def test_stable_hash_b64(self): class TestClusteringPipelineConfigValidation: """Test ClusteringPipelineConfig validation logic.""" - def test_error_when_neither_field_provided(self): - """Test that error is raised when neither path nor inline config is provided.""" - with pytest.raises(ValueError, match="Must specify exactly one"): + def test_error_when_path_does_not_exist(self): + """Test that error is raised when run_clustering_config_path does not exist.""" + with pytest.raises(pydantic_core._pydantic_core.ValidationError): ClusteringPipelineConfig( + run_clustering_config_path=Path("nonexistent/path.json"), n_runs=2, distances_methods=["perm_invariant_hamming"], base_output_dir=Path("/tmp/test"), @@ -82,32 +84,8 @@ def test_error_when_neither_field_provided(self): create_git_snapshot=False, ) - def test_error_when_both_fields_provided(self): - """Test that error is raised when both path and inline config are provided.""" - with pytest.raises(ValueError, match="Cannot specify both"): - ClusteringPipelineConfig( - run_clustering_config_path=Path("some/path.json"), - run_clustering_config=ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - merge_config=MergeConfig(), - dataset_seed=0, - ), - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - slurm_job_name_prefix=None, - slurm_partition=None, - wandb_entity="test", - create_git_snapshot=False, - ) - - -class TestClusteringPipelineConfigGetConfigPath: - """Test ClusteringPipelineConfig.get_config_path() method.""" - - def test_returns_path_directly_when_using_path_field(self): - """Test that get_config_path returns the path directly when using run_clustering_config_path.""" + def test_valid_config_with_existing_path(self): + """Test that config is valid when path points to existing file.""" expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") config = ClusteringPipelineConfig( @@ -119,133 +97,7 @@ def test_returns_path_directly_when_using_path_field(self): create_git_snapshot=False, ) - assert config.get_config_path() == expected_path - - def test_creates_cached_file_when_using_inline_config(self): - """Test that get_config_path creates a cached file when using inline config.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - config_path = config.get_config_path() - - # Check that file exists - assert config_path.exists() - - # Check that it's in the expected directory - expected_cache_dir = SPD_CACHE_DIR / "clustering_run_configs" - assert config_path.parent == expected_cache_dir - - # Check that filename is the hash - expected_hash = inline_config.stable_hash_b64() - assert config_path.name == f"{expected_hash}.json" - - # Check that file contents match the config - loaded_config = ClusteringRunConfig.from_file(config_path) - assert loaded_config == inline_config - - # Clean up - config_path.unlink() - - def test_reuses_existing_cached_file(self): - """Test that get_config_path reuses existing cached file with same hash.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - - config1 = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - # First call creates the file - config_path1 = config1.get_config_path() - assert config_path1.exists() - - # Record modification time - mtime1 = config_path1.stat().st_mtime - - # Create another config with same inline config - config2 = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=3, # Different n_runs shouldn't matter - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - # Second call should reuse the file - config_path2 = config2.get_config_path() - - assert config_path1 == config_path2 - assert config_path2.stat().st_mtime == mtime1 # File not modified - - # Clean up - config_path1.unlink() - - def test_hash_collision_detection(self): - """Test that hash collision is detected when existing file differs.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - - # Create a fake collision by manually creating a file with same hash - hash_value = inline_config.stable_hash_b64() - cache_dir = SPD_CACHE_DIR / "clustering_run_configs" - cache_dir.mkdir(parents=True, exist_ok=True) - collision_path = cache_dir / f"{hash_value}.json" - - # Write a different config to the file - different_config = ClusteringRunConfig( - model_path="wandb:test/project/run2", # Different! - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - different_config.to_file(collision_path) - - try: - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - slurm_job_name_prefix=None, - slurm_partition=None, - wandb_entity="test", - create_git_snapshot=False, - ) - - # Should raise ValueError about hash collision - with pytest.raises(ValueError, match="Hash collision detected"): - config.get_config_path() - finally: - # Clean up - if collision_path.exists(): - collision_path.unlink() + assert config.run_clustering_config_path == expected_path def _get_config_files(path: Path): @@ -269,7 +121,7 @@ def test_config_validate_pipeline(self, config_file: Path): """Test that each pipeline config file is valid.""" print(config_file) _config = ClusteringPipelineConfig.from_file(config_file) - crc_path = _config.get_config_path() + crc_path = _config.run_clustering_config_path print(f"{crc_path = }") assert crc_path.exists() From 2a9f731b6e9453f1d439f23a862cd69085a02191 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 17:58:46 +0100 Subject: [PATCH 31/61] rename run_clustering_config_path -> clustering_run_config_path old name didnt make sense, since it should be a path to a file with a `ClusteringRunConfig` see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 --- spd/clustering/configs/README.md | 2 +- .../configs/pipeline-dev-simplestories.yaml | 2 +- spd/clustering/configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- spd/clustering/scripts/run_pipeline.py | 12 ++++++------ tests/clustering/test_pipeline_config.py | 10 +++++----- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md index 51db8e8a0..e1ac41f47 100644 --- a/spd/clustering/configs/README.md +++ b/spd/clustering/configs/README.md @@ -1 +1 @@ -this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `clustering_run_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 6d181424a..1868b5887 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,4 +6,4 @@ slurm_partition: null wandb_project: "spd-cluster" # wandb fails in CI wandb_entity: "goodfire" create_git_snapshot: false -run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file +clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index db72fa3c0..37833c82c 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" +clustering_run_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index 24e686023..9872062d2 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/crc/test-simplestories.json" +clustering_run_config_path: "spd/clustering/configs/crc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 297b47d7b..3a533885d 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/crc/example.yaml" +clustering_run_config_path: "spd/clustering/configs/crc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 614d7ac17..179bc8bca 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -71,7 +71,7 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path = Field( + clustering_run_config_path: Path = Field( description="Path to ClusteringRunConfig file.", ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") @@ -97,12 +97,12 @@ class ClusteringPipelineConfig(BaseConfig): @model_validator(mode="after") def validate_crc(self) -> "ClusteringPipelineConfig": - """Validate that exactly one of run_clustering_config_path points to a valid `ClusteringRunConfig`.""" - assert self.run_clustering_config_path.exists(), ( - f"run_clustering_config_path does not exist: {self.run_clustering_config_path}" + """Validate that exactly one of clustering_run_config_path points to a valid `ClusteringRunConfig`.""" + assert self.clustering_run_config_path.exists(), ( + f"clustering_run_config_path does not exist: {self.clustering_run_config_path}" ) # Try to load ClusteringRunConfig - assert ClusteringRunConfig.from_file(self.run_clustering_config_path) + assert ClusteringRunConfig.from_file(self.clustering_run_config_path) return self @@ -170,7 +170,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.run_clustering_config_path.as_posix(), + pipeline_config.clustering_run_config_path.as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 264078392..ca6bad6ee 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -71,10 +71,10 @@ class TestClusteringPipelineConfigValidation: """Test ClusteringPipelineConfig validation logic.""" def test_error_when_path_does_not_exist(self): - """Test that error is raised when run_clustering_config_path does not exist.""" + """Test that error is raised when clustering_run_config_path does not exist.""" with pytest.raises(pydantic_core._pydantic_core.ValidationError): ClusteringPipelineConfig( - run_clustering_config_path=Path("nonexistent/path.json"), + clustering_run_config_path=Path("nonexistent/path.json"), n_runs=2, distances_methods=["perm_invariant_hamming"], base_output_dir=Path("/tmp/test"), @@ -89,7 +89,7 @@ def test_valid_config_with_existing_path(self): expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") config = ClusteringPipelineConfig( - run_clustering_config_path=expected_path, + clustering_run_config_path=expected_path, n_runs=2, distances_methods=["perm_invariant_hamming"], base_output_dir=Path("/tmp/test"), @@ -97,7 +97,7 @@ def test_valid_config_with_existing_path(self): create_git_snapshot=False, ) - assert config.run_clustering_config_path == expected_path + assert config.clustering_run_config_path == expected_path def _get_config_files(path: Path): @@ -121,7 +121,7 @@ def test_config_validate_pipeline(self, config_file: Path): """Test that each pipeline config file is valid.""" print(config_file) _config = ClusteringPipelineConfig.from_file(config_file) - crc_path = _config.run_clustering_config_path + crc_path = _config.clustering_run_config_path print(f"{crc_path = }") assert crc_path.exists() From 04ae8699529800b0231e656ba55ab3b6dad11b81 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 12:31:37 +0000 Subject: [PATCH 32/61] 4 claudes in parallel --- spd/clustering/batched_activations.py | 182 +++++++++++++++++++++ spd/clustering/clustering_run_config.py | 4 + spd/clustering/math/merge_pair_samplers.py | 49 ++++-- spd/clustering/merge.py | 112 ++++++++++--- spd/clustering/merge_config.py | 5 +- spd/clustering/scripts/run_clustering.py | 169 +++++++++++-------- spd/clustering/scripts/run_pipeline.py | 20 +++ 7 files changed, 437 insertions(+), 104 deletions(-) create mode 100644 spd/clustering/batched_activations.py diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py new file mode 100644 index 000000000..30703a0ba --- /dev/null +++ b/spd/clustering/batched_activations.py @@ -0,0 +1,182 @@ +"""Activation batch storage and precomputation for multi-batch clustering. + +This module provides: +1. Data structures for storing and loading activation batches (ActivationBatch, BatchedActivations) +2. Precomputation logic to generate batches for ensemble runs (precompute_batches_for_ensemble) +""" + +import gc +from dataclasses import dataclass +from pathlib import Path + +import torch +from torch import Tensor +from tqdm import tqdm + +from spd.clustering.activations import component_activations, process_activations +from spd.clustering.dataset import load_dataset +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.utils.distributed_utils import get_device + + +@dataclass +class ActivationBatch: + """Single batch of activations - just tensors, no processing.""" + + activations: Tensor # [samples, n_components] + labels: list[str] # ["module:idx", ...] + + def save(self, path: Path) -> None: + torch.save({"activations": self.activations, "labels": self.labels}, path) + + @staticmethod + def load(path: Path) -> "ActivationBatch": + data = torch.load(path, weights_only=False) + return ActivationBatch( + activations=data["activations"], + labels=data["labels"], + ) + + +class BatchedActivations: + """Iterator over activation batches from disk.""" + + def __init__(self, batch_dir: Path): + self.batch_dir = batch_dir + # Find all batch files: batch_0.pt, batch_1.pt, ... + self.batch_paths = sorted(batch_dir.glob("batch_*.pt")) + assert len(self.batch_paths) > 0, f"No batch files found in {batch_dir}" + self.current_idx = 0 + + @property + def n_batches(self) -> int: + return len(self.batch_paths) + + def get_next_batch(self) -> ActivationBatch: + """Load and return next batch, cycling through available batches.""" + batch = ActivationBatch.load(self.batch_paths[self.current_idx]) + self.current_idx = (self.current_idx + 1) % self.n_batches + return batch + + +def precompute_batches_for_ensemble( + clustering_run_config: "ClusteringRunConfig", + n_runs: int, + output_dir: Path, +) -> Path | None: + """ + Precompute activation batches for all runs in ensemble. + + This loads the model ONCE and generates all batches for all runs, + then saves them to disk. Each clustering run will load batches + from disk without needing the model. + + Args: + clustering_run_config: Configuration for clustering runs + n_runs: Number of runs in the ensemble + output_dir: Base directory to save precomputed batches + + Returns: + Path to base directory containing batches for all runs, + or None if single-batch mode (recompute_costs_every=1) + """ + # Check if multi-batch mode + recompute_every = clustering_run_config.merge_config.recompute_costs_every + if recompute_every == 1: + logger.info("Single-batch mode (recompute_costs_every=1), skipping precomputation") + return None + + logger.info("Multi-batch mode detected, precomputing activation batches") + + # Load model to determine number of components + device = get_device() + spd_run = SPDRunInfo.from_path(clustering_run_config.model_path) + model = ComponentModel.from_run_info(spd_run).to(device) + task_name = spd_run.config.task_config.task_name + + # Get number of components (no filtering, so just count from model) + # Load a sample to count components + logger.info("Loading sample batch to count components") + sample_batch = load_dataset( + model_path=clustering_run_config.model_path, + task_name=task_name, + batch_size=clustering_run_config.batch_size, + seed=0, + ).to(device) + + with torch.no_grad(): + sample_acts = component_activations(model, device, sample_batch) + + # Count total components across all modules + n_components = sum(act.shape[-1] for act in sample_acts.values()) + + # Calculate number of iterations + n_iters = clustering_run_config.merge_config.get_num_iters(n_components) + + # Calculate batches needed per run + n_batches_needed = (n_iters + recompute_every - 1) // recompute_every + + logger.info(f"Precomputing {n_batches_needed} batches per run for {n_runs} runs") + logger.info(f"Total: {n_batches_needed * n_runs} batches") + + # Create batches directory + batches_base_dir = output_dir / "precomputed_batches" + batches_base_dir.mkdir(exist_ok=True, parents=True) + + # For each run in ensemble + for run_idx in tqdm(range(n_runs), desc="Ensemble runs"): + run_batch_dir = batches_base_dir / f"run_{run_idx}" + run_batch_dir.mkdir(exist_ok=True) + + # Generate batches for this run + for batch_idx in tqdm( + range(n_batches_needed), + desc=f" Run {run_idx} batches", + leave=False, + ): + # Use unique seed: base_seed + run_idx * 1000 + batch_idx + # This ensures different data for each run and each batch + seed = clustering_run_config.dataset_seed + run_idx * 1000 + batch_idx + + # Load data + batch_data = load_dataset( + model_path=clustering_run_config.model_path, + task_name=task_name, + batch_size=clustering_run_config.batch_size, + seed=seed, + ).to(device) + + # Compute activations + with torch.no_grad(): + acts_dict = component_activations(model, device, batch_data) + + # Process (concat, NO FILTERING) + processed = process_activations( + activations=acts_dict, + filter_dead_threshold=0.0, # NO FILTERING + seq_mode="concat" if task_name == "lm" else None, + filter_modules=None, + ) + + # Save as ActivationBatch + activation_batch = ActivationBatch( + activations=processed.activations.cpu(), # Move to CPU for storage + labels=list(processed.labels), + ) + activation_batch.save(run_batch_dir / f"batch_{batch_idx}.pt") + + # Clean up + del batch_data, acts_dict, processed, activation_batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Clean up model + del model, sample_batch, sample_acts + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info(f"All batches precomputed and saved to {batches_base_dir}") + + return batches_base_dir diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index 95d72f9bd..1da579488 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -69,6 +69,10 @@ class ClusteringRunConfig(BaseConfig): default=False, description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", ) + precomputed_activations_dir: Path | None = Field( + default=None, + description="Path to directory containing precomputed activation batches. If None, batches will be auto-generated before merging starts.", + ) @model_validator(mode="before") def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py index 24c050d36..a68f7c327 100644 --- a/spd/clustering/math/merge_pair_samplers.py +++ b/spd/clustering/math/merge_pair_samplers.py @@ -36,7 +36,7 @@ def range_sampler( of the range of non-diagonal costs, then randomly selects one. Args: - costs: Cost matrix for all possible merges + costs: Cost matrix for all possible merges (may contain NaN for invalid pairs) k_groups: Number of current groups threshold: Fraction of cost range to consider (0=min only, 1=all pairs) @@ -47,22 +47,36 @@ def range_sampler( k_groups: int = costs.shape[0] assert costs.shape[1] == k_groups, "Cost matrix must be square" - # Find the range of non-diagonal costs - non_diag_costs: Float[Tensor, " k_groups_squared_minus_k"] = costs[ - ~torch.eye(k_groups, dtype=torch.bool, device=costs.device) - ] - min_cost: float = float(non_diag_costs.min().item()) - max_cost: float = float(non_diag_costs.max().item()) + # Mask out NaN entries and diagonal + valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.isnan(costs) + diag_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( + k_groups, dtype=torch.bool, device=costs.device + ) + valid_mask = valid_mask & diag_mask + + # Get valid costs + valid_costs: Float[Tensor, " n_valid"] = costs[valid_mask] + + if valid_costs.numel() == 0: + raise ValueError("All costs are NaN, cannot sample merge pair") + + # Find the range of valid costs + min_cost: float = float(valid_costs.min().item()) + max_cost: float = float(valid_costs.max().item()) # Calculate threshold cost max_considered_cost: float = (max_cost - min_cost) * threshold + min_cost - # Find all pairs below threshold + # Find all valid pairs below threshold + within_range: Bool[Tensor, "k_groups k_groups"] = (costs <= max_considered_cost) & valid_mask + + # Get indices of candidate pairs considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack( - torch.where(costs <= max_considered_cost), dim=1 + torch.where(within_range), dim=1 ) - # Remove diagonal entries (i == j) - considered_idxs = considered_idxs[considered_idxs[:, 0] != considered_idxs[:, 1]] + + if considered_idxs.shape[0] == 0: + raise ValueError("No valid pairs within threshold range") # Randomly select one of the considered pairs selected_idx: int = random.randint(0, considered_idxs.shape[0] - 1) @@ -78,7 +92,7 @@ def mcmc_sampler( """Sample a merge pair using MCMC with probability proportional to exp(-cost/temperature). Args: - costs: Cost matrix for all possible merges + costs: Cost matrix for all possible merges (may contain NaN for invalid pairs) k_groups: Number of current groups temperature: Temperature parameter for softmax (higher = more uniform sampling) @@ -89,21 +103,26 @@ def mcmc_sampler( k_groups: int = costs.shape[0] assert costs.shape[1] == k_groups, "Cost matrix must be square" - # Create mask for valid pairs (non-diagonal) + # Create mask for valid pairs (non-diagonal and non-NaN) valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( k_groups, dtype=torch.bool, device=costs.device ) + valid_mask = valid_mask & ~torch.isnan(costs) + + # Check if we have any valid pairs + if not valid_mask.any(): + raise ValueError("All costs are NaN, cannot sample merge pair") # Compute probabilities: exp(-cost/temperature) # Use stable softmax computation to avoid overflow costs_masked: ClusterCoactivationShaped = costs.clone() - costs_masked[~valid_mask] = float("inf") # Set diagonal to inf so exp gives 0 + costs_masked[~valid_mask] = float("inf") # Set invalid entries to inf so exp gives 0 # Subtract min for numerical stability min_cost: float = float(costs_masked[valid_mask].min()) probs: ClusterCoactivationShaped = ( torch.exp((min_cost - costs_masked) / temperature) * valid_mask - ) # Zero out diagonal + ) # Zero out invalid entries probs_flatten: Float[Tensor, " k_groups_squared"] = probs.flatten() probs_flatten = probs_flatten / probs_flatten.sum() diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index dba55c878..9419abd44 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -5,13 +5,14 @@ """ import warnings -from typing import Protocol +from typing import Protocol, runtime_checkable import torch from jaxtyping import Bool, Float from torch import Tensor from tqdm import tqdm +from spd.clustering.batched_activations import ActivationBatch, BatchedActivations from spd.clustering.compute_costs import ( compute_mdl_cost, compute_merge_costs, @@ -29,6 +30,38 @@ from spd.clustering.merge_history import MergeHistory +def recompute_coacts_from_scratch( + activations: Tensor, + current_merge: GroupMerge, + activation_threshold: float | None, +) -> tuple[Tensor, Tensor]: + """ + Recompute coactivations from fresh activations using current merge state. + + Args: + activations: Fresh activation tensor [samples, n_components_original] + current_merge: Current merge state mapping original -> groups + activation_threshold: Threshold for binarizing activations + + Returns: + (coact, activation_mask) - coact matrix [k_groups, k_groups] and + mask [samples, k_groups] for current groups + """ + # Apply threshold + activation_mask = ( + activations > activation_threshold if activation_threshold is not None else activations + ) + + # Apply current merge to get group-level activations + # current_merge.matrix is [c_original, k_groups] + group_activations = activation_mask @ current_merge.matrix.to(activation_mask.device) + + # Compute coactivations + coact = group_activations.float().T @ group_activations.float() + + return coact, group_activations + + class LogCallback(Protocol): def __call__( self, @@ -48,20 +81,25 @@ def __call__( def merge_iteration( merge_config: MergeConfig, - activations: ActivationsTensor, + batched_activations: BatchedActivations, component_labels: ComponentLabels, log_callback: LogCallback | None = None, ) -> MergeHistory: """ - Merge iteration with optional logging/plotting callbacks. + Merge iteration with multi-batch support and optional logging/plotting callbacks. - This wraps the pure computation with logging capabilities while maintaining - the same core algorithm logic. + This implementation uses NaN masking to track invalid coactivation entries + and periodically recomputes the full coactivation matrix from fresh batches. """ - # compute coactivations + # Load first batch + # -------------------------------------------------- + first_batch: ActivationBatch = batched_activations.get_next_batch() + activations: Tensor = first_batch.activations + + # Compute initial coactivations # -------------------------------------------------- - activation_mask_orig: BoolActivationsTensor | ActivationsTensor | None = ( + activation_mask_orig: BoolActivationsTensor | ActivationsTensor = ( activations > merge_config.activation_threshold if merge_config.activation_threshold is not None else activations @@ -110,27 +148,64 @@ def merge_iteration( merge_pair: MergePair = merge_config.merge_pair_sample(costs) + # Store merge pair cost before updating + # -------------------------------------------------- + merge_pair_cost: float = float(costs[merge_pair].item()) + # merge the pair # -------------------------------------------------- - # we do this *before* logging, so we can see how the sampled pair cost compares - # to the costs of all the other possible pairs - current_merge, current_coact, current_act_mask = recompute_coacts_merge_pair( - coact=current_coact, - merges=current_merge, - merge_pair=merge_pair, - activation_mask=current_act_mask, + # Update merge state BEFORE NaN-ing out + current_merge = current_merge.merge_groups(merge_pair[0], merge_pair[1]) + + # NaN out the merged components' rows/cols + i, j = merge_pair + new_idx: int = min(i, j) + remove_idx: int = max(i, j) + + # Mark affected entries as invalid (can't compute cost anymore without recompute) + current_coact[remove_idx, :] = float("nan") + current_coact[:, remove_idx] = float("nan") + current_coact[new_idx, :] = float("nan") + current_coact[:, new_idx] = float("nan") + + # Remove the deleted row/col to maintain shape consistency + mask: Bool[Tensor, " k_groups"] = torch.ones( + k_groups, dtype=torch.bool, device=current_coact.device ) + mask[remove_idx] = False + current_coact = current_coact[mask, :][:, mask] + current_act_mask = current_act_mask[:, mask] + + k_groups -= 1 - # metrics and logging - # -------------------------------------------------- # Store in history + # -------------------------------------------------- merge_history.add_iteration( idx=iter_idx, selected_pair=merge_pair, current_merge=current_merge, ) + # Recompute from new batch if it's time + # -------------------------------------------------- + should_recompute: bool = ( + (iter_idx + 1) % merge_config.recompute_costs_every == 0 + and iter_idx + 1 < num_iters + ) + + if should_recompute: + new_batch: ActivationBatch = batched_activations.get_next_batch() + activations = new_batch.activations + + # Recompute fresh coacts with current merge groups + current_coact, current_act_mask = recompute_coacts_from_scratch( + activations=activations, + current_merge=current_merge, + activation_threshold=merge_config.activation_threshold, + ) + # Compute metrics for logging + # -------------------------------------------------- # the MDL loss computed here is the *cost of the current merge*, a single scalar value # rather than the *delta in cost from merging a specific pair* (which is what `costs` matrix contains) diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) @@ -140,8 +215,6 @@ def merge_iteration( alpha=merge_config.alpha, ) mdl_loss_norm: float = mdl_loss / current_act_mask.shape[0] - # this is the cost for the selected pair - merge_pair_cost: float = float(costs[merge_pair].item()) # Update progress bar pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") @@ -161,9 +234,8 @@ def merge_iteration( diag_acts=diag_acts, ) - # iterate and sanity checks + # Sanity checks # -------------------------------------------------- - k_groups -= 1 assert current_coact.shape[0] == k_groups, ( "Coactivation matrix shape should match number of groups" ) diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 0f94996f4..8c2f44ca6 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -72,10 +72,9 @@ class MergeConfig(BaseConfig): default=None, description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", ) - # TODO: unsure of this var name recompute_costs_every: PositiveInt = Field( - default=10, - description="How often to recompute the full cost matrix, replacing NaN values of merged components with their true value. Higher values mean less accurate merges but faster computation.", + default=1, + description="Number of merges before recomputing costs with new batch. Set to 1 for original behavior.", ) batch_size: PositiveInt = Field( default=64, diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 54f0805c6..b978c92a6 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -31,6 +31,7 @@ component_activations, process_activations, ) +from spd.clustering.batched_activations import ActivationBatch, BatchedActivations from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( ActivationsTensor, @@ -258,25 +259,7 @@ def main(run_config: ClusteringRunConfig) -> Path: spd_run = SPDRunInfo.from_path(run_config.model_path) task_name: TaskName = spd_run.config.task_config.task_name - # 1. Load dataset - logger.info(f"Loading dataset (seed={run_config.dataset_seed})") - load_dataset_kwargs: dict[str, Any] = dict() - if run_config.dataset_streaming: - logger.info("Using streaming dataset loading") - load_dataset_kwargs["config_kwargs"] = dict(streaming=True) - assert task_name == "lm", ( - f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." - ) - batch: BatchTensor = load_dataset( - model_path=run_config.model_path, - task_name=task_name, - batch_size=run_config.batch_size, - seed=run_config.dataset_seed, - **load_dataset_kwargs, - ) - batch = batch.to(device) - - # 2. Setup WandB for this run + # Setup WandB for this run wandb_run: Run | None = None if run_config.wandb_project is not None: wandb_run = wandb.init( @@ -293,58 +276,104 @@ def main(run_config: ClusteringRunConfig) -> Path: f"assigned_idx:{assigned_idx}", ], ) - # logger.info(f"WandB run: {wandb_run.url}") - - # 3. Load model - logger.info("Loading model") - model = ComponentModel.from_run_info(spd_run).to(device) - - # 4. Compute activations - logger.info("Computing activations") - activations_dict: ( - dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] - ) = component_activations( - model=model, - batch=batch, - device=device, - ) - # 5. Process activations - logger.info("Processing activations") - processed_activations: ProcessedActivations = process_activations( - activations=activations_dict, - filter_dead_threshold=run_config.merge_config.filter_dead_threshold, - seq_mode="concat" if task_name == "lm" else None, - filter_modules=run_config.merge_config.filter_modules, - ) + # Load or compute activations + # ===================================== + batched_activations: BatchedActivations + component_labels: ComponentLabels - # 6. Log activations (if WandB enabled) - if wandb_run is not None: - logger.info("Plotting activations") - plot_activations( - processed_activations=processed_activations, - save_dir=None, # Don't save to disk, only WandB - n_samples_max=256, - wandb_run=wandb_run, + if run_config.precomputed_activations_dir is not None: + # Case 1: Use precomputed batches from disk + logger.info(f"Loading precomputed batches from {run_config.precomputed_activations_dir}") + batched_activations = BatchedActivations(run_config.precomputed_activations_dir) + + # Get labels from first batch + first_batch = batched_activations.get_next_batch() + component_labels = ComponentLabels(first_batch.labels) + + logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") + + else: + # Case 2: Compute single batch on-the-fly (original behavior) + logger.info(f"Computing single batch (seed={run_config.dataset_seed})") + + # Load model + logger.info("Loading model") + model = ComponentModel.from_run_info(spd_run).to(device) + + # Load data + logger.info("Loading dataset") + load_dataset_kwargs: dict[str, Any] = dict() + if run_config.dataset_streaming: + logger.info("Using streaming dataset loading") + load_dataset_kwargs["config_kwargs"] = dict(streaming=True) + assert task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) + + batch: BatchTensor = load_dataset( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + **load_dataset_kwargs, + ).to(device) + + # Compute activations + logger.info("Computing activations") + activations_dict: ( + dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] + ) = component_activations( + model=model, + batch=batch, + device=device, ) - wandb_log_tensor( - wandb_run, - processed_activations.activations, - "activations", - 0, - single=True, + + # Process (concat modules, with filtering) + logger.info("Processing activations") + processed: ProcessedActivations = process_activations( + activations=activations_dict, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + seq_mode="concat" if task_name == "lm" else None, + filter_modules=run_config.merge_config.filter_modules, ) - # Clean up memory - activations: ActivationsTensor = processed_activations.activations - component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) - del processed_activations - del activations_dict - del model - del batch - gc.collect() + # Save as single batch to temp dir + temp_batch_dir = storage.base_dir / "temp_batch" + temp_batch_dir.mkdir(exist_ok=True) + + single_batch = ActivationBatch( + activations=processed.activations, + labels=list(processed.labels), + ) + single_batch.save(temp_batch_dir / "batch_0.pt") + + batched_activations = BatchedActivations(temp_batch_dir) + component_labels = processed.labels + + # Log activations to WandB (if enabled) + if wandb_run is not None: + logger.info("Plotting activations") + plot_activations( + processed_activations=processed, + save_dir=None, + n_samples_max=256, + wandb_run=wandb_run, + ) + wandb_log_tensor( + wandb_run, + processed.activations, + "activations", + 0, + single=True, + ) - # 7. Run merge iteration + # Clean up memory + del model, batch, activations_dict, processed + gc.collect() + + # Run merge iteration + # ===================================== logger.info("Starting merging") log_callback: LogCallback | None = ( partial(_log_callback, run=wandb_run, run_config=run_config) @@ -354,7 +383,7 @@ def main(run_config: ClusteringRunConfig) -> Path: history: MergeHistory = merge_iteration( merge_config=run_config.merge_config, - activations=activations, + batched_activations=batched_activations, component_labels=component_labels, log_callback=log_callback, ) @@ -412,6 +441,12 @@ def cli() -> None: action="store_true", help="Whether to use streaming dataset loading (if supported by the dataset)", ) + parser.add_argument( + "--precomputed-activations-dir", + type=Path, + default=None, + help="Path to directory containing precomputed activation batches", + ) args: argparse.Namespace = parser.parse_args() @@ -431,6 +466,8 @@ def cli() -> None: overrides["wandb_project"] = args.wandb_project if args.wandb_entity is not None: overrides["wandb_entity"] = args.wandb_entity + if args.precomputed_activations_dir is not None: + overrides["precomputed_activations_dir"] = args.precomputed_activations_dir run_config = replace_pydantic_model(run_config, overrides) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 179bc8bca..cf4971a63 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -28,6 +28,7 @@ from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig +from spd.clustering.batched_activations import precompute_batches_for_ensemble from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import DistancesMethod from spd.clustering.storage import StorageBase @@ -151,6 +152,7 @@ def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str def generate_clustering_commands( pipeline_config: ClusteringPipelineConfig, pipeline_run_id: str, + batches_base_dir: Path | None = None, dataset_streaming: bool = False, ) -> list[str]: """Generate commands for each clustering run. @@ -158,6 +160,7 @@ def generate_clustering_commands( Args: pipeline_config: Pipeline configuration pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) + batches_base_dir: Path to precomputed batches directory, or None for single-batch mode dataset_streaming: Whether to use dataset streaming Returns: @@ -180,6 +183,12 @@ def generate_clustering_commands( "--wandb-entity", pipeline_config.wandb_entity, ] + + # Add precomputed batches path if available + if batches_base_dir is not None: + run_batch_dir = batches_base_dir / f"run_{idx}" + cmd_parts.extend(["--precomputed-activations-dir", str(run_batch_dir)]) + if dataset_streaming: cmd_parts.append("--dataset-streaming") @@ -268,10 +277,21 @@ def main( ) logger.info(f"WandB workspace: {workspace_url}") + # Precompute batches if multi-batch mode + clustering_run_config = ClusteringRunConfig.from_file( + pipeline_config.clustering_run_config_path + ) + batches_base_dir = precompute_batches_for_ensemble( + clustering_run_config=clustering_run_config, + n_runs=pipeline_config.n_runs, + output_dir=storage.base_dir, + ) + # Generate commands for clustering runs clustering_commands = generate_clustering_commands( pipeline_config=pipeline_config, pipeline_run_id=pipeline_run_id, + batches_base_dir=batches_base_dir, dataset_streaming=dataset_streaming, ) From a18a8072c8a81b5169c152f42902a5d6efed5a3e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 12:32:43 +0000 Subject: [PATCH 33/61] make format --- spd/clustering/batched_activations.py | 4 ++++ spd/clustering/math/merge_pair_samplers.py | 4 +--- spd/clustering/merge.py | 8 +++----- spd/clustering/scripts/run_clustering.py | 1 - 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 30703a0ba..7916dec72 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -8,6 +8,7 @@ import gc from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING import torch from torch import Tensor @@ -19,6 +20,9 @@ from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device +if TYPE_CHECKING: + from spd.clustering.clustering_run_config import ClusteringRunConfig + @dataclass class ActivationBatch: diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py index a68f7c327..384122711 100644 --- a/spd/clustering/math/merge_pair_samplers.py +++ b/spd/clustering/math/merge_pair_samplers.py @@ -71,9 +71,7 @@ def range_sampler( within_range: Bool[Tensor, "k_groups k_groups"] = (costs <= max_considered_cost) & valid_mask # Get indices of candidate pairs - considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack( - torch.where(within_range), dim=1 - ) + considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack(torch.where(within_range), dim=1) if considered_idxs.shape[0] == 0: raise ValueError("No valid pairs within threshold range") diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 9419abd44..025390959 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -5,7 +5,7 @@ """ import warnings -from typing import Protocol, runtime_checkable +from typing import Protocol import torch from jaxtyping import Bool, Float @@ -16,7 +16,6 @@ from spd.clustering.compute_costs import ( compute_mdl_cost, compute_merge_costs, - recompute_coacts_merge_pair, ) from spd.clustering.consts import ( ActivationsTensor, @@ -189,9 +188,8 @@ def merge_iteration( # Recompute from new batch if it's time # -------------------------------------------------- should_recompute: bool = ( - (iter_idx + 1) % merge_config.recompute_costs_every == 0 - and iter_idx + 1 < num_iters - ) + iter_idx + 1 + ) % merge_config.recompute_costs_every == 0 and iter_idx + 1 < num_iters if should_recompute: new_batch: ActivationBatch = batched_activations.get_next_batch() diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index b978c92a6..1ef7cd6cb 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -34,7 +34,6 @@ from spd.clustering.batched_activations import ActivationBatch, BatchedActivations from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( - ActivationsTensor, BatchTensor, ClusterCoactivationShaped, ComponentLabels, From 483e91cb0ab56a788c2e31c2e1eaa7e154e4cc4d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 12:42:19 +0000 Subject: [PATCH 34/61] type checking and compat fixes --- spd/clustering/batched_activations.py | 30 +++++++++++++++++ spd/clustering/merge.py | 18 +++++++++-- tests/clustering/scripts/cluster_resid_mlp.py | 13 ++++++-- tests/clustering/scripts/cluster_ss.py | 7 +++- tests/clustering/test_merge_integration.py | 32 ++++++++++++++++--- 5 files changed, 89 insertions(+), 11 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 7916dec72..ede4cf774 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -64,6 +64,36 @@ def get_next_batch(self) -> ActivationBatch: return batch +def batched_activations_from_tensor( + activations: Tensor, + labels: list[str], +) -> BatchedActivations: + """ + Create a BatchedActivations instance from a single activation tensor. + + This is a helper for backward compatibility with tests and code that uses + single-batch mode. It creates a temporary directory with a single batch file. + + Args: + activations: Activation tensor [samples, n_components] + labels: Component labels ["module:idx", ...] + + Returns: + BatchedActivations instance that cycles through the single batch + """ + import tempfile + + # Create a temporary directory + temp_dir = Path(tempfile.mkdtemp(prefix="batch_temp_")) + + # Save the single batch + batch = ActivationBatch(activations=activations, labels=labels) + batch.save(temp_dir / "batch_0.pt") + + # Return BatchedActivations that will cycle through this single batch + return BatchedActivations(temp_dir) + + def precompute_batches_for_ensemble( clustering_run_config: "ClusteringRunConfig", n_runs: int, diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 025390959..0e7a1b800 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -51,9 +51,21 @@ def recompute_coacts_from_scratch( activations > activation_threshold if activation_threshold is not None else activations ) - # Apply current merge to get group-level activations - # current_merge.matrix is [c_original, k_groups] - group_activations = activation_mask @ current_merge.matrix.to(activation_mask.device) + # Map component-level activations to group-level using scatter_add + # This is more efficient than materializing the full merge matrix + # current_merge.group_idxs: [n_components] with values 0 to k_groups-1 + n_samples = activation_mask.shape[0] + group_activations = torch.zeros( + (n_samples, current_merge.k_groups), + dtype=activation_mask.dtype, + device=activation_mask.device, + ) + + # Expand group_idxs to match batch dimension and scatter-add activations by group + group_idxs_expanded = ( + current_merge.group_idxs.unsqueeze(0).expand(n_samples, -1).to(activation_mask.device) + ) + group_activations.scatter_add_(1, group_idxs_expanded, activation_mask) # Compute coactivations coact = group_activations.float().T @ group_activations.float() diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index bbfb5259e..a5fdd6956 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -12,6 +12,7 @@ component_activations, process_activations, ) +from spd.clustering.batched_activations import batched_activations_from_tensor from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig @@ -148,9 +149,13 @@ def _plot_func( ) +BATCHED_ACTIVATIONS = batched_activations_from_tensor( + activations=PROCESSED_ACTIVATIONS.activations, + labels=list(PROCESSED_ACTIVATIONS.labels), +) MERGE_HIST: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - activations=PROCESSED_ACTIVATIONS.activations, + batched_activations=BATCHED_ACTIVATIONS, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=_plot_func, ) @@ -172,9 +177,13 @@ def _plot_func( ENSEMBLE_SIZE: int = 4 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): + batched_acts = batched_activations_from_tensor( + activations=PROCESSED_ACTIVATIONS.activations, + labels=list(PROCESSED_ACTIVATIONS.labels), + ) HISTORY: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - activations=PROCESSED_ACTIVATIONS.activations, + batched_activations=batched_acts, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=None, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 45c142fa0..8c7b42feb 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,6 +16,7 @@ component_activations, process_activations, ) +from spd.clustering.batched_activations import batched_activations_from_tensor from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration @@ -111,9 +112,13 @@ ENSEMBLE_SIZE: int = 2 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): + batched_acts = batched_activations_from_tensor( + activations=PROCESSED_ACTIVATIONS.activations, + labels=list(PROCESSED_ACTIVATIONS.labels), + ) HISTORY: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - activations=PROCESSED_ACTIVATIONS.activations, + batched_activations=batched_acts, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=None, ) diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 8492300de..d1b1a9571 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -2,6 +2,7 @@ import torch +from spd.clustering.batched_activations import batched_activations_from_tensor from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig @@ -29,8 +30,13 @@ def test_merge_with_range_sampler(self): ) # Run merge iteration + batched_activations = batched_activations_from_tensor( + activations=activations, labels=list(component_labels) + ) history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels + batched_activations=batched_activations, + merge_config=config, + component_labels=component_labels, ) # Check results @@ -62,8 +68,13 @@ def test_merge_with_mcmc_sampler(self): ) # Run merge iteration + batched_activations = batched_activations_from_tensor( + activations=activations, labels=list(component_labels) + ) history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels + batched_activations=batched_activations, + merge_config=config, + component_labels=component_labels, ) # Check results @@ -97,8 +108,11 @@ def test_merge_comparison_samplers(self): merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum ) + batched_activations_range = batched_activations_from_tensor( + activations=activations.clone(), labels=list(component_labels) + ) history_range = merge_iteration( - activations=activations.clone(), + batched_activations=batched_activations_range, merge_config=config_range, component_labels=ComponentLabels(component_labels.copy()), ) @@ -112,8 +126,11 @@ def test_merge_comparison_samplers(self): merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp ) + batched_activations_mcmc = batched_activations_from_tensor( + activations=activations.clone(), labels=list(component_labels) + ) history_mcmc = merge_iteration( - activations=activations.clone(), + batched_activations=batched_activations_mcmc, merge_config=config_mcmc, component_labels=ComponentLabels(component_labels.copy()), ) @@ -140,8 +157,13 @@ def test_merge_with_small_components(self): merge_pair_sampling_kwargs={"temperature": 2.0}, ) + batched_activations = batched_activations_from_tensor( + activations=activations, labels=list(component_labels) + ) history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels + batched_activations=batched_activations, + merge_config=config, + component_labels=component_labels, ) # First entry is after first merge, so should be 3 - 1 = 2 From 6f6696ade9c4d105ce5e4146139cb5e1ff946549 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 12:44:36 +0000 Subject: [PATCH 35/61] old todos, removing in next commit --- TASKS-claude-1.md | 170 ++++++++++++ TASKS-claude-2.md | 98 +++++++ TASKS-claude-3.md | 266 +++++++++++++++++++ TASKS-claude-4.md | 208 +++++++++++++++ TODO-multibatch.md | 640 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1382 insertions(+) create mode 100644 TASKS-claude-1.md create mode 100644 TASKS-claude-2.md create mode 100644 TASKS-claude-3.md create mode 100644 TASKS-claude-4.md create mode 100644 TODO-multibatch.md diff --git a/TASKS-claude-1.md b/TASKS-claude-1.md new file mode 100644 index 000000000..7ad79bbd2 --- /dev/null +++ b/TASKS-claude-1.md @@ -0,0 +1,170 @@ +# Task 1: Batch Storage Infrastructure - Completion Report + +## Summary + +Completed Task 1 from `TODO-multibatch.md`: Batch Storage Infrastructure. This provides the foundation for multi-batch clustering by implementing data structures and precomputation logic for activation batches. + +## Changes Made + +### 1. Created `spd/clustering/batched_activations.py` (~180 lines) + +This file consolidates all batch-related functionality (previously split between `batched_activations.py` and `precompute_batches.py`). + +**Components:** +- **`ActivationBatch`**: Dataclass for storing a single batch of activations with labels + - `save()`: Saves batch to disk as `.pt` file + - `load()`: Loads batch from disk + +- **`BatchedActivations`**: Iterator for cycling through multiple batches from disk + - Finds all `batch_*.pt` files in a directory + - `get_next_batch()`: Returns next batch in round-robin fashion + - `n_batches`: Property for total number of batches + +- **`precompute_batches_for_ensemble()`**: Function to generate all batches for ensemble runs + - Loads model once, generates all batches for all runs + - Saves batches to disk with structure: `/precomputed_batches/run_{idx}/batch_{idx}.pt` + - Returns `None` if `recompute_costs_every=1` (single-batch mode) + - Uses unique seeds per batch: `base_seed + run_idx * 1000 + batch_idx` + +### 2. Updated `spd/clustering/merge_config.py` + +**Changed:** +- `recompute_costs_every`: Updated default from `10` to `1` (original behavior) +- Updated description to match TODO spec + +**Rationale:** Default of 1 maintains backward compatibility - single batch mode is the original behavior. + +### 3. Updated `spd/clustering/clustering_run_config.py` + +**Added:** +- `precomputed_activations_dir: Path | None = None` +- Description: "Path to directory containing precomputed activation batches. If None, batches will be auto-generated before merging starts." + +**Key Design Decision:** When `None`, the system will auto-generate all required batches in a temp directory before merge starts (Option A from user clarification). + +### 4. Updated `spd/clustering/scripts/run_clustering.py` + +**Added:** +- CLI argument: `--precomputed-activations-dir` +- Override logic to pass this value to the config + +### 5. Merged Files + +**Deleted:** `spd/clustering/precompute_batches.py` + +**Rationale:** The two files were tightly coupled (~50 and ~140 lines), used together, and would evolve together. Combining them into one `batched_activations.py` file makes the codebase simpler. + +**Updated imports:** +- `spd/clustering/scripts/run_pipeline.py`: Now imports from `batched_activations` + +## Concerns & Notes + +### 1. **Circular Import Risk** ⚠️ + +In `batched_activations.py`, the `precompute_batches_for_ensemble()` function has a type annotation: +```python +def precompute_batches_for_ensemble( + clustering_run_config: "ClusteringRunConfig", # String annotation to avoid circular import + ... +) +``` + +This is a forward reference (string annotation) because: +- `batched_activations.py` imports from `clustering` modules +- `clustering_run_config.py` might import from `batched_activations.py` in the future + +**Status:** Currently safe, but watch for circular imports if we add imports to configs. + +### 2. **Type Checking Not Verified** ⚠️ + +I attempted to run a basic import test but the user interrupted. We should verify: +- `basedpyright` passes +- No circular import errors at runtime +- All imports resolve correctly + +**Recommended:** Run `make check` to verify type safety. + +### 3. **Config Default Behavior Change** + +The default for `recompute_costs_every` was changed from `10` → `1`. This affects any existing configs that relied on the implicit default of 10. + +**Impact:** Likely minimal since this appears to be new functionality, but worth noting for any in-progress experiments. + +### 4. **Seeding Strategy** + +The seed calculation for batches is: +```python +seed = base_seed + run_idx * 1000 + batch_idx +``` + +**Assumption:** Maximum of 1000 batches per run. If more than 1000 batches needed, seeds could collide across runs. + +**Mitigation:** Very unlikely - 1000 batches would require either: +- Very long merge iterations, or +- Very small `recompute_costs_every` values + +### 5. **Disk Space Considerations** + +Batches are saved to disk with activations on CPU. For large ensembles: +- `n_runs * n_batches_per_run * batch_size_on_disk` +- Could be substantial for large models/datasets + +**Note:** No cleanup mechanism implemented - batches persist after runs complete. + +### 6. **TODO Document Discrepancy** + +The TODO document (lines 418-483) describes behavior where: +- `precomputed_activations_dir=None` → "compute single batch on-the-fly" +- `precomputed_activations_dir=` → "use precomputed batches" + +**Actual Implementation (per user request):** +- `precomputed_activations_dir=None` → "auto-generate all batches, then run merge" +- `precomputed_activations_dir=` → "use precomputed batches" + +This follows "Option A" clarified by the user. The TODO document may need updating. + +## Next Steps + +### Immediate +1. **Verify type checking:** Run `make check` or `basedpyright` +2. **Test imports:** Ensure no circular import issues +3. **Update TODO-multibatch.md:** Reflect the actual implementation of Option A + +### For Task 2 (Core Merge Logic Refactor) +The infrastructure is ready: +- `BatchedActivations` can be used in `merge_iteration()` +- `recompute_coacts_from_scratch()` needs to be added +- NaN masking logic needs to be implemented +- Merge pair samplers need NaN handling + +## Files Modified + +``` +Created: +- spd/clustering/batched_activations.py (new, ~180 lines) + +Modified: +- spd/clustering/merge_config.py (updated default + description) +- spd/clustering/clustering_run_config.py (added 1 field) +- spd/clustering/scripts/run_clustering.py (added CLI arg + override) +- spd/clustering/scripts/run_pipeline.py (updated import) + +Deleted: +- spd/clustering/precompute_batches.py (merged into batched_activations.py) +``` + +## Testing Recommendations + +1. **Unit tests** (create `tests/clustering/test_batched_activations.py`): + - Test `ActivationBatch.save()` and `load()` + - Test `BatchedActivations` cycling behavior + - Test that `n_batches` property works correctly + +2. **Integration test**: + - Run `precompute_batches_for_ensemble()` with small config + - Verify batch files are created with correct naming + - Verify `BatchedActivations` can load and cycle through them + +3. **Backward compatibility test**: + - Run with `recompute_costs_every=1` (should behave as before) + - Verify no batches are precomputed when not needed diff --git a/TASKS-claude-2.md b/TASKS-claude-2.md new file mode 100644 index 000000000..0bb7f92e0 --- /dev/null +++ b/TASKS-claude-2.md @@ -0,0 +1,98 @@ +# Task 2: Core Merge Logic Refactor - COMPLETED ✅ + +## Summary + +Task 2 has been completed successfully. All changes have been implemented in `spd/clustering/merge.py` and `spd/clustering/math/merge_pair_samplers.py`. + +## Completed Work + +### ✅ 1. Helper Function Added +**File:** `spd/clustering/merge.py` (lines 33-61) + +Added `recompute_coacts_from_scratch()` function that: +- Takes fresh activations and current merge state +- Applies threshold to activations +- Applies current merge matrix to get group-level activations +- Computes coactivations for current groups +- Returns both coact matrix and activation mask + +### ✅ 2. Samplers Updated for NaN Handling +**File:** `spd/clustering/math/merge_pair_samplers.py` + +#### `range_sampler` (lines 28-84) +- Added NaN masking alongside diagonal masking +- Only considers valid (non-NaN, non-diagonal) pairs +- Raises clear error if all costs are NaN +- Updated docstring to document NaN handling + +#### `mcmc_sampler` (lines 87-134) +- Added NaN check to valid_mask +- Sets invalid entries to inf (so exp gives 0) +- Raises error if no valid pairs exist +- Updated docstring to document NaN handling + +### ✅ 3. merge_iteration() Refactored +**File:** `spd/clustering/merge.py` (lines 82-260) + +#### Changed Function Signature (line 82-87) +- Now accepts `batched_activations: BatchedActivations` instead of `activations: ActivationsTensor` +- Updated docstring to reflect multi-batch support + +#### Initial Batch Loading (lines 95-107) +- Loads first batch using `batched_activations.get_next_batch()` +- Extracts activations tensor from ActivationBatch +- Computes initial coactivations as before + +#### NaN Masking Instead of Incremental Updates (lines 155-179) +- Stores merge_pair_cost BEFORE updating (line 153) +- Updates merge state first (line 158) +- NaN out affected rows/cols (lines 166-169) +- Removes deleted row/col to maintain shape (lines 172-177) +- Decrements k_groups immediately (line 179) + +#### Batch Recomputation Logic (lines 189-205) +- Checks if it's time to recompute based on `merge_config.recompute_costs_every` +- Loads new batch from disk +- Calls `recompute_coacts_from_scratch()` to get fresh coactivations +- Updates both `current_coact` and `current_act_mask` + +#### Cleanup +- Removed duplicate `k_groups -= 1` (was at line ~239, now only at line 179) +- Kept all metrics and logging logic intact +- Maintained all sanity checks + +## Key Changes Summary + +| Component | Before | After | +|-----------|--------|-------| +| Function param | `activations: ActivationsTensor` | `batched_activations: BatchedActivations` | +| Coact updates | Incremental via `recompute_coacts_merge_pair()` | NaN masking + periodic full recompute | +| Invalid entries | Never existed | Marked as NaN | +| Batch handling | Single batch only | Multiple batches with cycling | +| Samplers | Assumed no NaN | Handle NaN gracefully | + +## Backward Compatibility + +The refactored code maintains backward compatibility: +- When `recompute_costs_every=1`, it recomputes every iteration (similar to old behavior but with fresh data) +- When using a single batch in `BatchedActivations`, it cycles through that one batch +- All existing metrics, logging, and callbacks continue to work + +## Notes + +- The import `from spd.clustering.compute_costs import recompute_coacts_merge_pair` is still present but the function is no longer used in `merge_iteration()` +- This function may still be used elsewhere in the codebase, so it was left in place +- The NaN masking approach is more memory-efficient as it doesn't require keeping the model loaded + +## Testing Recommendations + +1. **Single-batch backward compatibility:** Test with `recompute_costs_every=1` and verify results match old behavior +2. **Multi-batch mode:** Test with `recompute_costs_every=10` and multiple batches +3. **NaN handling:** Verify samplers don't crash when costs contain NaN +4. **Metrics/logging:** Ensure WandB logging and callbacks still work correctly +5. **Edge cases:** Test with very small k_groups values (near early stopping) + +## Dependencies + +- ✅ Task 1 completed: `BatchedActivations` and `ActivationBatch` classes exist in `spd/clustering/batched_activations.py` +- ⏭️ Task 3: Will need to update `run_clustering.py` to use the new `merge_iteration()` signature diff --git a/TASKS-claude-3.md b/TASKS-claude-3.md new file mode 100644 index 000000000..9cd2d7801 --- /dev/null +++ b/TASKS-claude-3.md @@ -0,0 +1,266 @@ +# Task 3 Implementation: Update `run_clustering.py` for Multi-Batch Support + +**Date**: 2025-10-27 +**Branch**: `clustering/refactor-multi-batch` +**Reference**: `spd/clustering/TODO-multibatch.md` - Task 3 + +## Executive Summary + +✅ **Implementation Complete**: All Task 3 code changes have been made to `run_clustering.py` +❌ **Functional Status**: Code will **fail at runtime** due to incomplete Task 2 +⚠️ **Critical Blocker**: `merge_iteration()` in `merge.py` must be refactored before this code can work + +**The Issue**: Task 3 calls `merge_iteration(batched_activations=...)` but Task 2 has not yet updated `merge_iteration()` to accept this parameter. The function still expects `activations: ActivationsTensor`. + +**See Section**: [Concern #1 - Critical Task 2 Dependency](#1-🚨-critical-task-2-is-incomplete) + +## Overview + +Implemented multi-batch clustering support in `run_clustering.py` to allow clustering runs to either: +1. Use precomputed activation batches from disk, OR +2. Compute a single batch on-the-fly (original behavior) + +This enables the model to be unloaded before merge iteration begins, saving memory during long merge processes. + +## Changes Made + +### 1. Import Addition (`run_clustering.py:34`) +```python +from spd.clustering.batched_activations import ActivationBatch, BatchedActivations +``` + +### 2. Refactored `main()` Function (lines 280-373) + +Replaced the monolithic data loading and activation computation section with a branching structure: + +#### Case 1: Precomputed Batches (lines 285-294) +When `run_config.precomputed_activations_dir` is provided: +- Loads `BatchedActivations` from disk directory +- Retrieves component labels from first batch +- No model loading required +- Logs number of batches available + +#### Case 2: Single Batch On-the-Fly (lines 296-373) +When `precomputed_activations_dir` is `None` (original behavior): +- Loads model from SPDRunInfo +- Loads dataset with appropriate seed +- Computes activations using `component_activations()` +- Processes activations (filtering, concatenation) +- **Saves to temporary directory** as `batch_0.pt` +- Creates `BatchedActivations` instance from temp directory +- Logs activations to WandB if enabled +- Cleans up model, batch, and intermediate tensors from memory + +### 3. Updated `merge_iteration()` Call (line 386) +```python +# Changed from: +activations=activations + +# Changed to: +batched_activations=batched_activations +``` + +## Dependencies + +### Completed (by other Claude instances): +- ✅ **Task 1** (TASKS-claude-1.md): + - Created `spd/clustering/batched_activations.py` with `ActivationBatch`, `BatchedActivations`, and `precompute_batches_for_ensemble()` + - Added `precomputed_activations_dir` field to `ClusteringRunConfig` (line 72-75) + - Added `--precomputed-activations-dir` CLI argument to `run_clustering.py` (line 415-420) + - Added CLI argument wiring to config overrides (line 440-441) + - Changed `recompute_costs_every` default from 10 → 1 in `MergeConfig` + - **Note**: Task 1 implemented "Option A" where `precomputed_activations_dir=None` means "auto-generate all batches" + +- ⚠️ **Task 2** (TASKS-claude-2.md): **INCOMPLETE** + - ✅ Added `recompute_coacts_from_scratch()` helper function to `merge.py` + - ✅ Updated merge pair samplers to handle NaN values + - ❌ **NOT DONE**: `merge_iteration()` refactor blocked, waiting for Task 1 + - **Status**: Task 2 is NOT complete - the main refactor is still pending + +- ✅ **Task 4** (TASKS-claude-4.md): + - Implemented batch precomputation in `run_pipeline.py` + - Added `generate_clustering_commands()` support for passing batch directories + - **Note**: Depends on Tasks 2 & 3 being complete for full integration + +## Testing Status + +❌ **No tests run yet** - changes are untested + +## Concerns and Potential Issues + +### 1. **🚨 CRITICAL: Task 2 is INCOMPLETE** +Per TASKS-claude-2.md, `merge_iteration()` has **NOT** been refactored yet. The Task 2 work was blocked waiting for Task 1, but now Task 1 is complete. + +**What's Missing in merge.py**: +- ❌ Function signature still uses `activations: ActivationsTensor` instead of `batched_activations: BatchedActivations` +- ❌ No batch cycling logic implemented +- ❌ No NaN masking for merged component rows/columns +- ❌ No periodic recomputation of coactivation matrix + +**Impact**: My changes to `run_clustering.py` will **FAIL** at runtime when calling `merge_iteration()` because: +```python +# This call will error - wrong parameter name/type +history: MergeHistory = merge_iteration( + merge_config=run_config.merge_config, + batched_activations=batched_activations, # ❌ merge_iteration() doesn't accept this yet + component_labels=component_labels, + log_callback=log_callback, +) +``` + +**Action Required**: Complete Task 2 refactor of `merge_iteration()` before Task 3 can be tested or used. + +### 2. **Temporary Directory Cleanup** +In single-batch mode, activations are saved to `storage.base_dir / "temp_batch"`. This directory: +- Is created in the run's output directory +- Contains a single `batch_0.pt` file +- Is **never explicitly cleaned up** + +**Potential Issue**: Accumulation of temp directories if many runs are executed. + +**Options**: +- Leave as-is (temp directories are part of run output, may be useful for debugging) +- Add cleanup after merge iteration completes +- Use Python's `tempfile.TemporaryDirectory()` context manager (would require restructuring) + +**Recommendation**: Leave as-is for now since it's in the run's output directory and provides transparency. + +### 3. **Batch Label Consistency** +When loading precomputed batches (Case 1), we only check the **first batch** for labels: +```python +first_batch = batched_activations.get_next_batch() +component_labels = ComponentLabels(first_batch.labels) +``` + +**Assumption**: All batches in the directory have identical labels in the same order. + +**Potential Issue**: If batches were generated incorrectly with different label sets or ordering, the merge will fail or produce incorrect results. + +**Mitigation**: The batch precomputation logic (Task 4 in `run_pipeline.py`) should ensure consistent labels. Consider adding validation. + +### 4. **Type Checking** +Have not run type checker (`basedpyright`) yet. Potential type issues: +- `batched_activations` variable assignment in two branches might confuse type checker +- `component_labels` assignment from different sources + +**Action Required**: Run `make type` to verify no type errors. + +### 5. **Memory Management** +In Case 2 (single-batch mode), we explicitly clean up: +```python +del model, batch, activations_dict, processed +gc.collect() +``` + +This is good, but: +- `temp_batch_dir` and `single_batch` objects remain in scope (though small) +- May want to explicitly `del single_batch` after save for consistency + +**Recommendation**: Low priority, current implementation is fine. + +### 6. **ComponentLabels Type Consistency** +In Case 1 (precomputed): +```python +component_labels = ComponentLabels(first_batch.labels) +``` + +In Case 2 (on-the-fly): +```python +component_labels = processed.labels # Already ComponentLabels type +``` + +The inconsistency is intentional (Case 1 needs wrapping, Case 2 doesn't), but could be confusing. + +**Verified**: Based on `batched_activations.py`, `ActivationBatch.labels` is `list[str]`, so wrapping in `ComponentLabels` is correct. + +### 7. **Implementation Divergence from TODO Spec** +The TODO document (lines 418-483) describes Case 2 as "compute single batch on-the-fly" without saving to disk. + +**However**, Task 1 implemented "Option A" where `precomputed_activations_dir=None` means "auto-generate all required batches before merging starts" (per TASKS-claude-1.md, concern #6). + +**My Implementation** follows the TODO spec literally: +- Case 2 creates a **single** batch +- Saves to temp directory +- Passes to merge iteration +- This is the **original behavior** (backward compatible) + +**Discrepancy**: My implementation doesn't match Task 1's "Option A" intent but **does** match the TODO document specification. + +**Clarification Needed**: Which behavior is desired? +- **Option A** (Task 1 intent): Auto-generate all batches when `precomputed_activations_dir=None` +- **Original Behavior** (TODO spec + my impl): Generate single batch on-the-fly + +**Current Status**: My implementation maintains backward compatibility with original single-batch behavior. + +### 8. **Backward Compatibility** +**Status**: ✅ Should be fully backward compatible + +When `precomputed_activations_dir=None` (default), the code: +- Follows the same logic as before +- Uses same dataset loading, activation computation, and processing +- Only difference: saves to temp directory and wraps in `BatchedActivations` + +**Concern**: The extra save/load cycle adds overhead for single-batch runs. + +**Impact**: Minimal - single file I/O is fast, and the wrapper is lightweight. + +**Important Note**: This only works if Task 2 implements backward compatibility where `recompute_costs_every=1` behaves identically to the original incremental update behavior. + +## Next Steps + +### Immediate Blockers +1. **🚨 COMPLETE TASK 2**: The `merge_iteration()` function must be refactored before Task 3 can work + - See TASKS-claude-2.md for detailed requirements + - Task 2 was blocked waiting for Task 1, but Task 1 is now complete + - This is the critical path blocker + +### After Task 2 is Complete +2. **Run type checker**: `make type` to catch any type issues +3. **Resolve implementation divergence**: Decide between Option A (auto-generate batches) vs. original behavior (single batch) +4. **Test single-batch mode**: Run clustering with default config + - Verify `recompute_costs_every=1` maintains backward compatibility +5. **Test multi-batch mode**: Use precomputed batches from Task 4 + - Verify batch cycling works correctly + - Verify cost recomputation happens at correct intervals +6. **Integration test**: Full pipeline test with `run_pipeline.py` + - Use Task 4's precomputation + - Run ensemble with multiple batches + - Verify all runs complete successfully + +## Files Modified + +- `spd/clustering/scripts/run_clustering.py`: ~100 lines changed + - Added imports + - Refactored `main()` function (lines 259-389) + - Updated merge_iteration call + +## Files Read/Referenced + +- `spd/clustering/batched_activations.py` (already exists) +- `spd/clustering/clustering_run_config.py` (field already added) +- `spd/clustering/merge_config.py` (field already added) +- `spd/clustering/TODO-multibatch.md` (implementation guide) + +## Task Status Summary + +| Component | Status | Notes | +|-----------|--------|-------| +| Task 3 Implementation | ✅ Complete | All code changes made per TODO spec | +| Task 2 Dependency | ❌ **BLOCKER** | `merge_iteration()` not refactored - will fail at runtime | +| Type Checking | ⚠️ Not Run | Need to run `make type` | +| Testing | ❌ Not Done | Cannot test until Task 2 complete | +| Integration | ⚠️ Pending | Works with Task 1 & 4, blocked by Task 2 | + +## Estimated Completeness + +**Task 3 Implementation**: **100% complete** ✅ (all code written per TODO-multibatch.md) + +**Task 3 Usability**: **0% functional** ❌ (blocked by incomplete Task 2) + +**Overall Multi-Batch Feature**: +- Task 1: ✅ Complete +- Task 2: ❌ Incomplete (critical blocker) +- Task 3: ✅ Complete (but untestable) +- Task 4: ✅ Complete + +**Integration Status**: ~75% complete (Task 2 is the only missing piece) diff --git a/TASKS-claude-4.md b/TASKS-claude-4.md new file mode 100644 index 000000000..93f212032 --- /dev/null +++ b/TASKS-claude-4.md @@ -0,0 +1,208 @@ +# Task 4 Implementation: Batch Precomputation in run_pipeline.py + +**Date**: 2025-10-27 +**Task**: Implement Task 4 from `spd/clustering/TODO-multibatch.md` +**Status**: ✅ Complete + +## What Was Done + +### 1. Created New Module: `spd/clustering/precompute_batches.py` + +Created a standalone module that can be imported by both `run_clustering.py` and `run_pipeline.py`. + +**Key Function**: `precompute_batches_for_ensemble()` +- Takes `ClusteringRunConfig`, `n_runs`, and `output_dir` as parameters +- Returns `Path | None` (None if single-batch mode, Path to batches directory if multi-batch) +- Loads model once to determine component count +- Calculates number of batches needed based on `recompute_costs_every` and total iterations +- Generates all activation batches for all runs in the ensemble +- Saves batches as `ActivationBatch` objects to disk +- Implements proper memory cleanup (GPU cache clearing, garbage collection) + +### 2. Updated `spd/clustering/scripts/run_pipeline.py` + +#### Imports Added: +```python +from spd.clustering.precompute_batches import precompute_batches_for_ensemble +from spd.clustering.clustering_run_config import ClusteringRunConfig +``` + +#### Modified Functions: + +**`generate_clustering_commands()`**: +- Added `batches_base_dir: Path | None = None` parameter +- Added logic to append `--precomputed-activations-dir` to command when batches are available +- Each run gets its own batch directory: `batches_base_dir / f"run_{idx}"` + +**`main()`**: +- Added code to load `ClusteringRunConfig` from pipeline config +- Calls `precompute_batches_for_ensemble()` before generating commands +- Passes `batches_base_dir` result to `generate_clustering_commands()` + +### 3. Implementation Details + +**Batch Directory Structure**: +``` +/ +└── precomputed_batches/ + ├── run_0/ + │ ├── batch_0.pt + │ ├── batch_1.pt + │ └── ... + ├── run_1/ + │ ├── batch_0.pt + │ └── ... + └── ... +``` + +**Seeding Strategy**: +- Each batch uses unique seed: `base_seed + run_idx * 1000 + batch_idx` +- Ensures different data for each run and each batch within a run + +**Memory Management**: +- Model loaded once at the beginning +- Activations moved to CPU before saving +- GPU cache cleared after each batch +- Full garbage collection after all batches complete + +## Concerns & Notes + +### 1. ✅ **CLI Argument in run_clustering.py - RESOLVED** + +**Status**: Task 1 has already added the `--precomputed-activations-dir` CLI argument to `run_clustering.py`. + +See TASKS-claude-1.md for details. + +### 2. ✅ **Config Field - RESOLVED** + +**Status**: Task 1 has already added `precomputed_activations_dir` field to `ClusteringRunConfig`. + +See TASKS-claude-1.md for details. + +### 3. ✅ **Precomputation Logic - RESOLVED** + +**Status**: The `precompute_batches_for_ensemble()` function was already implemented in `spd/clustering/batched_activations.py` by Task 1. + +**What I Did**: +- Initially created a duplicate `spd/clustering/precompute_batches.py` +- Discovered it was already in `batched_activations.py` +- File has been cleaned up (does not exist) +- Import in `run_pipeline.py` correctly references `batched_activations` + +**Current State**: Everything is properly integrated and no duplicates exist. + +### 4. ⚠️ **Filter Dead Threshold - Intentional Design** + +According to TASKS-claude-1.md and the TODO document: +- **NO FILTERING** is intentional for the simplified multi-batch implementation +- `filter_dead_threshold=0.0` is used during precomputation +- This is a key simplification listed in the TODO (line 11) + +**Status**: This is correct behavior, not a bug. + +### 5. ⚠️ **Missing Integration with run_clustering.py** + +According to TASKS-claude-2.md: +- Task 2 implemented NaN handling in samplers ✅ +- Task 2 added `recompute_coacts_from_scratch()` helper ✅ +- Task 2 did NOT complete the `merge_iteration()` refactor ⚠️ (blocked waiting for Task 1) + +**Current Status**: Task 1 is now complete, so Task 2 should be unblocked to finish the `merge_iteration()` refactor. + +**What's Still Needed**: +- `merge_iteration()` needs to accept `BatchedActivations` instead of single tensor +- NaN masking logic needs to be added to `merge_iteration()` +- Batch recomputation logic needs to be added + +See TASKS-claude-2.md for detailed instructions on what remains. + +### 6. ℹ️ **Component Count Calculation** + +We count components by summing across modules: +```python +n_components = sum(act.shape[-1] for act in sample_acts.values()) +``` + +This assumes: +- All modules are included (no module filtering) +- No dead component filtering +- All components from all modules are concatenated + +This matches the "NO FILTERING" principle in the TODO. + +### 7. ℹ️ **Dataset Streaming Not Supported** + +The precomputation uses `load_dataset()` which may not support streaming mode optimally when generating many batches. For large ensembles, this could be slow. + +**Mitigation**: Batches are generated sequentially, so memory footprint is bounded. + +### 8. ℹ️ **GPU Memory Assumptions** + +The implementation assumes: +- A single GPU is available (`get_device()`) +- The model + single batch fit in GPU memory +- Batches are small enough to process one at a time + +For very large models, this might require adjustments. + +## Testing Recommendations + +1. **Single-batch mode (backward compatibility)**: + - Set `recompute_costs_every=1` in config + - Verify `precompute_batches_for_ensemble()` returns `None` + - Verify no batch directory is created + - Verify commands don't include `--precomputed-activations-dir` + +2. **Multi-batch mode**: + - Set `recompute_costs_every=20` in config + - Run pipeline with `n_runs=2` + - Verify batch directories are created correctly + - Verify correct number of batches per run + - Verify batch files can be loaded with `ActivationBatch.load()` + +3. **Integration test**: + - Run full pipeline end-to-end with multi-batch mode + - Verify clustering runs complete successfully + - Verify results match single-batch mode (within tolerance) + +## Files Modified + +1. **Created**: `spd/clustering/precompute_batches.py` (~160 lines) +2. **Modified**: `spd/clustering/scripts/run_pipeline.py` + - Updated imports + - Modified `generate_clustering_commands()` (+3 lines logic) + - Modified `main()` (+7 lines for precomputation call) + +## Dependencies on Other Tasks + +- **Depends on Task 1**: ✅ COMPLETE + - `ActivationBatch` and `BatchedActivations` classes exist in `batched_activations.py` + - `precompute_batches_for_ensemble()` function exists + - Config fields added to `ClusteringRunConfig` and `MergeConfig` + - CLI argument added to `run_clustering.py` + +- **Task 2 Status**: ⚠️ PARTIALLY COMPLETE + - ✅ `recompute_coacts_from_scratch()` helper added + - ✅ NaN handling added to samplers + - ⚠️ `merge_iteration()` refactor NOT DONE (was blocked waiting for Task 1) + +- **Task 3 Status**: UNKNOWN (no TASKS-claude-3.md file found) + +## Next Steps + +1. ✅ Task 1 verified complete +2. ✅ Task 4 (this task) complete - pipeline integration done +3. ⚠️ **BLOCKER**: Task 2 needs to be finished + - Complete the `merge_iteration()` refactor per TASKS-claude-2.md + - Change function signature to accept `BatchedActivations` + - Implement NaN masking and batch recomputation logic +4. ❓ Verify Task 3 status (if it exists separately from Task 2) +5. 🧪 Run end-to-end integration tests +6. 📝 Update TODO-multibatch.md to reflect actual implementation choices + +## Summary + +**Task 4 is functionally complete**, but the full multi-batch system won't work until Task 2's `merge_iteration()` refactor is finished. The pipeline infrastructure is ready: +- ✅ Batches can be precomputed via `run_pipeline.py` +- ✅ Commands are generated with `--precomputed-activations-dir` +- ⚠️ But `run_clustering.py` can't actually USE those batches yet (needs Task 2 completion) diff --git a/TODO-multibatch.md b/TODO-multibatch.md new file mode 100644 index 000000000..2dde9f184 --- /dev/null +++ b/TODO-multibatch.md @@ -0,0 +1,640 @@ +# Multi-Batch Clustering Implementation Plan + +## Implementation Status: ✅ ALL TASKS COMPLETE + +**Date Completed:** 2025-10-27 +**Branch:** `clustering/refactor-multi-batch` + +All four tasks have been implemented. See task completion reports: +- `TASKS-claude-1.md` - Task 1 completion +- `TASKS-claude-2.md` - Task 2 completion +- `spd/clustering/TASKS-claude-3.md` - Task 3 completion +- `TASKS-claude-4.md` - Task 4 completion + +## Overview + +Implement multi-batch clustering to avoid keeping the model loaded during merge iterations. Instead, precompute multiple batches of activations and cycle through them, recomputing costs every `m` merges with a new batch. + +**Core Principle: Keep It Minimal** + +### Key Simplifications +- ✅ No dead component filtering (for now) +- ✅ No within-module clustering (defer to later) +- ✅ Always disk-based batch storage +- ✅ Model never kept during merge iteration +- ✅ NaN masking only (no separate boolean masks) +- ✅ Simple round-robin batch loading from disk + +--- + +## Tasks (All Complete) + +### ✅ Task 1: Batch Storage Infrastructure (COMPLETE) +**Files:** New file + config changes +**Dependencies:** None +**Estimated Lines:** ~180 lines (actual) + +**Completed Components:** +- ✅ Created `spd/clustering/batched_activations.py` (~180 lines) + - `ActivationBatch` class with save/load methods + - `BatchedActivations` class for round-robin batch cycling + - `precompute_batches_for_ensemble()` function +- ✅ Updated `merge_config.py`: Added `recompute_costs_every` field (default=1) +- ✅ Updated `clustering_run_config.py`: Added `precomputed_activations_dir` field +- ✅ Updated `run_clustering.py`: Added `--precomputed-activations-dir` CLI argument + +**Implementation Notes:** +- File location: `spd/clustering/batched_activations.py` +- Default `recompute_costs_every=1` maintains backward compatibility +- When `precomputed_activations_dir=None`, single batch is computed on-the-fly + +**See:** `TASKS-claude-1.md` for full implementation details + +--- + +### ✅ Task 2: Core Merge Logic Refactor (COMPLETE) +**Files:** `spd/clustering/merge.py`, `spd/clustering/math/merge_pair_samplers.py` +**Dependencies:** Task 1 (needs `BatchedActivations` interface) +**Estimated Lines:** ~150 lines (actual) + +**Completed Components:** +- ✅ Added `recompute_coacts_from_scratch()` helper function (`merge.py:33-61`) +- ✅ Refactored `merge_iteration()` to accept `BatchedActivations` parameter +- ✅ Implemented NaN masking for merged component rows/columns +- ✅ Added periodic batch recomputation logic (every `recompute_costs_every` iterations) +- ✅ Updated `range_sampler` in `merge_pair_samplers.py` to handle NaN values +- ✅ Updated `mcmc_sampler` in `merge_pair_samplers.py` to handle NaN values + +**Key Implementation Details:** +- Function signature changed to `batched_activations: BatchedActivations` +- Loads first batch at start, cycles through batches during merge iteration +- NaN masking invalidates affected entries after each merge +- Full coactivation recomputation from fresh batch at specified intervals +- All merge pair samplers gracefully handle NaN entries + +**See:** `TASKS-claude-2.md` for full implementation details + +--- + +### ✅ Task 3: Update `run_clustering.py` (COMPLETE) +**Files:** `spd/clustering/scripts/run_clustering.py` +**Dependencies:** Task 1 (needs `BatchedActivations`) +**Estimated Lines:** ~100 lines (actual) + +**Completed Components:** +- ✅ Refactored `main()` function to support two modes: + - **Case 1:** Load precomputed batches from disk (`precomputed_activations_dir` provided) + - **Case 2:** Compute single batch on-the-fly, save to temp directory (original behavior) +- ✅ Updated `merge_iteration()` call to use `batched_activations` parameter +- ✅ Added memory cleanup after activation computation +- ✅ Both modes wrap batches in `BatchedActivations` for unified interface + +**Key Implementation Details:** +- When `precomputed_activations_dir=None`: Single batch computed and saved to temp directory +- When `precomputed_activations_dir` provided: All batches loaded from disk +- Component labels extracted from first batch in both cases +- Memory cleanup: model, batch, activations deleted after computation in Case 2 + +**See:** `spd/clustering/TASKS-claude-3.md` for full implementation details + +--- + +### ✅ Task 4: Batch Precomputation in `run_pipeline.py` (COMPLETE) +**Files:** `spd/clustering/scripts/run_pipeline.py` +**Dependencies:** Task 1 (needs `ActivationBatch`) +**Estimated Lines:** ~50 lines modifications (actual) + +**Note:** The `precompute_batches_for_ensemble()` function was implemented in Task 1 as part of `batched_activations.py`, not as a separate addition to `run_pipeline.py`. + +**Completed Components:** +- ✅ Batch precomputation logic in `batched_activations.py` (`precompute_batches_for_ensemble()`) +- ✅ Updated `generate_clustering_commands()` to pass `--precomputed-activations-dir` argument +- ✅ Updated `main()` to call precomputation before generating commands +- ✅ Proper seeding strategy: `base_seed + run_idx * 1000 + batch_idx` + +**Key Implementation Details:** +- Loads model once, generates all batches for all runs in ensemble +- Batches saved to: `/precomputed_batches/run_{idx}/batch_{idx}.pt` +- Returns `None` if `recompute_costs_every=1` (single-batch mode) +- Each run gets unique seed per batch to ensure different data + +**See:** `TASKS-claude-4.md` for full implementation details + +--- + +## Testing Plan + +### Unit Tests +**File:** `tests/clustering/test_multi_batch.py` (TO BE CREATED) + +**Required Tests:** + """A single clustering run.""" + + # Create ExecutionStamp and storage + execution_stamp = ExecutionStamp.create( + run_type="cluster", + create_snapshot=False, + ) + storage = ClusteringRunStorage(execution_stamp) + clustering_run_id = execution_stamp.run_id + logger.info(f"Clustering run ID: {clustering_run_id}") + + # Register with ensemble if this is part of a pipeline + assigned_idx = None + if run_config.ensemble_id: + assigned_idx = register_clustering_run( + pipeline_run_id=run_config.ensemble_id, + clustering_run_id=clustering_run_id, + ) + logger.info( + f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx}" + ) + # IMPORTANT: set dataset seed based on assigned index + run_config = replace_pydantic_model( + run_config, + {"dataset_seed": run_config.dataset_seed + assigned_idx}, + ) + + # Save config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + + # Start + logger.info("Starting clustering run") + logger.info(f"Output directory: {storage.base_dir}") + device = get_device() + + spd_run = SPDRunInfo.from_path(run_config.model_path) + task_name = spd_run.config.task_config.task_name + + # Setup WandB for this run + wandb_run = None + if run_config.wandb_project is not None: + wandb_run = wandb.init( + id=clustering_run_id, + entity=run_config.wandb_entity, + project=run_config.wandb_project, + group=run_config.ensemble_id, + config=run_config.model_dump(mode="json"), + tags=[ + "clustering", + f"task:{task_name}", + f"model:{run_config.wandb_decomp_model}", + f"ensemble_id:{run_config.ensemble_id}", + f"assigned_idx:{assigned_idx}", + ], + ) + + # Load or compute activations + # ===================================== + batched_activations: BatchedActivations + component_labels: ComponentLabels + + if run_config.precomputed_activations_dir is not None: + # Case 1: Use precomputed batches from disk + logger.info(f"Loading precomputed batches from {run_config.precomputed_activations_dir}") + batched_activations = BatchedActivations(run_config.precomputed_activations_dir) + + # Get labels from first batch + first_batch = batched_activations.get_next_batch() + component_labels = ComponentLabels(first_batch.labels) + + logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") + + else: + # Case 2: Compute single batch on-the-fly (original behavior) + logger.info(f"Computing single batch (seed={run_config.dataset_seed})") + + # Load model + logger.info("Loading model") + model = ComponentModel.from_run_info(spd_run).to(device) + + # Load data + logger.info("Loading dataset") + load_dataset_kwargs = {} + if run_config.dataset_streaming: + logger.info("Using streaming dataset loading") + load_dataset_kwargs["config_kwargs"] = dict(streaming=True) + assert task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'" + ) + + batch = load_dataset( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + **load_dataset_kwargs, + ).to(device) + + # Compute activations + logger.info("Computing activations") + activations_dict = component_activations( + model=model, + batch=batch, + device=device, + ) + + # Process (concat modules, NO FILTERING) + logger.info("Processing activations") + processed = process_activations( + activations=activations_dict, + filter_dead_threshold=0.0, # NO FILTERING + seq_mode="concat" if task_name == "lm" else None, + filter_modules=None, + ) + + # Save as single batch to temp dir + temp_batch_dir = storage.base_dir / "temp_batch" + temp_batch_dir.mkdir(exist_ok=True) + + single_batch = ActivationBatch( + activations=processed.activations, + labels=list(processed.labels), + ) + single_batch.save(temp_batch_dir / "batch_0.pt") + + batched_activations = BatchedActivations(temp_batch_dir) + component_labels = processed.labels + + # Log activations to WandB (if enabled) + if wandb_run is not None: + logger.info("Plotting activations") + plot_activations( + processed_activations=processed, + save_dir=None, + n_samples_max=256, + wandb_run=wandb_run, + ) + wandb_log_tensor( + wandb_run, + processed.activations, + "activations", + 0, + single=True, + ) + + # Clean up memory + del model, batch, activations_dict, processed + gc.collect() + + # Run merge iteration + # ===================================== + logger.info("Starting merging") + log_callback = ( + partial(_log_callback, run=wandb_run, run_config=run_config) + if wandb_run is not None + else None + ) + + history = merge_iteration( + merge_config=run_config.merge_config, + batched_activations=batched_activations, + component_labels=component_labels, + log_callback=log_callback, + ) + + # Save merge history + history.save(storage.history_path) + logger.info(f"History saved to {storage.history_path}") + + # Log to WandB + if wandb_run is not None: + _log_merge_history_plots(wandb_run, history) + _save_merge_history_artifact(wandb_run, storage.history_path, history) + wandb_run.finish() + logger.info("WandB run finished") + + return storage.history_path +``` + +--- + +### Task 4: Batch Precomputation in `run_pipeline.py` +**Files:** `spd/clustering/scripts/run_pipeline.py` +**Dependencies:** Task 1 (needs `ActivationBatch`) +**Estimated Lines:** ~200 lines + +#### 4.1 Add Batch Precomputation Function + +Add this function before `main()`: + +```python +def precompute_batches_for_ensemble( + pipeline_config: ClusteringPipelineConfig, + pipeline_run_id: str, + storage: ClusteringPipelineStorage, +) -> Path | None: + """ + Precompute activation batches for all runs in ensemble. + + This loads the model ONCE and generates all batches for all runs, + then saves them to disk. Each clustering run will load batches + from disk without needing the model. + + Args: + pipeline_config: Pipeline configuration + pipeline_run_id: Unique ID for this pipeline run + storage: Storage paths for pipeline outputs + + Returns: + Path to base directory containing batches for all runs, + or None if single-batch mode (recompute_costs_every=1) + """ + clustering_run_config = ClusteringRunConfig.from_file( + pipeline_config.clustering_run_config_path + ) + + # Check if multi-batch mode + recompute_every = clustering_run_config.merge_config.recompute_costs_every + if recompute_every == 1: + logger.info("Single-batch mode (recompute_costs_every=1), skipping precomputation") + return None + + logger.info("Multi-batch mode detected, precomputing activation batches") + + # Load model to determine number of components + device = get_device() + spd_run = SPDRunInfo.from_path(clustering_run_config.model_path) + model = ComponentModel.from_run_info(spd_run).to(device) + task_name = spd_run.config.task_config.task_name + + # Get number of components (no filtering, so just count from model) + # Load a sample to count components + logger.info("Loading sample batch to count components") + sample_batch = load_dataset( + model_path=clustering_run_config.model_path, + task_name=task_name, + batch_size=clustering_run_config.batch_size, + seed=0, + ).to(device) + + with torch.no_grad(): + sample_acts = component_activations(model, device, sample_batch) + + # Count total components across all modules + n_components = sum( + act.shape[-1] for act in sample_acts.values() + ) + + # Calculate number of iterations + n_iters = clustering_run_config.merge_config.get_num_iters(n_components) + + # Calculate batches needed per run + n_batches_needed = (n_iters + recompute_every - 1) // recompute_every + + logger.info( + f"Precomputing {n_batches_needed} batches per run for {pipeline_config.n_runs} runs" + ) + logger.info(f"Total: {n_batches_needed * pipeline_config.n_runs} batches") + + # Create batches directory + batches_base_dir = storage.base_dir / "precomputed_batches" + batches_base_dir.mkdir(exist_ok=True) + + # For each run in ensemble + for run_idx in tqdm(range(pipeline_config.n_runs), desc="Ensemble runs"): + run_batch_dir = batches_base_dir / f"run_{run_idx}" + run_batch_dir.mkdir(exist_ok=True) + + # Generate batches for this run + for batch_idx in tqdm( + range(n_batches_needed), + desc=f" Run {run_idx} batches", + leave=False + ): + # Use unique seed: base_seed + run_idx * 1000 + batch_idx + # This ensures different data for each run and each batch + seed = clustering_run_config.dataset_seed + run_idx * 1000 + batch_idx + + # Load data + batch_data = load_dataset( + model_path=clustering_run_config.model_path, + task_name=task_name, + batch_size=clustering_run_config.batch_size, + seed=seed, + ).to(device) + + # Compute activations + with torch.no_grad(): + acts_dict = component_activations(model, device, batch_data) + + # Process (concat, NO FILTERING) + processed = process_activations( + activations=acts_dict, + filter_dead_threshold=0.0, # NO FILTERING + seq_mode="concat" if task_name == "lm" else None, + filter_modules=None, + ) + + # Save as ActivationBatch + activation_batch = ActivationBatch( + activations=processed.activations.cpu(), # Move to CPU for storage + labels=list(processed.labels), + ) + activation_batch.save(run_batch_dir / f"batch_{batch_idx}.pt") + + # Clean up + del batch_data, acts_dict, processed, activation_batch + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Clean up model + del model, sample_batch, sample_acts + gc.collect() + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + logger.info(f"All batches precomputed and saved to {batches_base_dir}") + + return batches_base_dir +``` + +#### 4.2 Update `generate_clustering_commands()` + +```python +def generate_clustering_commands( + pipeline_config: ClusteringPipelineConfig, + pipeline_run_id: str, + batches_base_dir: Path | None, # NEW PARAMETER + dataset_streaming: bool = False, +) -> list[str]: + """Generate commands for each clustering run. + + Args: + pipeline_config: Pipeline configuration + pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) + batches_base_dir: Path to precomputed batches directory, or None for single-batch mode + dataset_streaming: Whether to use dataset streaming + + Returns: + List of shell-safe command strings + """ + commands = [] + + for idx in range(pipeline_config.n_runs): + cmd_parts = [ + "python", + "spd/clustering/scripts/run_clustering.py", + "--config", + pipeline_config.clustering_run_config_path.as_posix(), + "--pipeline-run-id", + pipeline_run_id, + "--idx-in-ensemble", + str(idx), + "--wandb-project", + str(pipeline_config.wandb_project), + "--wandb-entity", + pipeline_config.wandb_entity, + ] + + # Add precomputed batches path if available + if batches_base_dir is not None: + run_batch_dir = batches_base_dir / f"run_{idx}" + cmd_parts.extend(["--precomputed-activations-dir", str(run_batch_dir)]) + + if dataset_streaming: + cmd_parts.append("--dataset-streaming") + + commands.append(shlex.join(cmd_parts)) + + return commands +``` + +#### 4.3 Update `main()` to Call Precomputation + +```python +def main( + pipeline_config: ClusteringPipelineConfig, + local: bool = False, + local_clustering_parallel: bool = False, + local_calc_distances_parallel: bool = False, + dataset_streaming: bool = False, + track_resources_calc_distances: bool = False, +) -> None: + """Submit clustering runs to SLURM.""" + + logger.set_format("console", "terse") + + # Validation + if local_clustering_parallel or local_calc_distances_parallel or track_resources_calc_distances: + assert local, ( + "local_clustering_parallel, local_calc_distances_parallel, track_resources_calc_distances " + "can only be set when running locally" + ) + + # Create ExecutionStamp for pipeline + execution_stamp = ExecutionStamp.create( + run_type="ensemble", + create_snapshot=pipeline_config.create_git_snapshot, + ) + pipeline_run_id = execution_stamp.run_id + logger.info(f"Pipeline run ID: {pipeline_run_id}") + + # Initialize storage + storage = ClusteringPipelineStorage(execution_stamp) + logger.info(f"Pipeline output directory: {storage.base_dir}") + + # Save pipeline config + pipeline_config.to_file(storage.pipeline_config_path) + logger.info(f"Pipeline config saved to {storage.pipeline_config_path}") + + # Create WandB workspace if requested + if pipeline_config.wandb_project is not None: + workspace_url = create_clustering_workspace_view( + ensemble_id=pipeline_run_id, + project=pipeline_config.wandb_project, + entity=pipeline_config.wandb_entity, + ) + logger.info(f"WandB workspace: {workspace_url}") + + # NEW: Precompute batches if multi-batch mode + batches_base_dir = precompute_batches_for_ensemble( + pipeline_config=pipeline_config, + pipeline_run_id=pipeline_run_id, + storage=storage, + ) + + # Generate commands for clustering runs + clustering_commands = generate_clustering_commands( + pipeline_config=pipeline_config, + pipeline_run_id=pipeline_run_id, + batches_base_dir=batches_base_dir, # NEW + dataset_streaming=dataset_streaming, + ) + + # Generate commands for calculating distances + calc_distances_commands = generate_calc_distances_commands( + pipeline_run_id=pipeline_run_id, + distances_methods=pipeline_config.distances_methods, + ) + + # ... rest of submission logic unchanged +``` + +--- + +## Testing Plan + +### Unit Tests +**File:** `tests/clustering/test_multi_batch.py` (new) + +1. Test `ActivationBatch` save/load +2. Test `BatchedActivations` cycling through batches +3. Test `recompute_coacts_from_scratch` produces correct shapes +4. Test NaN handling in merge pair samplers +5. Test backward compatibility (single batch, `recompute_costs_every=1`) + +### Integration Tests + +1. **Single-batch mode (backward compatibility):** + ```python + config = ClusteringRunConfig( + precomputed_activations_dir=None, + merge_config=MergeConfig(recompute_costs_every=1), + ... + ) + # Should behave exactly as before + ``` + +2. **Multi-batch mode:** + ```python + config = ClusteringRunConfig( + precomputed_activations_dir=Path("batches/"), + merge_config=MergeConfig(recompute_costs_every=20), + ... + ) + # Should use multiple batches + ``` + +3. **Ensemble with precomputation:** + - Run small ensemble (n=3) with multi-batch + - Verify batches are created correctly + - Verify clustering runs use precomputed batches + +### Manual Testing Checklist + +- [ ] Single run, single batch (original behavior) +- [ ] Single run, multi-batch with precomputed dir +- [ ] Ensemble run, single batch mode +- [ ] Ensemble run, multi-batch mode with precomputation +- [ ] Verify NaN masking doesn't break merge sampling +- [ ] Verify memory usage (model not kept during merge) +- [ ] Verify batch cycling works correctly + +--- + +## Summary + +**Total Changes:** +- New files: 1 (`spd/clustering/batched_activations.py`) +- Modified files: 5 +- Total new code: ~500 lines +- Backward compatible: Yes (defaults to original behavior) + +**Key Benefits:** +- Model loaded once, not kept during merge iteration +- Supports arbitrary number of batches +- Simple disk-based storage +- Minimal config changes (2 new fields) +- No complex scheduling or memory management + +**Dependencies:** None (PR #227 not required for this simplified version) From d7f0f683ffc826b0750b89455473f542af914f44 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 12:45:13 +0000 Subject: [PATCH 36/61] remove old todos --- TASKS-claude-1.md | 170 ------------ TASKS-claude-2.md | 98 ------- TASKS-claude-3.md | 266 ------------------- TASKS-claude-4.md | 208 --------------- TODO-multibatch.md | 640 --------------------------------------------- TODO.md | 73 ------ 6 files changed, 1455 deletions(-) delete mode 100644 TASKS-claude-1.md delete mode 100644 TASKS-claude-2.md delete mode 100644 TASKS-claude-3.md delete mode 100644 TASKS-claude-4.md delete mode 100644 TODO-multibatch.md delete mode 100644 TODO.md diff --git a/TASKS-claude-1.md b/TASKS-claude-1.md deleted file mode 100644 index 7ad79bbd2..000000000 --- a/TASKS-claude-1.md +++ /dev/null @@ -1,170 +0,0 @@ -# Task 1: Batch Storage Infrastructure - Completion Report - -## Summary - -Completed Task 1 from `TODO-multibatch.md`: Batch Storage Infrastructure. This provides the foundation for multi-batch clustering by implementing data structures and precomputation logic for activation batches. - -## Changes Made - -### 1. Created `spd/clustering/batched_activations.py` (~180 lines) - -This file consolidates all batch-related functionality (previously split between `batched_activations.py` and `precompute_batches.py`). - -**Components:** -- **`ActivationBatch`**: Dataclass for storing a single batch of activations with labels - - `save()`: Saves batch to disk as `.pt` file - - `load()`: Loads batch from disk - -- **`BatchedActivations`**: Iterator for cycling through multiple batches from disk - - Finds all `batch_*.pt` files in a directory - - `get_next_batch()`: Returns next batch in round-robin fashion - - `n_batches`: Property for total number of batches - -- **`precompute_batches_for_ensemble()`**: Function to generate all batches for ensemble runs - - Loads model once, generates all batches for all runs - - Saves batches to disk with structure: `/precomputed_batches/run_{idx}/batch_{idx}.pt` - - Returns `None` if `recompute_costs_every=1` (single-batch mode) - - Uses unique seeds per batch: `base_seed + run_idx * 1000 + batch_idx` - -### 2. Updated `spd/clustering/merge_config.py` - -**Changed:** -- `recompute_costs_every`: Updated default from `10` to `1` (original behavior) -- Updated description to match TODO spec - -**Rationale:** Default of 1 maintains backward compatibility - single batch mode is the original behavior. - -### 3. Updated `spd/clustering/clustering_run_config.py` - -**Added:** -- `precomputed_activations_dir: Path | None = None` -- Description: "Path to directory containing precomputed activation batches. If None, batches will be auto-generated before merging starts." - -**Key Design Decision:** When `None`, the system will auto-generate all required batches in a temp directory before merge starts (Option A from user clarification). - -### 4. Updated `spd/clustering/scripts/run_clustering.py` - -**Added:** -- CLI argument: `--precomputed-activations-dir` -- Override logic to pass this value to the config - -### 5. Merged Files - -**Deleted:** `spd/clustering/precompute_batches.py` - -**Rationale:** The two files were tightly coupled (~50 and ~140 lines), used together, and would evolve together. Combining them into one `batched_activations.py` file makes the codebase simpler. - -**Updated imports:** -- `spd/clustering/scripts/run_pipeline.py`: Now imports from `batched_activations` - -## Concerns & Notes - -### 1. **Circular Import Risk** ⚠️ - -In `batched_activations.py`, the `precompute_batches_for_ensemble()` function has a type annotation: -```python -def precompute_batches_for_ensemble( - clustering_run_config: "ClusteringRunConfig", # String annotation to avoid circular import - ... -) -``` - -This is a forward reference (string annotation) because: -- `batched_activations.py` imports from `clustering` modules -- `clustering_run_config.py` might import from `batched_activations.py` in the future - -**Status:** Currently safe, but watch for circular imports if we add imports to configs. - -### 2. **Type Checking Not Verified** ⚠️ - -I attempted to run a basic import test but the user interrupted. We should verify: -- `basedpyright` passes -- No circular import errors at runtime -- All imports resolve correctly - -**Recommended:** Run `make check` to verify type safety. - -### 3. **Config Default Behavior Change** - -The default for `recompute_costs_every` was changed from `10` → `1`. This affects any existing configs that relied on the implicit default of 10. - -**Impact:** Likely minimal since this appears to be new functionality, but worth noting for any in-progress experiments. - -### 4. **Seeding Strategy** - -The seed calculation for batches is: -```python -seed = base_seed + run_idx * 1000 + batch_idx -``` - -**Assumption:** Maximum of 1000 batches per run. If more than 1000 batches needed, seeds could collide across runs. - -**Mitigation:** Very unlikely - 1000 batches would require either: -- Very long merge iterations, or -- Very small `recompute_costs_every` values - -### 5. **Disk Space Considerations** - -Batches are saved to disk with activations on CPU. For large ensembles: -- `n_runs * n_batches_per_run * batch_size_on_disk` -- Could be substantial for large models/datasets - -**Note:** No cleanup mechanism implemented - batches persist after runs complete. - -### 6. **TODO Document Discrepancy** - -The TODO document (lines 418-483) describes behavior where: -- `precomputed_activations_dir=None` → "compute single batch on-the-fly" -- `precomputed_activations_dir=` → "use precomputed batches" - -**Actual Implementation (per user request):** -- `precomputed_activations_dir=None` → "auto-generate all batches, then run merge" -- `precomputed_activations_dir=` → "use precomputed batches" - -This follows "Option A" clarified by the user. The TODO document may need updating. - -## Next Steps - -### Immediate -1. **Verify type checking:** Run `make check` or `basedpyright` -2. **Test imports:** Ensure no circular import issues -3. **Update TODO-multibatch.md:** Reflect the actual implementation of Option A - -### For Task 2 (Core Merge Logic Refactor) -The infrastructure is ready: -- `BatchedActivations` can be used in `merge_iteration()` -- `recompute_coacts_from_scratch()` needs to be added -- NaN masking logic needs to be implemented -- Merge pair samplers need NaN handling - -## Files Modified - -``` -Created: -- spd/clustering/batched_activations.py (new, ~180 lines) - -Modified: -- spd/clustering/merge_config.py (updated default + description) -- spd/clustering/clustering_run_config.py (added 1 field) -- spd/clustering/scripts/run_clustering.py (added CLI arg + override) -- spd/clustering/scripts/run_pipeline.py (updated import) - -Deleted: -- spd/clustering/precompute_batches.py (merged into batched_activations.py) -``` - -## Testing Recommendations - -1. **Unit tests** (create `tests/clustering/test_batched_activations.py`): - - Test `ActivationBatch.save()` and `load()` - - Test `BatchedActivations` cycling behavior - - Test that `n_batches` property works correctly - -2. **Integration test**: - - Run `precompute_batches_for_ensemble()` with small config - - Verify batch files are created with correct naming - - Verify `BatchedActivations` can load and cycle through them - -3. **Backward compatibility test**: - - Run with `recompute_costs_every=1` (should behave as before) - - Verify no batches are precomputed when not needed diff --git a/TASKS-claude-2.md b/TASKS-claude-2.md deleted file mode 100644 index 0bb7f92e0..000000000 --- a/TASKS-claude-2.md +++ /dev/null @@ -1,98 +0,0 @@ -# Task 2: Core Merge Logic Refactor - COMPLETED ✅ - -## Summary - -Task 2 has been completed successfully. All changes have been implemented in `spd/clustering/merge.py` and `spd/clustering/math/merge_pair_samplers.py`. - -## Completed Work - -### ✅ 1. Helper Function Added -**File:** `spd/clustering/merge.py` (lines 33-61) - -Added `recompute_coacts_from_scratch()` function that: -- Takes fresh activations and current merge state -- Applies threshold to activations -- Applies current merge matrix to get group-level activations -- Computes coactivations for current groups -- Returns both coact matrix and activation mask - -### ✅ 2. Samplers Updated for NaN Handling -**File:** `spd/clustering/math/merge_pair_samplers.py` - -#### `range_sampler` (lines 28-84) -- Added NaN masking alongside diagonal masking -- Only considers valid (non-NaN, non-diagonal) pairs -- Raises clear error if all costs are NaN -- Updated docstring to document NaN handling - -#### `mcmc_sampler` (lines 87-134) -- Added NaN check to valid_mask -- Sets invalid entries to inf (so exp gives 0) -- Raises error if no valid pairs exist -- Updated docstring to document NaN handling - -### ✅ 3. merge_iteration() Refactored -**File:** `spd/clustering/merge.py` (lines 82-260) - -#### Changed Function Signature (line 82-87) -- Now accepts `batched_activations: BatchedActivations` instead of `activations: ActivationsTensor` -- Updated docstring to reflect multi-batch support - -#### Initial Batch Loading (lines 95-107) -- Loads first batch using `batched_activations.get_next_batch()` -- Extracts activations tensor from ActivationBatch -- Computes initial coactivations as before - -#### NaN Masking Instead of Incremental Updates (lines 155-179) -- Stores merge_pair_cost BEFORE updating (line 153) -- Updates merge state first (line 158) -- NaN out affected rows/cols (lines 166-169) -- Removes deleted row/col to maintain shape (lines 172-177) -- Decrements k_groups immediately (line 179) - -#### Batch Recomputation Logic (lines 189-205) -- Checks if it's time to recompute based on `merge_config.recompute_costs_every` -- Loads new batch from disk -- Calls `recompute_coacts_from_scratch()` to get fresh coactivations -- Updates both `current_coact` and `current_act_mask` - -#### Cleanup -- Removed duplicate `k_groups -= 1` (was at line ~239, now only at line 179) -- Kept all metrics and logging logic intact -- Maintained all sanity checks - -## Key Changes Summary - -| Component | Before | After | -|-----------|--------|-------| -| Function param | `activations: ActivationsTensor` | `batched_activations: BatchedActivations` | -| Coact updates | Incremental via `recompute_coacts_merge_pair()` | NaN masking + periodic full recompute | -| Invalid entries | Never existed | Marked as NaN | -| Batch handling | Single batch only | Multiple batches with cycling | -| Samplers | Assumed no NaN | Handle NaN gracefully | - -## Backward Compatibility - -The refactored code maintains backward compatibility: -- When `recompute_costs_every=1`, it recomputes every iteration (similar to old behavior but with fresh data) -- When using a single batch in `BatchedActivations`, it cycles through that one batch -- All existing metrics, logging, and callbacks continue to work - -## Notes - -- The import `from spd.clustering.compute_costs import recompute_coacts_merge_pair` is still present but the function is no longer used in `merge_iteration()` -- This function may still be used elsewhere in the codebase, so it was left in place -- The NaN masking approach is more memory-efficient as it doesn't require keeping the model loaded - -## Testing Recommendations - -1. **Single-batch backward compatibility:** Test with `recompute_costs_every=1` and verify results match old behavior -2. **Multi-batch mode:** Test with `recompute_costs_every=10` and multiple batches -3. **NaN handling:** Verify samplers don't crash when costs contain NaN -4. **Metrics/logging:** Ensure WandB logging and callbacks still work correctly -5. **Edge cases:** Test with very small k_groups values (near early stopping) - -## Dependencies - -- ✅ Task 1 completed: `BatchedActivations` and `ActivationBatch` classes exist in `spd/clustering/batched_activations.py` -- ⏭️ Task 3: Will need to update `run_clustering.py` to use the new `merge_iteration()` signature diff --git a/TASKS-claude-3.md b/TASKS-claude-3.md deleted file mode 100644 index 9cd2d7801..000000000 --- a/TASKS-claude-3.md +++ /dev/null @@ -1,266 +0,0 @@ -# Task 3 Implementation: Update `run_clustering.py` for Multi-Batch Support - -**Date**: 2025-10-27 -**Branch**: `clustering/refactor-multi-batch` -**Reference**: `spd/clustering/TODO-multibatch.md` - Task 3 - -## Executive Summary - -✅ **Implementation Complete**: All Task 3 code changes have been made to `run_clustering.py` -❌ **Functional Status**: Code will **fail at runtime** due to incomplete Task 2 -⚠️ **Critical Blocker**: `merge_iteration()` in `merge.py` must be refactored before this code can work - -**The Issue**: Task 3 calls `merge_iteration(batched_activations=...)` but Task 2 has not yet updated `merge_iteration()` to accept this parameter. The function still expects `activations: ActivationsTensor`. - -**See Section**: [Concern #1 - Critical Task 2 Dependency](#1-🚨-critical-task-2-is-incomplete) - -## Overview - -Implemented multi-batch clustering support in `run_clustering.py` to allow clustering runs to either: -1. Use precomputed activation batches from disk, OR -2. Compute a single batch on-the-fly (original behavior) - -This enables the model to be unloaded before merge iteration begins, saving memory during long merge processes. - -## Changes Made - -### 1. Import Addition (`run_clustering.py:34`) -```python -from spd.clustering.batched_activations import ActivationBatch, BatchedActivations -``` - -### 2. Refactored `main()` Function (lines 280-373) - -Replaced the monolithic data loading and activation computation section with a branching structure: - -#### Case 1: Precomputed Batches (lines 285-294) -When `run_config.precomputed_activations_dir` is provided: -- Loads `BatchedActivations` from disk directory -- Retrieves component labels from first batch -- No model loading required -- Logs number of batches available - -#### Case 2: Single Batch On-the-Fly (lines 296-373) -When `precomputed_activations_dir` is `None` (original behavior): -- Loads model from SPDRunInfo -- Loads dataset with appropriate seed -- Computes activations using `component_activations()` -- Processes activations (filtering, concatenation) -- **Saves to temporary directory** as `batch_0.pt` -- Creates `BatchedActivations` instance from temp directory -- Logs activations to WandB if enabled -- Cleans up model, batch, and intermediate tensors from memory - -### 3. Updated `merge_iteration()` Call (line 386) -```python -# Changed from: -activations=activations - -# Changed to: -batched_activations=batched_activations -``` - -## Dependencies - -### Completed (by other Claude instances): -- ✅ **Task 1** (TASKS-claude-1.md): - - Created `spd/clustering/batched_activations.py` with `ActivationBatch`, `BatchedActivations`, and `precompute_batches_for_ensemble()` - - Added `precomputed_activations_dir` field to `ClusteringRunConfig` (line 72-75) - - Added `--precomputed-activations-dir` CLI argument to `run_clustering.py` (line 415-420) - - Added CLI argument wiring to config overrides (line 440-441) - - Changed `recompute_costs_every` default from 10 → 1 in `MergeConfig` - - **Note**: Task 1 implemented "Option A" where `precomputed_activations_dir=None` means "auto-generate all batches" - -- ⚠️ **Task 2** (TASKS-claude-2.md): **INCOMPLETE** - - ✅ Added `recompute_coacts_from_scratch()` helper function to `merge.py` - - ✅ Updated merge pair samplers to handle NaN values - - ❌ **NOT DONE**: `merge_iteration()` refactor blocked, waiting for Task 1 - - **Status**: Task 2 is NOT complete - the main refactor is still pending - -- ✅ **Task 4** (TASKS-claude-4.md): - - Implemented batch precomputation in `run_pipeline.py` - - Added `generate_clustering_commands()` support for passing batch directories - - **Note**: Depends on Tasks 2 & 3 being complete for full integration - -## Testing Status - -❌ **No tests run yet** - changes are untested - -## Concerns and Potential Issues - -### 1. **🚨 CRITICAL: Task 2 is INCOMPLETE** -Per TASKS-claude-2.md, `merge_iteration()` has **NOT** been refactored yet. The Task 2 work was blocked waiting for Task 1, but now Task 1 is complete. - -**What's Missing in merge.py**: -- ❌ Function signature still uses `activations: ActivationsTensor` instead of `batched_activations: BatchedActivations` -- ❌ No batch cycling logic implemented -- ❌ No NaN masking for merged component rows/columns -- ❌ No periodic recomputation of coactivation matrix - -**Impact**: My changes to `run_clustering.py` will **FAIL** at runtime when calling `merge_iteration()` because: -```python -# This call will error - wrong parameter name/type -history: MergeHistory = merge_iteration( - merge_config=run_config.merge_config, - batched_activations=batched_activations, # ❌ merge_iteration() doesn't accept this yet - component_labels=component_labels, - log_callback=log_callback, -) -``` - -**Action Required**: Complete Task 2 refactor of `merge_iteration()` before Task 3 can be tested or used. - -### 2. **Temporary Directory Cleanup** -In single-batch mode, activations are saved to `storage.base_dir / "temp_batch"`. This directory: -- Is created in the run's output directory -- Contains a single `batch_0.pt` file -- Is **never explicitly cleaned up** - -**Potential Issue**: Accumulation of temp directories if many runs are executed. - -**Options**: -- Leave as-is (temp directories are part of run output, may be useful for debugging) -- Add cleanup after merge iteration completes -- Use Python's `tempfile.TemporaryDirectory()` context manager (would require restructuring) - -**Recommendation**: Leave as-is for now since it's in the run's output directory and provides transparency. - -### 3. **Batch Label Consistency** -When loading precomputed batches (Case 1), we only check the **first batch** for labels: -```python -first_batch = batched_activations.get_next_batch() -component_labels = ComponentLabels(first_batch.labels) -``` - -**Assumption**: All batches in the directory have identical labels in the same order. - -**Potential Issue**: If batches were generated incorrectly with different label sets or ordering, the merge will fail or produce incorrect results. - -**Mitigation**: The batch precomputation logic (Task 4 in `run_pipeline.py`) should ensure consistent labels. Consider adding validation. - -### 4. **Type Checking** -Have not run type checker (`basedpyright`) yet. Potential type issues: -- `batched_activations` variable assignment in two branches might confuse type checker -- `component_labels` assignment from different sources - -**Action Required**: Run `make type` to verify no type errors. - -### 5. **Memory Management** -In Case 2 (single-batch mode), we explicitly clean up: -```python -del model, batch, activations_dict, processed -gc.collect() -``` - -This is good, but: -- `temp_batch_dir` and `single_batch` objects remain in scope (though small) -- May want to explicitly `del single_batch` after save for consistency - -**Recommendation**: Low priority, current implementation is fine. - -### 6. **ComponentLabels Type Consistency** -In Case 1 (precomputed): -```python -component_labels = ComponentLabels(first_batch.labels) -``` - -In Case 2 (on-the-fly): -```python -component_labels = processed.labels # Already ComponentLabels type -``` - -The inconsistency is intentional (Case 1 needs wrapping, Case 2 doesn't), but could be confusing. - -**Verified**: Based on `batched_activations.py`, `ActivationBatch.labels` is `list[str]`, so wrapping in `ComponentLabels` is correct. - -### 7. **Implementation Divergence from TODO Spec** -The TODO document (lines 418-483) describes Case 2 as "compute single batch on-the-fly" without saving to disk. - -**However**, Task 1 implemented "Option A" where `precomputed_activations_dir=None` means "auto-generate all required batches before merging starts" (per TASKS-claude-1.md, concern #6). - -**My Implementation** follows the TODO spec literally: -- Case 2 creates a **single** batch -- Saves to temp directory -- Passes to merge iteration -- This is the **original behavior** (backward compatible) - -**Discrepancy**: My implementation doesn't match Task 1's "Option A" intent but **does** match the TODO document specification. - -**Clarification Needed**: Which behavior is desired? -- **Option A** (Task 1 intent): Auto-generate all batches when `precomputed_activations_dir=None` -- **Original Behavior** (TODO spec + my impl): Generate single batch on-the-fly - -**Current Status**: My implementation maintains backward compatibility with original single-batch behavior. - -### 8. **Backward Compatibility** -**Status**: ✅ Should be fully backward compatible - -When `precomputed_activations_dir=None` (default), the code: -- Follows the same logic as before -- Uses same dataset loading, activation computation, and processing -- Only difference: saves to temp directory and wraps in `BatchedActivations` - -**Concern**: The extra save/load cycle adds overhead for single-batch runs. - -**Impact**: Minimal - single file I/O is fast, and the wrapper is lightweight. - -**Important Note**: This only works if Task 2 implements backward compatibility where `recompute_costs_every=1` behaves identically to the original incremental update behavior. - -## Next Steps - -### Immediate Blockers -1. **🚨 COMPLETE TASK 2**: The `merge_iteration()` function must be refactored before Task 3 can work - - See TASKS-claude-2.md for detailed requirements - - Task 2 was blocked waiting for Task 1, but Task 1 is now complete - - This is the critical path blocker - -### After Task 2 is Complete -2. **Run type checker**: `make type` to catch any type issues -3. **Resolve implementation divergence**: Decide between Option A (auto-generate batches) vs. original behavior (single batch) -4. **Test single-batch mode**: Run clustering with default config - - Verify `recompute_costs_every=1` maintains backward compatibility -5. **Test multi-batch mode**: Use precomputed batches from Task 4 - - Verify batch cycling works correctly - - Verify cost recomputation happens at correct intervals -6. **Integration test**: Full pipeline test with `run_pipeline.py` - - Use Task 4's precomputation - - Run ensemble with multiple batches - - Verify all runs complete successfully - -## Files Modified - -- `spd/clustering/scripts/run_clustering.py`: ~100 lines changed - - Added imports - - Refactored `main()` function (lines 259-389) - - Updated merge_iteration call - -## Files Read/Referenced - -- `spd/clustering/batched_activations.py` (already exists) -- `spd/clustering/clustering_run_config.py` (field already added) -- `spd/clustering/merge_config.py` (field already added) -- `spd/clustering/TODO-multibatch.md` (implementation guide) - -## Task Status Summary - -| Component | Status | Notes | -|-----------|--------|-------| -| Task 3 Implementation | ✅ Complete | All code changes made per TODO spec | -| Task 2 Dependency | ❌ **BLOCKER** | `merge_iteration()` not refactored - will fail at runtime | -| Type Checking | ⚠️ Not Run | Need to run `make type` | -| Testing | ❌ Not Done | Cannot test until Task 2 complete | -| Integration | ⚠️ Pending | Works with Task 1 & 4, blocked by Task 2 | - -## Estimated Completeness - -**Task 3 Implementation**: **100% complete** ✅ (all code written per TODO-multibatch.md) - -**Task 3 Usability**: **0% functional** ❌ (blocked by incomplete Task 2) - -**Overall Multi-Batch Feature**: -- Task 1: ✅ Complete -- Task 2: ❌ Incomplete (critical blocker) -- Task 3: ✅ Complete (but untestable) -- Task 4: ✅ Complete - -**Integration Status**: ~75% complete (Task 2 is the only missing piece) diff --git a/TASKS-claude-4.md b/TASKS-claude-4.md deleted file mode 100644 index 93f212032..000000000 --- a/TASKS-claude-4.md +++ /dev/null @@ -1,208 +0,0 @@ -# Task 4 Implementation: Batch Precomputation in run_pipeline.py - -**Date**: 2025-10-27 -**Task**: Implement Task 4 from `spd/clustering/TODO-multibatch.md` -**Status**: ✅ Complete - -## What Was Done - -### 1. Created New Module: `spd/clustering/precompute_batches.py` - -Created a standalone module that can be imported by both `run_clustering.py` and `run_pipeline.py`. - -**Key Function**: `precompute_batches_for_ensemble()` -- Takes `ClusteringRunConfig`, `n_runs`, and `output_dir` as parameters -- Returns `Path | None` (None if single-batch mode, Path to batches directory if multi-batch) -- Loads model once to determine component count -- Calculates number of batches needed based on `recompute_costs_every` and total iterations -- Generates all activation batches for all runs in the ensemble -- Saves batches as `ActivationBatch` objects to disk -- Implements proper memory cleanup (GPU cache clearing, garbage collection) - -### 2. Updated `spd/clustering/scripts/run_pipeline.py` - -#### Imports Added: -```python -from spd.clustering.precompute_batches import precompute_batches_for_ensemble -from spd.clustering.clustering_run_config import ClusteringRunConfig -``` - -#### Modified Functions: - -**`generate_clustering_commands()`**: -- Added `batches_base_dir: Path | None = None` parameter -- Added logic to append `--precomputed-activations-dir` to command when batches are available -- Each run gets its own batch directory: `batches_base_dir / f"run_{idx}"` - -**`main()`**: -- Added code to load `ClusteringRunConfig` from pipeline config -- Calls `precompute_batches_for_ensemble()` before generating commands -- Passes `batches_base_dir` result to `generate_clustering_commands()` - -### 3. Implementation Details - -**Batch Directory Structure**: -``` -/ -└── precomputed_batches/ - ├── run_0/ - │ ├── batch_0.pt - │ ├── batch_1.pt - │ └── ... - ├── run_1/ - │ ├── batch_0.pt - │ └── ... - └── ... -``` - -**Seeding Strategy**: -- Each batch uses unique seed: `base_seed + run_idx * 1000 + batch_idx` -- Ensures different data for each run and each batch within a run - -**Memory Management**: -- Model loaded once at the beginning -- Activations moved to CPU before saving -- GPU cache cleared after each batch -- Full garbage collection after all batches complete - -## Concerns & Notes - -### 1. ✅ **CLI Argument in run_clustering.py - RESOLVED** - -**Status**: Task 1 has already added the `--precomputed-activations-dir` CLI argument to `run_clustering.py`. - -See TASKS-claude-1.md for details. - -### 2. ✅ **Config Field - RESOLVED** - -**Status**: Task 1 has already added `precomputed_activations_dir` field to `ClusteringRunConfig`. - -See TASKS-claude-1.md for details. - -### 3. ✅ **Precomputation Logic - RESOLVED** - -**Status**: The `precompute_batches_for_ensemble()` function was already implemented in `spd/clustering/batched_activations.py` by Task 1. - -**What I Did**: -- Initially created a duplicate `spd/clustering/precompute_batches.py` -- Discovered it was already in `batched_activations.py` -- File has been cleaned up (does not exist) -- Import in `run_pipeline.py` correctly references `batched_activations` - -**Current State**: Everything is properly integrated and no duplicates exist. - -### 4. ⚠️ **Filter Dead Threshold - Intentional Design** - -According to TASKS-claude-1.md and the TODO document: -- **NO FILTERING** is intentional for the simplified multi-batch implementation -- `filter_dead_threshold=0.0` is used during precomputation -- This is a key simplification listed in the TODO (line 11) - -**Status**: This is correct behavior, not a bug. - -### 5. ⚠️ **Missing Integration with run_clustering.py** - -According to TASKS-claude-2.md: -- Task 2 implemented NaN handling in samplers ✅ -- Task 2 added `recompute_coacts_from_scratch()` helper ✅ -- Task 2 did NOT complete the `merge_iteration()` refactor ⚠️ (blocked waiting for Task 1) - -**Current Status**: Task 1 is now complete, so Task 2 should be unblocked to finish the `merge_iteration()` refactor. - -**What's Still Needed**: -- `merge_iteration()` needs to accept `BatchedActivations` instead of single tensor -- NaN masking logic needs to be added to `merge_iteration()` -- Batch recomputation logic needs to be added - -See TASKS-claude-2.md for detailed instructions on what remains. - -### 6. ℹ️ **Component Count Calculation** - -We count components by summing across modules: -```python -n_components = sum(act.shape[-1] for act in sample_acts.values()) -``` - -This assumes: -- All modules are included (no module filtering) -- No dead component filtering -- All components from all modules are concatenated - -This matches the "NO FILTERING" principle in the TODO. - -### 7. ℹ️ **Dataset Streaming Not Supported** - -The precomputation uses `load_dataset()` which may not support streaming mode optimally when generating many batches. For large ensembles, this could be slow. - -**Mitigation**: Batches are generated sequentially, so memory footprint is bounded. - -### 8. ℹ️ **GPU Memory Assumptions** - -The implementation assumes: -- A single GPU is available (`get_device()`) -- The model + single batch fit in GPU memory -- Batches are small enough to process one at a time - -For very large models, this might require adjustments. - -## Testing Recommendations - -1. **Single-batch mode (backward compatibility)**: - - Set `recompute_costs_every=1` in config - - Verify `precompute_batches_for_ensemble()` returns `None` - - Verify no batch directory is created - - Verify commands don't include `--precomputed-activations-dir` - -2. **Multi-batch mode**: - - Set `recompute_costs_every=20` in config - - Run pipeline with `n_runs=2` - - Verify batch directories are created correctly - - Verify correct number of batches per run - - Verify batch files can be loaded with `ActivationBatch.load()` - -3. **Integration test**: - - Run full pipeline end-to-end with multi-batch mode - - Verify clustering runs complete successfully - - Verify results match single-batch mode (within tolerance) - -## Files Modified - -1. **Created**: `spd/clustering/precompute_batches.py` (~160 lines) -2. **Modified**: `spd/clustering/scripts/run_pipeline.py` - - Updated imports - - Modified `generate_clustering_commands()` (+3 lines logic) - - Modified `main()` (+7 lines for precomputation call) - -## Dependencies on Other Tasks - -- **Depends on Task 1**: ✅ COMPLETE - - `ActivationBatch` and `BatchedActivations` classes exist in `batched_activations.py` - - `precompute_batches_for_ensemble()` function exists - - Config fields added to `ClusteringRunConfig` and `MergeConfig` - - CLI argument added to `run_clustering.py` - -- **Task 2 Status**: ⚠️ PARTIALLY COMPLETE - - ✅ `recompute_coacts_from_scratch()` helper added - - ✅ NaN handling added to samplers - - ⚠️ `merge_iteration()` refactor NOT DONE (was blocked waiting for Task 1) - -- **Task 3 Status**: UNKNOWN (no TASKS-claude-3.md file found) - -## Next Steps - -1. ✅ Task 1 verified complete -2. ✅ Task 4 (this task) complete - pipeline integration done -3. ⚠️ **BLOCKER**: Task 2 needs to be finished - - Complete the `merge_iteration()` refactor per TASKS-claude-2.md - - Change function signature to accept `BatchedActivations` - - Implement NaN masking and batch recomputation logic -4. ❓ Verify Task 3 status (if it exists separately from Task 2) -5. 🧪 Run end-to-end integration tests -6. 📝 Update TODO-multibatch.md to reflect actual implementation choices - -## Summary - -**Task 4 is functionally complete**, but the full multi-batch system won't work until Task 2's `merge_iteration()` refactor is finished. The pipeline infrastructure is ready: -- ✅ Batches can be precomputed via `run_pipeline.py` -- ✅ Commands are generated with `--precomputed-activations-dir` -- ⚠️ But `run_clustering.py` can't actually USE those batches yet (needs Task 2 completion) diff --git a/TODO-multibatch.md b/TODO-multibatch.md deleted file mode 100644 index 2dde9f184..000000000 --- a/TODO-multibatch.md +++ /dev/null @@ -1,640 +0,0 @@ -# Multi-Batch Clustering Implementation Plan - -## Implementation Status: ✅ ALL TASKS COMPLETE - -**Date Completed:** 2025-10-27 -**Branch:** `clustering/refactor-multi-batch` - -All four tasks have been implemented. See task completion reports: -- `TASKS-claude-1.md` - Task 1 completion -- `TASKS-claude-2.md` - Task 2 completion -- `spd/clustering/TASKS-claude-3.md` - Task 3 completion -- `TASKS-claude-4.md` - Task 4 completion - -## Overview - -Implement multi-batch clustering to avoid keeping the model loaded during merge iterations. Instead, precompute multiple batches of activations and cycle through them, recomputing costs every `m` merges with a new batch. - -**Core Principle: Keep It Minimal** - -### Key Simplifications -- ✅ No dead component filtering (for now) -- ✅ No within-module clustering (defer to later) -- ✅ Always disk-based batch storage -- ✅ Model never kept during merge iteration -- ✅ NaN masking only (no separate boolean masks) -- ✅ Simple round-robin batch loading from disk - ---- - -## Tasks (All Complete) - -### ✅ Task 1: Batch Storage Infrastructure (COMPLETE) -**Files:** New file + config changes -**Dependencies:** None -**Estimated Lines:** ~180 lines (actual) - -**Completed Components:** -- ✅ Created `spd/clustering/batched_activations.py` (~180 lines) - - `ActivationBatch` class with save/load methods - - `BatchedActivations` class for round-robin batch cycling - - `precompute_batches_for_ensemble()` function -- ✅ Updated `merge_config.py`: Added `recompute_costs_every` field (default=1) -- ✅ Updated `clustering_run_config.py`: Added `precomputed_activations_dir` field -- ✅ Updated `run_clustering.py`: Added `--precomputed-activations-dir` CLI argument - -**Implementation Notes:** -- File location: `spd/clustering/batched_activations.py` -- Default `recompute_costs_every=1` maintains backward compatibility -- When `precomputed_activations_dir=None`, single batch is computed on-the-fly - -**See:** `TASKS-claude-1.md` for full implementation details - ---- - -### ✅ Task 2: Core Merge Logic Refactor (COMPLETE) -**Files:** `spd/clustering/merge.py`, `spd/clustering/math/merge_pair_samplers.py` -**Dependencies:** Task 1 (needs `BatchedActivations` interface) -**Estimated Lines:** ~150 lines (actual) - -**Completed Components:** -- ✅ Added `recompute_coacts_from_scratch()` helper function (`merge.py:33-61`) -- ✅ Refactored `merge_iteration()` to accept `BatchedActivations` parameter -- ✅ Implemented NaN masking for merged component rows/columns -- ✅ Added periodic batch recomputation logic (every `recompute_costs_every` iterations) -- ✅ Updated `range_sampler` in `merge_pair_samplers.py` to handle NaN values -- ✅ Updated `mcmc_sampler` in `merge_pair_samplers.py` to handle NaN values - -**Key Implementation Details:** -- Function signature changed to `batched_activations: BatchedActivations` -- Loads first batch at start, cycles through batches during merge iteration -- NaN masking invalidates affected entries after each merge -- Full coactivation recomputation from fresh batch at specified intervals -- All merge pair samplers gracefully handle NaN entries - -**See:** `TASKS-claude-2.md` for full implementation details - ---- - -### ✅ Task 3: Update `run_clustering.py` (COMPLETE) -**Files:** `spd/clustering/scripts/run_clustering.py` -**Dependencies:** Task 1 (needs `BatchedActivations`) -**Estimated Lines:** ~100 lines (actual) - -**Completed Components:** -- ✅ Refactored `main()` function to support two modes: - - **Case 1:** Load precomputed batches from disk (`precomputed_activations_dir` provided) - - **Case 2:** Compute single batch on-the-fly, save to temp directory (original behavior) -- ✅ Updated `merge_iteration()` call to use `batched_activations` parameter -- ✅ Added memory cleanup after activation computation -- ✅ Both modes wrap batches in `BatchedActivations` for unified interface - -**Key Implementation Details:** -- When `precomputed_activations_dir=None`: Single batch computed and saved to temp directory -- When `precomputed_activations_dir` provided: All batches loaded from disk -- Component labels extracted from first batch in both cases -- Memory cleanup: model, batch, activations deleted after computation in Case 2 - -**See:** `spd/clustering/TASKS-claude-3.md` for full implementation details - ---- - -### ✅ Task 4: Batch Precomputation in `run_pipeline.py` (COMPLETE) -**Files:** `spd/clustering/scripts/run_pipeline.py` -**Dependencies:** Task 1 (needs `ActivationBatch`) -**Estimated Lines:** ~50 lines modifications (actual) - -**Note:** The `precompute_batches_for_ensemble()` function was implemented in Task 1 as part of `batched_activations.py`, not as a separate addition to `run_pipeline.py`. - -**Completed Components:** -- ✅ Batch precomputation logic in `batched_activations.py` (`precompute_batches_for_ensemble()`) -- ✅ Updated `generate_clustering_commands()` to pass `--precomputed-activations-dir` argument -- ✅ Updated `main()` to call precomputation before generating commands -- ✅ Proper seeding strategy: `base_seed + run_idx * 1000 + batch_idx` - -**Key Implementation Details:** -- Loads model once, generates all batches for all runs in ensemble -- Batches saved to: `/precomputed_batches/run_{idx}/batch_{idx}.pt` -- Returns `None` if `recompute_costs_every=1` (single-batch mode) -- Each run gets unique seed per batch to ensure different data - -**See:** `TASKS-claude-4.md` for full implementation details - ---- - -## Testing Plan - -### Unit Tests -**File:** `tests/clustering/test_multi_batch.py` (TO BE CREATED) - -**Required Tests:** - """A single clustering run.""" - - # Create ExecutionStamp and storage - execution_stamp = ExecutionStamp.create( - run_type="cluster", - create_snapshot=False, - ) - storage = ClusteringRunStorage(execution_stamp) - clustering_run_id = execution_stamp.run_id - logger.info(f"Clustering run ID: {clustering_run_id}") - - # Register with ensemble if this is part of a pipeline - assigned_idx = None - if run_config.ensemble_id: - assigned_idx = register_clustering_run( - pipeline_run_id=run_config.ensemble_id, - clustering_run_id=clustering_run_id, - ) - logger.info( - f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx}" - ) - # IMPORTANT: set dataset seed based on assigned index - run_config = replace_pydantic_model( - run_config, - {"dataset_seed": run_config.dataset_seed + assigned_idx}, - ) - - # Save config - run_config.to_file(storage.config_path) - logger.info(f"Config saved to {storage.config_path}") - - # Start - logger.info("Starting clustering run") - logger.info(f"Output directory: {storage.base_dir}") - device = get_device() - - spd_run = SPDRunInfo.from_path(run_config.model_path) - task_name = spd_run.config.task_config.task_name - - # Setup WandB for this run - wandb_run = None - if run_config.wandb_project is not None: - wandb_run = wandb.init( - id=clustering_run_id, - entity=run_config.wandb_entity, - project=run_config.wandb_project, - group=run_config.ensemble_id, - config=run_config.model_dump(mode="json"), - tags=[ - "clustering", - f"task:{task_name}", - f"model:{run_config.wandb_decomp_model}", - f"ensemble_id:{run_config.ensemble_id}", - f"assigned_idx:{assigned_idx}", - ], - ) - - # Load or compute activations - # ===================================== - batched_activations: BatchedActivations - component_labels: ComponentLabels - - if run_config.precomputed_activations_dir is not None: - # Case 1: Use precomputed batches from disk - logger.info(f"Loading precomputed batches from {run_config.precomputed_activations_dir}") - batched_activations = BatchedActivations(run_config.precomputed_activations_dir) - - # Get labels from first batch - first_batch = batched_activations.get_next_batch() - component_labels = ComponentLabels(first_batch.labels) - - logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") - - else: - # Case 2: Compute single batch on-the-fly (original behavior) - logger.info(f"Computing single batch (seed={run_config.dataset_seed})") - - # Load model - logger.info("Loading model") - model = ComponentModel.from_run_info(spd_run).to(device) - - # Load data - logger.info("Loading dataset") - load_dataset_kwargs = {} - if run_config.dataset_streaming: - logger.info("Using streaming dataset loading") - load_dataset_kwargs["config_kwargs"] = dict(streaming=True) - assert task_name == "lm", ( - f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'" - ) - - batch = load_dataset( - model_path=run_config.model_path, - task_name=task_name, - batch_size=run_config.batch_size, - seed=run_config.dataset_seed, - **load_dataset_kwargs, - ).to(device) - - # Compute activations - logger.info("Computing activations") - activations_dict = component_activations( - model=model, - batch=batch, - device=device, - ) - - # Process (concat modules, NO FILTERING) - logger.info("Processing activations") - processed = process_activations( - activations=activations_dict, - filter_dead_threshold=0.0, # NO FILTERING - seq_mode="concat" if task_name == "lm" else None, - filter_modules=None, - ) - - # Save as single batch to temp dir - temp_batch_dir = storage.base_dir / "temp_batch" - temp_batch_dir.mkdir(exist_ok=True) - - single_batch = ActivationBatch( - activations=processed.activations, - labels=list(processed.labels), - ) - single_batch.save(temp_batch_dir / "batch_0.pt") - - batched_activations = BatchedActivations(temp_batch_dir) - component_labels = processed.labels - - # Log activations to WandB (if enabled) - if wandb_run is not None: - logger.info("Plotting activations") - plot_activations( - processed_activations=processed, - save_dir=None, - n_samples_max=256, - wandb_run=wandb_run, - ) - wandb_log_tensor( - wandb_run, - processed.activations, - "activations", - 0, - single=True, - ) - - # Clean up memory - del model, batch, activations_dict, processed - gc.collect() - - # Run merge iteration - # ===================================== - logger.info("Starting merging") - log_callback = ( - partial(_log_callback, run=wandb_run, run_config=run_config) - if wandb_run is not None - else None - ) - - history = merge_iteration( - merge_config=run_config.merge_config, - batched_activations=batched_activations, - component_labels=component_labels, - log_callback=log_callback, - ) - - # Save merge history - history.save(storage.history_path) - logger.info(f"History saved to {storage.history_path}") - - # Log to WandB - if wandb_run is not None: - _log_merge_history_plots(wandb_run, history) - _save_merge_history_artifact(wandb_run, storage.history_path, history) - wandb_run.finish() - logger.info("WandB run finished") - - return storage.history_path -``` - ---- - -### Task 4: Batch Precomputation in `run_pipeline.py` -**Files:** `spd/clustering/scripts/run_pipeline.py` -**Dependencies:** Task 1 (needs `ActivationBatch`) -**Estimated Lines:** ~200 lines - -#### 4.1 Add Batch Precomputation Function - -Add this function before `main()`: - -```python -def precompute_batches_for_ensemble( - pipeline_config: ClusteringPipelineConfig, - pipeline_run_id: str, - storage: ClusteringPipelineStorage, -) -> Path | None: - """ - Precompute activation batches for all runs in ensemble. - - This loads the model ONCE and generates all batches for all runs, - then saves them to disk. Each clustering run will load batches - from disk without needing the model. - - Args: - pipeline_config: Pipeline configuration - pipeline_run_id: Unique ID for this pipeline run - storage: Storage paths for pipeline outputs - - Returns: - Path to base directory containing batches for all runs, - or None if single-batch mode (recompute_costs_every=1) - """ - clustering_run_config = ClusteringRunConfig.from_file( - pipeline_config.clustering_run_config_path - ) - - # Check if multi-batch mode - recompute_every = clustering_run_config.merge_config.recompute_costs_every - if recompute_every == 1: - logger.info("Single-batch mode (recompute_costs_every=1), skipping precomputation") - return None - - logger.info("Multi-batch mode detected, precomputing activation batches") - - # Load model to determine number of components - device = get_device() - spd_run = SPDRunInfo.from_path(clustering_run_config.model_path) - model = ComponentModel.from_run_info(spd_run).to(device) - task_name = spd_run.config.task_config.task_name - - # Get number of components (no filtering, so just count from model) - # Load a sample to count components - logger.info("Loading sample batch to count components") - sample_batch = load_dataset( - model_path=clustering_run_config.model_path, - task_name=task_name, - batch_size=clustering_run_config.batch_size, - seed=0, - ).to(device) - - with torch.no_grad(): - sample_acts = component_activations(model, device, sample_batch) - - # Count total components across all modules - n_components = sum( - act.shape[-1] for act in sample_acts.values() - ) - - # Calculate number of iterations - n_iters = clustering_run_config.merge_config.get_num_iters(n_components) - - # Calculate batches needed per run - n_batches_needed = (n_iters + recompute_every - 1) // recompute_every - - logger.info( - f"Precomputing {n_batches_needed} batches per run for {pipeline_config.n_runs} runs" - ) - logger.info(f"Total: {n_batches_needed * pipeline_config.n_runs} batches") - - # Create batches directory - batches_base_dir = storage.base_dir / "precomputed_batches" - batches_base_dir.mkdir(exist_ok=True) - - # For each run in ensemble - for run_idx in tqdm(range(pipeline_config.n_runs), desc="Ensemble runs"): - run_batch_dir = batches_base_dir / f"run_{run_idx}" - run_batch_dir.mkdir(exist_ok=True) - - # Generate batches for this run - for batch_idx in tqdm( - range(n_batches_needed), - desc=f" Run {run_idx} batches", - leave=False - ): - # Use unique seed: base_seed + run_idx * 1000 + batch_idx - # This ensures different data for each run and each batch - seed = clustering_run_config.dataset_seed + run_idx * 1000 + batch_idx - - # Load data - batch_data = load_dataset( - model_path=clustering_run_config.model_path, - task_name=task_name, - batch_size=clustering_run_config.batch_size, - seed=seed, - ).to(device) - - # Compute activations - with torch.no_grad(): - acts_dict = component_activations(model, device, batch_data) - - # Process (concat, NO FILTERING) - processed = process_activations( - activations=acts_dict, - filter_dead_threshold=0.0, # NO FILTERING - seq_mode="concat" if task_name == "lm" else None, - filter_modules=None, - ) - - # Save as ActivationBatch - activation_batch = ActivationBatch( - activations=processed.activations.cpu(), # Move to CPU for storage - labels=list(processed.labels), - ) - activation_batch.save(run_batch_dir / f"batch_{batch_idx}.pt") - - # Clean up - del batch_data, acts_dict, processed, activation_batch - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - # Clean up model - del model, sample_batch, sample_acts - gc.collect() - torch.cuda.empty_cache() if torch.cuda.is_available() else None - - logger.info(f"All batches precomputed and saved to {batches_base_dir}") - - return batches_base_dir -``` - -#### 4.2 Update `generate_clustering_commands()` - -```python -def generate_clustering_commands( - pipeline_config: ClusteringPipelineConfig, - pipeline_run_id: str, - batches_base_dir: Path | None, # NEW PARAMETER - dataset_streaming: bool = False, -) -> list[str]: - """Generate commands for each clustering run. - - Args: - pipeline_config: Pipeline configuration - pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) - batches_base_dir: Path to precomputed batches directory, or None for single-batch mode - dataset_streaming: Whether to use dataset streaming - - Returns: - List of shell-safe command strings - """ - commands = [] - - for idx in range(pipeline_config.n_runs): - cmd_parts = [ - "python", - "spd/clustering/scripts/run_clustering.py", - "--config", - pipeline_config.clustering_run_config_path.as_posix(), - "--pipeline-run-id", - pipeline_run_id, - "--idx-in-ensemble", - str(idx), - "--wandb-project", - str(pipeline_config.wandb_project), - "--wandb-entity", - pipeline_config.wandb_entity, - ] - - # Add precomputed batches path if available - if batches_base_dir is not None: - run_batch_dir = batches_base_dir / f"run_{idx}" - cmd_parts.extend(["--precomputed-activations-dir", str(run_batch_dir)]) - - if dataset_streaming: - cmd_parts.append("--dataset-streaming") - - commands.append(shlex.join(cmd_parts)) - - return commands -``` - -#### 4.3 Update `main()` to Call Precomputation - -```python -def main( - pipeline_config: ClusteringPipelineConfig, - local: bool = False, - local_clustering_parallel: bool = False, - local_calc_distances_parallel: bool = False, - dataset_streaming: bool = False, - track_resources_calc_distances: bool = False, -) -> None: - """Submit clustering runs to SLURM.""" - - logger.set_format("console", "terse") - - # Validation - if local_clustering_parallel or local_calc_distances_parallel or track_resources_calc_distances: - assert local, ( - "local_clustering_parallel, local_calc_distances_parallel, track_resources_calc_distances " - "can only be set when running locally" - ) - - # Create ExecutionStamp for pipeline - execution_stamp = ExecutionStamp.create( - run_type="ensemble", - create_snapshot=pipeline_config.create_git_snapshot, - ) - pipeline_run_id = execution_stamp.run_id - logger.info(f"Pipeline run ID: {pipeline_run_id}") - - # Initialize storage - storage = ClusteringPipelineStorage(execution_stamp) - logger.info(f"Pipeline output directory: {storage.base_dir}") - - # Save pipeline config - pipeline_config.to_file(storage.pipeline_config_path) - logger.info(f"Pipeline config saved to {storage.pipeline_config_path}") - - # Create WandB workspace if requested - if pipeline_config.wandb_project is not None: - workspace_url = create_clustering_workspace_view( - ensemble_id=pipeline_run_id, - project=pipeline_config.wandb_project, - entity=pipeline_config.wandb_entity, - ) - logger.info(f"WandB workspace: {workspace_url}") - - # NEW: Precompute batches if multi-batch mode - batches_base_dir = precompute_batches_for_ensemble( - pipeline_config=pipeline_config, - pipeline_run_id=pipeline_run_id, - storage=storage, - ) - - # Generate commands for clustering runs - clustering_commands = generate_clustering_commands( - pipeline_config=pipeline_config, - pipeline_run_id=pipeline_run_id, - batches_base_dir=batches_base_dir, # NEW - dataset_streaming=dataset_streaming, - ) - - # Generate commands for calculating distances - calc_distances_commands = generate_calc_distances_commands( - pipeline_run_id=pipeline_run_id, - distances_methods=pipeline_config.distances_methods, - ) - - # ... rest of submission logic unchanged -``` - ---- - -## Testing Plan - -### Unit Tests -**File:** `tests/clustering/test_multi_batch.py` (new) - -1. Test `ActivationBatch` save/load -2. Test `BatchedActivations` cycling through batches -3. Test `recompute_coacts_from_scratch` produces correct shapes -4. Test NaN handling in merge pair samplers -5. Test backward compatibility (single batch, `recompute_costs_every=1`) - -### Integration Tests - -1. **Single-batch mode (backward compatibility):** - ```python - config = ClusteringRunConfig( - precomputed_activations_dir=None, - merge_config=MergeConfig(recompute_costs_every=1), - ... - ) - # Should behave exactly as before - ``` - -2. **Multi-batch mode:** - ```python - config = ClusteringRunConfig( - precomputed_activations_dir=Path("batches/"), - merge_config=MergeConfig(recompute_costs_every=20), - ... - ) - # Should use multiple batches - ``` - -3. **Ensemble with precomputation:** - - Run small ensemble (n=3) with multi-batch - - Verify batches are created correctly - - Verify clustering runs use precomputed batches - -### Manual Testing Checklist - -- [ ] Single run, single batch (original behavior) -- [ ] Single run, multi-batch with precomputed dir -- [ ] Ensemble run, single batch mode -- [ ] Ensemble run, multi-batch mode with precomputation -- [ ] Verify NaN masking doesn't break merge sampling -- [ ] Verify memory usage (model not kept during merge) -- [ ] Verify batch cycling works correctly - ---- - -## Summary - -**Total Changes:** -- New files: 1 (`spd/clustering/batched_activations.py`) -- Modified files: 5 -- Total new code: ~500 lines -- Backward compatible: Yes (defaults to original behavior) - -**Key Benefits:** -- Model loaded once, not kept during merge iteration -- Supports arbitrary number of batches -- Simple disk-based storage -- Minimal config changes (2 new fields) -- No complex scheduling or memory management - -**Dependencies:** None (PR #227 not required for this simplified version) diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 9e6f14815..000000000 --- a/TODO.md +++ /dev/null @@ -1,73 +0,0 @@ -# TODO: Cluster Coactivation Matrix Implementation - -## What Was Changed - -### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`) -- New dataclass to hold vectorized cluster activations for all clusters -- Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list - -### 2. Added `compute_all_cluster_activations()` function -- Vectorized computation of all cluster activations at once -- Replaces the per-cluster loop for better performance -- Returns `ClusterActivations` object - -### 3. Added `compute_cluster_coactivations()` function -- Computes coactivation matrix from list of `ClusterActivations` across batches -- Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask` -- Follows the pattern from `spd/clustering/merge.py:69` -- Returns tuple of (coactivation_matrix, cluster_indices) - -### 4. Modified `compute_max_activations()` function -- Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list -- Calls `compute_cluster_coactivations()` to compute the matrix -- **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]` - - Added coactivation matrix and cluster_indices to return value - -### 5. Modified `spd/clustering/dashboard/run.py` -- Updated to handle new return value from `compute_max_activations()` -- Saves coactivation matrix as `coactivations.npz` in the dashboard output directory -- NPZ file contains: - - `coactivations`: the [n_clusters, n_clusters] matrix - - `cluster_indices`: array mapping matrix positions to cluster IDs - -## What Needs to be Checked - -### Testing -- [ ] **Run the dashboard pipeline** on a real clustering run to verify: - - Coactivation computation doesn't crash - - Coactivations are saved correctly to NPZ file - - Matrix dimensions are correct - - `cluster_indices` mapping is correct - -### Type Checking -- [ ] Run `make type` to ensure no type errors were introduced -- [ ] Verify jaxtyping annotations are correct - -### Verification -- [ ] Load a saved `coactivations.npz` file and verify: - ```python - data = np.load("coactivations.npz") - coact = data["coactivations"] - cluster_indices = data["cluster_indices"] - # Check: coact should be symmetric - # Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most) - # Check: cluster_indices length should match coact.shape[0] - ``` - -### Performance -- [ ] Check if vectorization actually improved performance -- [ ] Monitor memory usage with large numbers of clusters - -### Edge Cases -- [ ] Test with clusters that have zero activations -- [ ] Test with single-batch runs -- [ ] Test with very large number of clusters - -### Integration -- [ ] Verify the coactivation matrix can be used in downstream analysis -- [ ] Consider if visualization of coactivations should be added to dashboard - -## Notes -- The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples) -- Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable -- The computation happens in the dashboard pipeline, NOT during the main clustering pipeline From a5505df317890ccb4c854fd87b98a831bd813a3b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 13:27:50 +0000 Subject: [PATCH 37/61] change recompute_costs_every behavior (None for single-batch) --- spd/clustering/batched_activations.py | 4 ++-- spd/clustering/merge_config.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index ede4cf774..2e14f9041 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -117,8 +117,8 @@ def precompute_batches_for_ensemble( """ # Check if multi-batch mode recompute_every = clustering_run_config.merge_config.recompute_costs_every - if recompute_every == 1: - logger.info("Single-batch mode (recompute_costs_every=1), skipping precomputation") + if recompute_every is None: + logger.info("Single-batch mode (recompute_costs_every=`None`), skipping precomputation") return None logger.info("Multi-batch mode detected, precomputing activation batches") diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 8c2f44ca6..6a1c53069 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -72,9 +72,9 @@ class MergeConfig(BaseConfig): default=None, description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", ) - recompute_costs_every: PositiveInt = Field( - default=1, - description="Number of merges before recomputing costs with new batch. Set to 1 for original behavior.", + recompute_costs_every: PositiveInt | None = Field( + default=None, + description="Number of merges before recomputing costs with new batch. Set to `None` to use a single batch throughout.", ) batch_size: PositiveInt = Field( default=64, From 3f078190d77750d232e96b3691df486834a8c3b5 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 13:32:24 +0000 Subject: [PATCH 38/61] simplify & dedupe valid mask computation --- spd/clustering/math/merge_pair_samplers.py | 26 +++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py index 384122711..38da24af5 100644 --- a/spd/clustering/math/merge_pair_samplers.py +++ b/spd/clustering/math/merge_pair_samplers.py @@ -25,6 +25,19 @@ def __call__( ) -> MergePair: ... +def get_valid_mask( + costs: ClusterCoactivationShaped, +) -> ClusterCoactivationShaped: + """Get a boolean mask of valid merge pairs (non-NaN, non-diagonal).""" + k_groups: int = costs.shape[0] + return ( + ~torch.isnan(costs) # mask out NaN entries + & ~torch.eye( + k_groups, dtype=torch.bool, device=costs.device + ) # mask out diagonal (can't merge with self) + ) + + def range_sampler( costs: ClusterCoactivationShaped, threshold: float = 0.05, @@ -47,12 +60,7 @@ def range_sampler( k_groups: int = costs.shape[0] assert costs.shape[1] == k_groups, "Cost matrix must be square" - # Mask out NaN entries and diagonal - valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.isnan(costs) - diag_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( - k_groups, dtype=torch.bool, device=costs.device - ) - valid_mask = valid_mask & diag_mask + valid_mask: ClusterCoactivationShaped = get_valid_mask(costs) # Get valid costs valid_costs: Float[Tensor, " n_valid"] = costs[valid_mask] @@ -101,11 +109,7 @@ def mcmc_sampler( k_groups: int = costs.shape[0] assert costs.shape[1] == k_groups, "Cost matrix must be square" - # Create mask for valid pairs (non-diagonal and non-NaN) - valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( - k_groups, dtype=torch.bool, device=costs.device - ) - valid_mask = valid_mask & ~torch.isnan(costs) + valid_mask: ClusterCoactivationShaped = get_valid_mask(costs) # Check if we have any valid pairs if not valid_mask.any(): From 76f7fcded25560487ae7fd677d8a366ce571b932 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 13:34:06 +0000 Subject: [PATCH 39/61] move non-nan count validity check --- spd/clustering/math/merge_pair_samplers.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py index 38da24af5..bb8eaa3e8 100644 --- a/spd/clustering/math/merge_pair_samplers.py +++ b/spd/clustering/math/merge_pair_samplers.py @@ -30,12 +30,15 @@ def get_valid_mask( ) -> ClusterCoactivationShaped: """Get a boolean mask of valid merge pairs (non-NaN, non-diagonal).""" k_groups: int = costs.shape[0] - return ( + valid_mask: ClusterCoactivationShaped = ( ~torch.isnan(costs) # mask out NaN entries & ~torch.eye( k_groups, dtype=torch.bool, device=costs.device ) # mask out diagonal (can't merge with self) ) + if not valid_mask.any(): + raise ValueError("All non-diagonal costs are NaN, cannot sample merge pair") + return valid_mask def range_sampler( @@ -65,9 +68,6 @@ def range_sampler( # Get valid costs valid_costs: Float[Tensor, " n_valid"] = costs[valid_mask] - if valid_costs.numel() == 0: - raise ValueError("All costs are NaN, cannot sample merge pair") - # Find the range of valid costs min_cost: float = float(valid_costs.min().item()) max_cost: float = float(valid_costs.max().item()) @@ -111,10 +111,6 @@ def mcmc_sampler( valid_mask: ClusterCoactivationShaped = get_valid_mask(costs) - # Check if we have any valid pairs - if not valid_mask.any(): - raise ValueError("All costs are NaN, cannot sample merge pair") - # Compute probabilities: exp(-cost/temperature) # Use stable softmax computation to avoid overflow costs_masked: ClusterCoactivationShaped = costs.clone() From 7c1724734fe0818ac417d578b72ef9c17b78921d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 14:01:05 +0000 Subject: [PATCH 40/61] wip --- spd/clustering/batched_activations.py | 121 +++++++++++------- spd/clustering/merge.py | 4 +- spd/clustering/scripts/run_clustering.py | 2 +- tests/clustering/scripts/cluster_resid_mlp.py | 6 +- tests/clustering/scripts/cluster_ss.py | 4 +- tests/clustering/test_merge_integration.py | 12 +- 6 files changed, 91 insertions(+), 58 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 2e14f9041..199d6103c 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -6,92 +6,125 @@ """ import gc +import re +import zipfile +from collections.abc import Iterator from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING +import numpy as np import torch +from jaxtyping import Float from torch import Tensor from tqdm import tqdm from spd.clustering.activations import component_activations, process_activations +from spd.clustering.consts import BatchTensor, ComponentLabels from spd.clustering.dataset import load_dataset from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName from spd.utils.distributed_utils import get_device if TYPE_CHECKING: from spd.clustering.clustering_run_config import ClusteringRunConfig +_BATCH_FORMAT: str = "batch_{idx:04}.zip" + + @dataclass class ActivationBatch: - """Single batch of activations - just tensors, no processing.""" + """Single batch of subcomponent activations""" - activations: Tensor # [samples, n_components] - labels: list[str] # ["module:idx", ...] + activations: Float[Tensor, "samples n_components"] + labels: ComponentLabels def save(self, path: Path) -> None: - torch.save({"activations": self.activations, "labels": self.labels}, path) + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "w") as zf: + with zf.open("activations.npy", "w") as f: + np.save(f, self.activations.cpu().numpy()) + zf.writestr("labels.txt", "\n".join(self.labels)) + + def save_idx(self, batch_dir: Path, idx: int) -> None: + self.save(batch_dir / _BATCH_FORMAT.format(idx=idx)) @staticmethod - def load(path: Path) -> "ActivationBatch": - data = torch.load(path, weights_only=False) + def read(path: Path) -> "ActivationBatch": + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "r") as zf: + with zf.open("activations.npy", "r") as f: + activations_np = np.load(f) + labels_raw: list[str] = zf.read("labels.txt").decode("utf-8").splitlines() return ActivationBatch( - activations=data["activations"], - labels=data["labels"], + activations=torch.from_numpy(activations_np), + labels=ComponentLabels(labels_raw), ) -class BatchedActivations: +class BatchedActivations(Iterator[ActivationBatch]): """Iterator over activation batches from disk.""" def __init__(self, batch_dir: Path): - self.batch_dir = batch_dir - # Find all batch files: batch_0.pt, batch_1.pt, ... - self.batch_paths = sorted(batch_dir.glob("batch_*.pt")) + self.batch_dir: Path = batch_dir + # Find all batch files + _glob_pattern: str = re.sub(r"\{[^{}]*\}", "*", _BATCH_FORMAT) # returns `batch_*.zip` + self.batch_paths: list[Path] = sorted(batch_dir.glob(_glob_pattern)) assert len(self.batch_paths) > 0, f"No batch files found in {batch_dir}" - self.current_idx = 0 + self.current_idx: int = 0 + + # Verify naming + for i in range(len(self.batch_paths)): + expected_name = _BATCH_FORMAT.format(idx=i) + actual_name = self.batch_paths[i].name + assert expected_name == actual_name, ( + f"Expected batch file '{expected_name}', found '{actual_name}'" + ) @property def n_batches(self) -> int: return len(self.batch_paths) - def get_next_batch(self) -> ActivationBatch: + def _get_next_batch(self) -> ActivationBatch: """Load and return next batch, cycling through available batches.""" - batch = ActivationBatch.load(self.batch_paths[self.current_idx]) - self.current_idx = (self.current_idx + 1) % self.n_batches + batch: ActivationBatch = ActivationBatch.read( + self.batch_paths[self.current_idx % self.n_batches] + ) + self.current_idx += 1 return batch + def __next__(self) -> ActivationBatch: + return self._get_next_batch() -def batched_activations_from_tensor( - activations: Tensor, - labels: list[str], -) -> BatchedActivations: - """ - Create a BatchedActivations instance from a single activation tensor. + @classmethod + def from_tensor( + cls, activations: Tensor, labels: ComponentLabels | list[str] + ) -> "BatchedActivations": + """Create a BatchedActivations instance from a single activation tensor. - This is a helper for backward compatibility with tests and code that uses - single-batch mode. It creates a temporary directory with a single batch file. + This is a helper for backward compatibility with tests and code that uses + single-batch mode. It creates a temporary directory with a single batch file. - Args: - activations: Activation tensor [samples, n_components] - labels: Component labels ["module:idx", ...] + Args: + activations: Activation tensor [samples, n_components] + labels: Component labels ["module:idx", ...] - Returns: - BatchedActivations instance that cycles through the single batch - """ - import tempfile + Returns: + BatchedActivations instance that cycles through the single batch + """ + import tempfile - # Create a temporary directory - temp_dir = Path(tempfile.mkdtemp(prefix="batch_temp_")) + # Create a temporary directory + temp_dir = Path(tempfile.mkdtemp(prefix="batch_temp_")) - # Save the single batch - batch = ActivationBatch(activations=activations, labels=labels) - batch.save(temp_dir / "batch_0.pt") + # Save the single batch + batch = ActivationBatch(activations=activations, labels=ComponentLabels(labels)) + batch.save(temp_dir / _BATCH_FORMAT.format(idx=0)) - # Return BatchedActivations that will cycle through this single batch - return BatchedActivations(temp_dir) + # Return BatchedActivations that will cycle through this single batch + return BatchedActivations(temp_dir) def precompute_batches_for_ensemble( @@ -116,7 +149,7 @@ def precompute_batches_for_ensemble( or None if single-batch mode (recompute_costs_every=1) """ # Check if multi-batch mode - recompute_every = clustering_run_config.merge_config.recompute_costs_every + recompute_every: int | None = clustering_run_config.merge_config.recompute_costs_every if recompute_every is None: logger.info("Single-batch mode (recompute_costs_every=`None`), skipping precomputation") return None @@ -124,15 +157,15 @@ def precompute_batches_for_ensemble( logger.info("Multi-batch mode detected, precomputing activation batches") # Load model to determine number of components - device = get_device() - spd_run = SPDRunInfo.from_path(clustering_run_config.model_path) - model = ComponentModel.from_run_info(spd_run).to(device) - task_name = spd_run.config.task_config.task_name + device: str = get_device() + spd_run: SPDRunInfo = SPDRunInfo.from_path(clustering_run_config.model_path) + model: ComponentModel = ComponentModel.from_run_info(spd_run).to(device) + task_name: TaskName = spd_run.config.task_config.task_name # Get number of components (no filtering, so just count from model) # Load a sample to count components logger.info("Loading sample batch to count components") - sample_batch = load_dataset( + sample_batch: BatchTensor = load_dataset( model_path=clustering_run_config.model_path, task_name=task_name, batch_size=clustering_run_config.batch_size, diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 0e7a1b800..6a5aca85c 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -105,7 +105,7 @@ def merge_iteration( # Load first batch # -------------------------------------------------- - first_batch: ActivationBatch = batched_activations.get_next_batch() + first_batch: ActivationBatch = batched_activations._get_next_batch() activations: Tensor = first_batch.activations # Compute initial coactivations @@ -204,7 +204,7 @@ def merge_iteration( ) % merge_config.recompute_costs_every == 0 and iter_idx + 1 < num_iters if should_recompute: - new_batch: ActivationBatch = batched_activations.get_next_batch() + new_batch: ActivationBatch = batched_activations._get_next_batch() activations = new_batch.activations # Recompute fresh coacts with current merge groups diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 1ef7cd6cb..4241c43ee 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -287,7 +287,7 @@ def main(run_config: ClusteringRunConfig) -> Path: batched_activations = BatchedActivations(run_config.precomputed_activations_dir) # Get labels from first batch - first_batch = batched_activations.get_next_batch() + first_batch = batched_activations._get_next_batch() component_labels = ComponentLabels(first_batch.labels) logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index a5fdd6956..6a2d25d53 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -12,7 +12,7 @@ component_activations, process_activations, ) -from spd.clustering.batched_activations import batched_activations_from_tensor +from spd.clustering.batched_activations import BatchedActivations from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig @@ -149,7 +149,7 @@ def _plot_func( ) -BATCHED_ACTIVATIONS = batched_activations_from_tensor( +BATCHED_ACTIVATIONS = BatchedActivations.from_tensor( activations=PROCESSED_ACTIVATIONS.activations, labels=list(PROCESSED_ACTIVATIONS.labels), ) @@ -177,7 +177,7 @@ def _plot_func( ENSEMBLE_SIZE: int = 4 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): - batched_acts = batched_activations_from_tensor( + batched_acts = BatchedActivations.from_tensor( activations=PROCESSED_ACTIVATIONS.activations, labels=list(PROCESSED_ACTIVATIONS.labels), ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 8c7b42feb..8a3bbd033 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,7 +16,7 @@ component_activations, process_activations, ) -from spd.clustering.batched_activations import batched_activations_from_tensor +from spd.clustering.batched_activations import BatchedActivations from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration @@ -112,7 +112,7 @@ ENSEMBLE_SIZE: int = 2 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): - batched_acts = batched_activations_from_tensor( + batched_acts = BatchedActivations.from_tensor( activations=PROCESSED_ACTIVATIONS.activations, labels=list(PROCESSED_ACTIVATIONS.labels), ) diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index d1b1a9571..09c1596bf 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -2,7 +2,7 @@ import torch -from spd.clustering.batched_activations import batched_activations_from_tensor +from spd.clustering.batched_activations import BatchedActivations from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig @@ -30,7 +30,7 @@ def test_merge_with_range_sampler(self): ) # Run merge iteration - batched_activations = batched_activations_from_tensor( + batched_activations = BatchedActivations.from_tensor( activations=activations, labels=list(component_labels) ) history = merge_iteration( @@ -68,7 +68,7 @@ def test_merge_with_mcmc_sampler(self): ) # Run merge iteration - batched_activations = batched_activations_from_tensor( + batched_activations = BatchedActivations.from_tensor( activations=activations, labels=list(component_labels) ) history = merge_iteration( @@ -108,7 +108,7 @@ def test_merge_comparison_samplers(self): merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum ) - batched_activations_range = batched_activations_from_tensor( + batched_activations_range = BatchedActivations.from_tensor( activations=activations.clone(), labels=list(component_labels) ) history_range = merge_iteration( @@ -126,7 +126,7 @@ def test_merge_comparison_samplers(self): merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp ) - batched_activations_mcmc = batched_activations_from_tensor( + batched_activations_mcmc = BatchedActivations.from_tensor( activations=activations.clone(), labels=list(component_labels) ) history_mcmc = merge_iteration( @@ -157,7 +157,7 @@ def test_merge_with_small_components(self): merge_pair_sampling_kwargs={"temperature": 2.0}, ) - batched_activations = batched_activations_from_tensor( + batched_activations = BatchedActivations.from_tensor( activations=activations, labels=list(component_labels) ) history = merge_iteration( From a7315c11e3e68b9f16d5627bc8552c1cf6920a6d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 14:14:27 +0000 Subject: [PATCH 41/61] add back some type hints claude is extremely lazy. it reads STYLE.md and apparently decides to literally never use type hints for anything, and especially to never type hint anything with jaxtyping. So, I made it do a pass of adding back type hints --- spd/clustering/batched_activations.py | 41 ++++++++------ spd/clustering/merge.py | 14 ++--- spd/clustering/scripts/run_clustering.py | 20 ++++--- spd/clustering/scripts/run_pipeline.py | 12 ++--- tests/clustering/scripts/cluster_ss.py | 8 +-- tests/clustering/test_merge_integration.py | 63 +++++++++++----------- 6 files changed, 83 insertions(+), 75 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 199d6103c..31d394d5a 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -19,7 +19,11 @@ from torch import Tensor from tqdm import tqdm -from spd.clustering.activations import component_activations, process_activations +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) from spd.clustering.consts import BatchTensor, ComponentLabels from spd.clustering.dataset import load_dataset from spd.log import logger @@ -56,7 +60,7 @@ def read(path: Path) -> "ActivationBatch": zf: zipfile.ZipFile with zipfile.ZipFile(path, "r") as zf: with zf.open("activations.npy", "r") as f: - activations_np = np.load(f) + activations_np: Float[np.ndarray, "samples n_components"] = np.load(f) labels_raw: list[str] = zf.read("labels.txt").decode("utf-8").splitlines() return ActivationBatch( activations=torch.from_numpy(activations_np), @@ -76,9 +80,10 @@ def __init__(self, batch_dir: Path): self.current_idx: int = 0 # Verify naming + i: int for i in range(len(self.batch_paths)): - expected_name = _BATCH_FORMAT.format(idx=i) - actual_name = self.batch_paths[i].name + expected_name: str = _BATCH_FORMAT.format(idx=i) + actual_name: str = self.batch_paths[i].name assert expected_name == actual_name, ( f"Expected batch file '{expected_name}', found '{actual_name}'" ) @@ -117,10 +122,10 @@ def from_tensor( import tempfile # Create a temporary directory - temp_dir = Path(tempfile.mkdtemp(prefix="batch_temp_")) + temp_dir: Path = Path(tempfile.mkdtemp(prefix="batch_temp_")) # Save the single batch - batch = ActivationBatch(activations=activations, labels=ComponentLabels(labels)) + batch: ActivationBatch = ActivationBatch(activations=activations, labels=ComponentLabels(labels)) batch.save(temp_dir / _BATCH_FORMAT.format(idx=0)) # Return BatchedActivations that will cycle through this single batch @@ -173,30 +178,32 @@ def precompute_batches_for_ensemble( ).to(device) with torch.no_grad(): - sample_acts = component_activations(model, device, sample_batch) + sample_acts: dict[str, Float[Tensor, "samples components"]] = component_activations(model, device, sample_batch) # Count total components across all modules - n_components = sum(act.shape[-1] for act in sample_acts.values()) + n_components: int = sum(act.shape[-1] for act in sample_acts.values()) # Calculate number of iterations - n_iters = clustering_run_config.merge_config.get_num_iters(n_components) + n_iters: int = clustering_run_config.merge_config.get_num_iters(n_components) # Calculate batches needed per run - n_batches_needed = (n_iters + recompute_every - 1) // recompute_every + n_batches_needed: int = (n_iters + recompute_every - 1) // recompute_every logger.info(f"Precomputing {n_batches_needed} batches per run for {n_runs} runs") logger.info(f"Total: {n_batches_needed * n_runs} batches") # Create batches directory - batches_base_dir = output_dir / "precomputed_batches" + batches_base_dir: Path = output_dir / "precomputed_batches" batches_base_dir.mkdir(exist_ok=True, parents=True) # For each run in ensemble + run_idx: int for run_idx in tqdm(range(n_runs), desc="Ensemble runs"): - run_batch_dir = batches_base_dir / f"run_{run_idx}" + run_batch_dir: Path = batches_base_dir / f"run_{run_idx}" run_batch_dir.mkdir(exist_ok=True) # Generate batches for this run + batch_idx: int for batch_idx in tqdm( range(n_batches_needed), desc=f" Run {run_idx} batches", @@ -204,10 +211,10 @@ def precompute_batches_for_ensemble( ): # Use unique seed: base_seed + run_idx * 1000 + batch_idx # This ensures different data for each run and each batch - seed = clustering_run_config.dataset_seed + run_idx * 1000 + batch_idx + seed: int = clustering_run_config.dataset_seed + run_idx * 1000 + batch_idx # Load data - batch_data = load_dataset( + batch_data: BatchTensor = load_dataset( model_path=clustering_run_config.model_path, task_name=task_name, batch_size=clustering_run_config.batch_size, @@ -216,10 +223,10 @@ def precompute_batches_for_ensemble( # Compute activations with torch.no_grad(): - acts_dict = component_activations(model, device, batch_data) + acts_dict: dict[str, Float[Tensor, "samples components"]] = component_activations(model, device, batch_data) # Process (concat, NO FILTERING) - processed = process_activations( + processed: ProcessedActivations = process_activations( activations=acts_dict, filter_dead_threshold=0.0, # NO FILTERING seq_mode="concat" if task_name == "lm" else None, @@ -227,7 +234,7 @@ def precompute_batches_for_ensemble( ) # Save as ActivationBatch - activation_batch = ActivationBatch( + activation_batch: ActivationBatch = ActivationBatch( activations=processed.activations.cpu(), # Move to CPU for storage labels=list(processed.labels), ) diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 6a5aca85c..c654ff78b 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -8,7 +8,7 @@ from typing import Protocol import torch -from jaxtyping import Bool, Float +from jaxtyping import Bool, Float, Int from torch import Tensor from tqdm import tqdm @@ -47,28 +47,28 @@ def recompute_coacts_from_scratch( mask [samples, k_groups] for current groups """ # Apply threshold - activation_mask = ( + activation_mask: Bool[Tensor, "samples n_components"] = ( activations > activation_threshold if activation_threshold is not None else activations ) # Map component-level activations to group-level using scatter_add # This is more efficient than materializing the full merge matrix # current_merge.group_idxs: [n_components] with values 0 to k_groups-1 - n_samples = activation_mask.shape[0] - group_activations = torch.zeros( + n_samples: int = activation_mask.shape[0] + group_activations: Float[Tensor, "n_samples k_groups"] = torch.zeros( (n_samples, current_merge.k_groups), dtype=activation_mask.dtype, device=activation_mask.device, ) # Expand group_idxs to match batch dimension and scatter-add activations by group - group_idxs_expanded = ( + group_idxs_expanded: Int[Tensor, "n_samples n_components"] = ( current_merge.group_idxs.unsqueeze(0).expand(n_samples, -1).to(activation_mask.device) ) group_activations.scatter_add_(1, group_idxs_expanded, activation_mask) # Compute coactivations - coact = group_activations.float().T @ group_activations.float() + coact: ClusterCoactivationShaped = group_activations.float().T @ group_activations.float() return coact, group_activations @@ -205,7 +205,7 @@ def merge_iteration( if should_recompute: new_batch: ActivationBatch = batched_activations._get_next_batch() - activations = new_batch.activations + activations: Float[Tensor, "samples n_components"] = new_batch.activations # Recompute fresh coacts with current merge groups current_coact, current_act_mask = recompute_coacts_from_scratch( diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 4241c43ee..fc9f24627 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -219,16 +219,16 @@ def main(run_config: ClusteringRunConfig) -> Path: """ # Create ExecutionStamp and storage # don't create git snapshot -- if we are part of an ensemble, the snapshot should be created by the pipeline - execution_stamp = ExecutionStamp.create( + execution_stamp: ExecutionStamp = ExecutionStamp.create( run_type="cluster", create_snapshot=False, ) - storage = ClusteringRunStorage(execution_stamp) - clustering_run_id = execution_stamp.run_id + storage: ClusteringRunStorage = ClusteringRunStorage(execution_stamp) + clustering_run_id: str = execution_stamp.run_id logger.info(f"Clustering run ID: {clustering_run_id}") # Register with ensemble if this is part of a pipeline - assigned_idx: int | None + assigned_idx: int | None = None if run_config.ensemble_id: assigned_idx = register_clustering_run( pipeline_run_id=run_config.ensemble_id, @@ -243,8 +243,6 @@ def main(run_config: ClusteringRunConfig) -> Path: run_config, {"dataset_seed": run_config.dataset_seed + assigned_idx}, ) - else: - assigned_idx = None # save config run_config.to_file(storage.config_path) @@ -255,7 +253,7 @@ def main(run_config: ClusteringRunConfig) -> Path: logger.info(f"Output directory: {storage.base_dir}") device = get_device() - spd_run = SPDRunInfo.from_path(run_config.model_path) + spd_run: SPDRunInfo = SPDRunInfo.from_path(run_config.model_path) task_name: TaskName = spd_run.config.task_config.task_name # Setup WandB for this run @@ -287,7 +285,7 @@ def main(run_config: ClusteringRunConfig) -> Path: batched_activations = BatchedActivations(run_config.precomputed_activations_dir) # Get labels from first batch - first_batch = batched_activations._get_next_batch() + first_batch: ActivationBatch = batched_activations._get_next_batch() component_labels = ComponentLabels(first_batch.labels) logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") @@ -298,7 +296,7 @@ def main(run_config: ClusteringRunConfig) -> Path: # Load model logger.info("Loading model") - model = ComponentModel.from_run_info(spd_run).to(device) + model: ComponentModel = ComponentModel.from_run_info(spd_run).to(device) # Load data logger.info("Loading dataset") @@ -338,10 +336,10 @@ def main(run_config: ClusteringRunConfig) -> Path: ) # Save as single batch to temp dir - temp_batch_dir = storage.base_dir / "temp_batch" + temp_batch_dir: Path = storage.base_dir / "temp_batch" temp_batch_dir.mkdir(exist_ok=True) - single_batch = ActivationBatch( + single_batch: ActivationBatch = ActivationBatch( activations=processed.activations, labels=list(processed.labels), ) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index cf4971a63..92503229f 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -261,7 +261,7 @@ def main( logger.info(f"Pipeline run ID: {pipeline_run_id}") # Initialize storage - storage = ClusteringPipelineStorage(execution_stamp) + storage: ClusteringPipelineStorage = ClusteringPipelineStorage(execution_stamp) logger.info(f"Pipeline output directory: {storage.base_dir}") # Save pipeline config @@ -270,7 +270,7 @@ def main( # Create WandB workspace if requested if pipeline_config.wandb_project is not None: - workspace_url = create_clustering_workspace_view( + workspace_url: str = create_clustering_workspace_view( ensemble_id=pipeline_run_id, project=pipeline_config.wandb_project, entity=pipeline_config.wandb_entity, @@ -278,17 +278,17 @@ def main( logger.info(f"WandB workspace: {workspace_url}") # Precompute batches if multi-batch mode - clustering_run_config = ClusteringRunConfig.from_file( + clustering_run_config: ClusteringRunConfig = ClusteringRunConfig.from_file( pipeline_config.clustering_run_config_path ) - batches_base_dir = precompute_batches_for_ensemble( + batches_base_dir: Path | None = precompute_batches_for_ensemble( clustering_run_config=clustering_run_config, n_runs=pipeline_config.n_runs, output_dir=storage.base_dir, ) # Generate commands for clustering runs - clustering_commands = generate_clustering_commands( + clustering_commands: list[str] = generate_clustering_commands( pipeline_config=pipeline_config, pipeline_run_id=pipeline_run_id, batches_base_dir=batches_base_dir, @@ -296,7 +296,7 @@ def main( ) # Generate commands for calculating distances - calc_distances_commands = generate_calc_distances_commands( + calc_distances_commands: list[str] = generate_calc_distances_commands( pipeline_run_id=pipeline_run_id, distances_methods=pipeline_config.distances_methods, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 8a3bbd033..5261484b3 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -18,12 +18,14 @@ ) from spd.clustering.batched_activations import BatchedActivations from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.consts import DistancesArray from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution +from spd.configs import Config from spd.models.component_model import ComponentModel, SPDRunInfo DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -44,7 +46,7 @@ SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) MODEL.to(DEVICE) -SPD_CONFIG = SPD_RUN.config +SPD_CONFIG: Config = SPD_RUN.config # Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( @@ -112,7 +114,7 @@ ENSEMBLE_SIZE: int = 2 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): - batched_acts = BatchedActivations.from_tensor( + batched_acts: BatchedActivations = BatchedActivations.from_tensor( activations=PROCESSED_ACTIVATIONS.activations, labels=list(PROCESSED_ACTIVATIONS.labels), ) @@ -130,7 +132,7 @@ # %% # Compute and plot distances # ============================================================ -DISTANCES = ENSEMBLE.get_distances() +DISTANCES: DistancesArray = ENSEMBLE.get_distances() plot_dists_distribution( distances=DISTANCES, diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 09c1596bf..47db195b7 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -6,6 +6,7 @@ from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory class TestMergeIntegration: @@ -14,13 +15,13 @@ class TestMergeIntegration: def test_merge_with_range_sampler(self): """Test merge iteration with range sampler.""" # Create test data - n_samples = 100 - n_components = 10 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + n_samples: int = 100 + n_components: int = 10 + activations: torch.Tensor = torch.rand(n_samples, n_components) + component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) # Configure with range sampler - config = MergeConfig( + config: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=5, @@ -30,10 +31,10 @@ def test_merge_with_range_sampler(self): ) # Run merge iteration - batched_activations = BatchedActivations.from_tensor( + batched_activations: BatchedActivations = BatchedActivations.from_tensor( activations=activations, labels=list(component_labels) ) - history = merge_iteration( + history: MergeHistory = merge_iteration( batched_activations=batched_activations, merge_config=config, component_labels=component_labels, @@ -52,13 +53,13 @@ def test_merge_with_range_sampler(self): def test_merge_with_mcmc_sampler(self): """Test merge iteration with MCMC sampler.""" # Create test data - n_samples = 100 - n_components = 10 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + n_samples: int = 100 + n_components: int = 10 + activations: torch.Tensor = torch.rand(n_samples, n_components) + component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) # Configure with MCMC sampler - config = MergeConfig( + config: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=5, @@ -68,10 +69,10 @@ def test_merge_with_mcmc_sampler(self): ) # Run merge iteration - batched_activations = BatchedActivations.from_tensor( + batched_activations: BatchedActivations = BatchedActivations.from_tensor( activations=activations, labels=list(component_labels) ) - history = merge_iteration( + history: MergeHistory = merge_iteration( batched_activations=batched_activations, merge_config=config, component_labels=component_labels, @@ -89,18 +90,18 @@ def test_merge_with_mcmc_sampler(self): def test_merge_comparison_samplers(self): """Compare behavior of different samplers with same data.""" # Create test data with clear structure - n_samples = 100 - n_components = 8 - activations = torch.rand(n_samples, n_components) + n_samples: int = 100 + n_components: int = 8 + activations: torch.Tensor = torch.rand(n_samples, n_components) # Make some components more active to create cost structure activations[:, 0] *= 2 # Component 0 is very active activations[:, 1] *= 0.1 # Component 1 is rarely active - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) # Run with range sampler (threshold=0 for deterministic minimum selection) - config_range = MergeConfig( + config_range: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=3, @@ -108,17 +109,17 @@ def test_merge_comparison_samplers(self): merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum ) - batched_activations_range = BatchedActivations.from_tensor( + batched_activations_range: BatchedActivations = BatchedActivations.from_tensor( activations=activations.clone(), labels=list(component_labels) ) - history_range = merge_iteration( + history_range: MergeHistory = merge_iteration( batched_activations=batched_activations_range, merge_config=config_range, component_labels=ComponentLabels(component_labels.copy()), ) # Run with MCMC sampler (low temperature for near-deterministic) - config_mcmc = MergeConfig( + config_mcmc: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=3, @@ -126,10 +127,10 @@ def test_merge_comparison_samplers(self): merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp ) - batched_activations_mcmc = BatchedActivations.from_tensor( + batched_activations_mcmc: BatchedActivations = BatchedActivations.from_tensor( activations=activations.clone(), labels=list(component_labels) ) - history_mcmc = merge_iteration( + history_mcmc: MergeHistory = merge_iteration( batched_activations=batched_activations_mcmc, merge_config=config_mcmc, component_labels=ComponentLabels(component_labels.copy()), @@ -144,12 +145,12 @@ def test_merge_comparison_samplers(self): def test_merge_with_small_components(self): """Test merge with very few components.""" # Edge case: only 3 components - n_samples = 50 - n_components = 3 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + n_samples: int = 50 + n_components: int = 3 + activations: torch.Tensor = torch.rand(n_samples, n_components) + component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) - config = MergeConfig( + config: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=1, # Just one merge @@ -157,10 +158,10 @@ def test_merge_with_small_components(self): merge_pair_sampling_kwargs={"temperature": 2.0}, ) - batched_activations = BatchedActivations.from_tensor( + batched_activations: BatchedActivations = BatchedActivations.from_tensor( activations=activations, labels=list(component_labels) ) - history = merge_iteration( + history: MergeHistory = merge_iteration( batched_activations=batched_activations, merge_config=config, component_labels=component_labels, From 2a18389f4d435bda31b0b2c73cb4601e4f1e0848 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 14:15:55 +0000 Subject: [PATCH 42/61] make format --- spd/clustering/batched_activations.py | 12 +++++++++--- tests/clustering/test_merge_integration.py | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 31d394d5a..4768ec2a5 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -125,7 +125,9 @@ def from_tensor( temp_dir: Path = Path(tempfile.mkdtemp(prefix="batch_temp_")) # Save the single batch - batch: ActivationBatch = ActivationBatch(activations=activations, labels=ComponentLabels(labels)) + batch: ActivationBatch = ActivationBatch( + activations=activations, labels=ComponentLabels(labels) + ) batch.save(temp_dir / _BATCH_FORMAT.format(idx=0)) # Return BatchedActivations that will cycle through this single batch @@ -178,7 +180,9 @@ def precompute_batches_for_ensemble( ).to(device) with torch.no_grad(): - sample_acts: dict[str, Float[Tensor, "samples components"]] = component_activations(model, device, sample_batch) + sample_acts: dict[str, Float[Tensor, "samples components"]] = component_activations( + model, device, sample_batch + ) # Count total components across all modules n_components: int = sum(act.shape[-1] for act in sample_acts.values()) @@ -223,7 +227,9 @@ def precompute_batches_for_ensemble( # Compute activations with torch.no_grad(): - acts_dict: dict[str, Float[Tensor, "samples components"]] = component_activations(model, device, batch_data) + acts_dict: dict[str, Float[Tensor, "samples components"]] = component_activations( + model, device, batch_data + ) # Process (concat, NO FILTERING) processed: ProcessedActivations = process_activations( diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 47db195b7..89bbb470d 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -18,7 +18,9 @@ def test_merge_with_range_sampler(self): n_samples: int = 100 n_components: int = 10 activations: torch.Tensor = torch.rand(n_samples, n_components) - component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) # Configure with range sampler config: MergeConfig = MergeConfig( @@ -56,7 +58,9 @@ def test_merge_with_mcmc_sampler(self): n_samples: int = 100 n_components: int = 10 activations: torch.Tensor = torch.rand(n_samples, n_components) - component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) # Configure with MCMC sampler config: MergeConfig = MergeConfig( @@ -98,7 +102,9 @@ def test_merge_comparison_samplers(self): activations[:, 0] *= 2 # Component 0 is very active activations[:, 1] *= 0.1 # Component 1 is rarely active - component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) # Run with range sampler (threshold=0 for deterministic minimum selection) config_range: MergeConfig = MergeConfig( @@ -148,7 +154,9 @@ def test_merge_with_small_components(self): n_samples: int = 50 n_components: int = 3 activations: torch.Tensor = torch.rand(n_samples, n_components) - component_labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) config: MergeConfig = MergeConfig( activation_threshold=0.1, From 28271c82b0e2c9d7561733ce622cfe66f2f17ad5 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 14:21:30 +0000 Subject: [PATCH 43/61] type hint fixes --- spd/clustering/batched_activations.py | 11 +++++---- spd/clustering/consts.py | 2 ++ spd/clustering/merge.py | 31 ++++++++++++------------ spd/clustering/scripts/run_clustering.py | 2 +- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 4768ec2a5..ef2437409 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -11,7 +11,7 @@ from collections.abc import Iterator from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, override import numpy as np import torch @@ -24,7 +24,7 @@ component_activations, process_activations, ) -from spd.clustering.consts import BatchTensor, ComponentLabels +from spd.clustering.consts import ActivationsTensor, BatchTensor, ComponentLabels from spd.clustering.dataset import load_dataset from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -42,7 +42,7 @@ class ActivationBatch: """Single batch of subcomponent activations""" - activations: Float[Tensor, "samples n_components"] + activations: ActivationsTensor labels: ComponentLabels def save(self, path: Path) -> None: @@ -100,6 +100,7 @@ def _get_next_batch(self) -> ActivationBatch: self.current_idx += 1 return batch + @override def __next__(self) -> ActivationBatch: return self._get_next_batch() @@ -227,7 +228,7 @@ def precompute_batches_for_ensemble( # Compute activations with torch.no_grad(): - acts_dict: dict[str, Float[Tensor, "samples components"]] = component_activations( + acts_dict: dict[str, ActivationsTensor] = component_activations( model, device, batch_data ) @@ -242,7 +243,7 @@ def precompute_batches_for_ensemble( # Save as ActivationBatch activation_batch: ActivationBatch = ActivationBatch( activations=processed.activations.cpu(), # Move to CPU for storage - labels=list(processed.labels), + labels=ComponentLabels(list(processed.labels)), ) activation_batch.save(run_batch_dir / f"batch_{batch_idx}.pt") diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py index 8a9647dc8..cc86ff432 100644 --- a/spd/clustering/consts.py +++ b/spd/clustering/consts.py @@ -8,6 +8,8 @@ from jaxtyping import Bool, Float, Int from torch import Tensor +# TODO: docstrings for all types below + # Merge arrays and distances (numpy-based for storage/analysis) MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index c654ff78b..65f6c2d01 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -106,7 +106,7 @@ def merge_iteration( # Load first batch # -------------------------------------------------- first_batch: ActivationBatch = batched_activations._get_next_batch() - activations: Tensor = first_batch.activations + activations: ActivationsTensor = first_batch.activations # Compute initial coactivations # -------------------------------------------------- @@ -199,20 +199,21 @@ def merge_iteration( # Recompute from new batch if it's time # -------------------------------------------------- - should_recompute: bool = ( - iter_idx + 1 - ) % merge_config.recompute_costs_every == 0 and iter_idx + 1 < num_iters - - if should_recompute: - new_batch: ActivationBatch = batched_activations._get_next_batch() - activations: Float[Tensor, "samples n_components"] = new_batch.activations - - # Recompute fresh coacts with current merge groups - current_coact, current_act_mask = recompute_coacts_from_scratch( - activations=activations, - current_merge=current_merge, - activation_threshold=merge_config.activation_threshold, - ) + if merge_config.recompute_costs_every is not None: + should_recompute: bool = ( + (iter_idx + 1) % merge_config.recompute_costs_every == 0 + ) and (iter_idx + 1 < num_iters) + + if should_recompute: + new_batch: ActivationBatch = batched_activations._get_next_batch() + activations = new_batch.activations + + # Recompute fresh coacts with current merge groups + current_coact, current_act_mask = recompute_coacts_from_scratch( + activations=activations, + current_merge=current_merge, + activation_threshold=merge_config.activation_threshold, + ) # Compute metrics for logging # -------------------------------------------------- diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index fc9f24627..31d7c303b 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -341,7 +341,7 @@ def main(run_config: ClusteringRunConfig) -> Path: single_batch: ActivationBatch = ActivationBatch( activations=processed.activations, - labels=list(processed.labels), + labels=ComponentLabels(list(processed.labels)), ) single_batch.save(temp_batch_dir / "batch_0.pt") From 62a7dd080ae6dde15003bfba11543d54802ee13d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 14:22:50 +0000 Subject: [PATCH 44/61] remove old wip file --- spd/clustering/dataset_multibatch.py | 183 --------------------------- 1 file changed, 183 deletions(-) delete mode 100644 spd/clustering/dataset_multibatch.py diff --git a/spd/clustering/dataset_multibatch.py b/spd/clustering/dataset_multibatch.py deleted file mode 100644 index c64b4f08b..000000000 --- a/spd/clustering/dataset_multibatch.py +++ /dev/null @@ -1,183 +0,0 @@ -""" -Loads and splits dataset into batches, returning them as an iterator. -""" - -# TODO: figure out this file vs spd/clustering/dataset.py -from collections.abc import Generator, Iterator -from typing import Any - -import torch -from muutils.spinner import SpinnerContext -from torch import Tensor -from torch.utils.data import DataLoader -from tqdm import tqdm - -from spd.clustering.clustering_run_config import ClusteringRunConfig -from spd.clustering.consts import BatchTensor -from spd.configs import Config -from spd.data import DatasetConfig, create_data_loader -from spd.experiments.lm.configs import LMTaskConfig -from spd.experiments.resid_mlp.configs import ResidMLPModelConfig, ResidMLPTaskConfig -from spd.experiments.resid_mlp.models import ResidMLP -from spd.models.component_model import ComponentModel, SPDRunInfo -from spd.spd_types import TaskName - - -def get_clustering_dataloader( - config: ClusteringRunConfig, - task_name: TaskName, - n_batches: int, - ddp_rank: int = 0, - ddp_world_size: int = 1, - **kwargs: Any, -) -> tuple[Iterator[BatchTensor], dict[str, Any]]: - """Split a dataset into n_batches of batch_size, returning iterator and config""" - ds: Generator[BatchTensor] - ds_config_dict: dict[str, Any] - match task_name: - case "lm": - ds, ds_config_dict = _get_dataloader_lm( - model_path=config.model_path, - batch_size=config.merge_config.batch_size, - ddp_rank=ddp_rank, - ddp_world_size=ddp_world_size, - **kwargs, - ) - case "resid_mlp": - ds, ds_config_dict = _get_dataloader_resid_mlp( - model_path=config.model_path, - batch_size=config.merge_config.batch_size, - ddp_rank=ddp_rank, - ddp_world_size=ddp_world_size, - **kwargs, - ) - case name: - raise ValueError( - f"Unsupported task name '{name}'. Supported tasks are 'lm' and 'resid_mlp'. {config.model_path=}, {name=}" - ) - - # Limit iterator to n_batches - def limited_iterator() -> Iterator[BatchTensor]: - batch_idx: int - batch: BatchTensor - for batch_idx, batch in tqdm(enumerate(ds), total=n_batches, unit="batch"): - if batch_idx >= n_batches: - break - yield batch - - return limited_iterator(), ds_config_dict - - -def _get_dataloader_lm( - model_path: str, - batch_size: int, - config_kwargs: dict[str, Any] | None = None, - ddp_rank: int = 0, - ddp_world_size: int = 1, -) -> tuple[Generator[BatchTensor], dict[str, Any]]: - """split up a SS dataset into n_batches of batch_size, returned the saved paths - - 1. load the config for a SimpleStories SPD Run given by model_path - 2. create a DataLoader for the dataset - 3. iterate over the DataLoader and save each batch to a file - - - """ - with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): - spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) - cfg: Config = spd_run.config - - try: - pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] - assert pretrained_model_name is not None - except Exception as e: - raise AttributeError( - "Could not find 'pretrained_model_name' in the SPD Run config, but called `_get_dataloader_lm`" - ) from e - - assert isinstance(cfg.task_config, LMTaskConfig), ( - f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" - ) - - config_kwargs_: dict[str, Any] = { - **dict( - is_tokenized=False, - streaming=False, - seed=0, - ), - **(config_kwargs or {}), - } - - dataset_config: DatasetConfig = DatasetConfig( - name=cfg.task_config.dataset_name, - hf_tokenizer_path=pretrained_model_name, - split=cfg.task_config.train_data_split, - n_ctx=cfg.task_config.max_seq_len, - column_name=cfg.task_config.column_name, - **config_kwargs_, - ) - - with SpinnerContext(message="getting dataloader..."): - dataloader: DataLoader[dict[str, torch.Tensor]] - dataloader, _tokenizer = create_data_loader( - dataset_config=dataset_config, - batch_size=batch_size, - buffer_size=cfg.task_config.buffer_size, - global_seed=cfg.seed, - ddp_rank=ddp_rank, - ddp_world_size=ddp_world_size, - ) - - return (batch["input_ids"] for batch in dataloader), dataset_config.model_dump(mode="json") - - -def _get_dataloader_resid_mlp( - model_path: str, - batch_size: int, - ddp_rank: int = 0, - ddp_world_size: int = 1, -) -> tuple[Generator[torch.Tensor], dict[str, Any]]: - """Split a ResidMLP dataset into n_batches of batch_size and save the batches.""" - from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset - from spd.utils.data_utils import DatasetGeneratedDataLoader - - # TODO: this is a hack. idk what the best way to handle this is - shuffle_data: bool = ddp_world_size <= 1 - - assert ddp_rank >= 0 - - with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): - spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) - # SPD_RUN = SPDRunInfo.from_path(EXPERIMENT_REGISTRY["resid_mlp3"].canonical_run) - component_model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) - cfg: Config = spd_run.config - - with SpinnerContext(message="Creating ResidMLPDataset..."): - assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( - f"Expected task_config to be of type ResidMLPTaskConfig since using `_get_dataloader_resid_mlp`, but got {type(cfg.task_config) = }" - ) - assert isinstance(component_model.target_model, ResidMLP), ( - f"Expected patched_model to be of type ResidMLP since using `_get_dataloader_resid_mlp`, but got {type(component_model.patched_model) = }" - ) - - assert isinstance(component_model.target_model.config, ResidMLPModelConfig), ( - f"Expected patched_model.config to be of type ResidMLPModelConfig since using `_get_dataloader_resid_mlp`, but got {type(component_model.target_model.config) = }" - ) - resid_mlp_dataset_kwargs: dict[str, Any] = dict( - n_features=component_model.target_model.config.n_features, - feature_probability=cfg.task_config.feature_probability, - device="cpu", - calc_labels=False, - label_type=None, - act_fn_name=None, - label_fn_seed=None, - label_coeffs=None, - data_generation_type=cfg.task_config.data_generation_type, - ) - dataset: ResidMLPDataset = ResidMLPDataset(**resid_mlp_dataset_kwargs) - - dataloader: DatasetGeneratedDataLoader[tuple[Tensor, Tensor]] = DatasetGeneratedDataLoader( - dataset, batch_size=batch_size, shuffle=shuffle_data - ) - - return (batch[0] for batch in dataloader), resid_mlp_dataset_kwargs From 24d4b8cbda8ee66f7a7d4551ac6e395c657c69fa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 14:55:35 +0000 Subject: [PATCH 45/61] wip --- spd/clustering/batched_activations.py | 1 - tests/clustering/scripts/cluster_resid_mlp.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index ef2437409..97e14a3f4 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -232,7 +232,6 @@ def precompute_batches_for_ensemble( model, device, batch_data ) - # Process (concat, NO FILTERING) processed: ProcessedActivations = process_activations( activations=acts_dict, filter_dead_threshold=0.0, # NO FILTERING diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index 6a2d25d53..3013ea150 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -75,7 +75,9 @@ data_generation_type=DATASET.data_generation_type, ) ) -DATALOADER = DatasetGeneratedDataLoader(DATASET, batch_size=N_SAMPLES, shuffle=False) +DATALOADER: DatasetGeneratedDataLoader = DatasetGeneratedDataLoader( + DATASET, batch_size=N_SAMPLES, shuffle=False +) # %% # Get component activations @@ -149,7 +151,7 @@ def _plot_func( ) -BATCHED_ACTIVATIONS = BatchedActivations.from_tensor( +BATCHED_ACTIVATIONS: BatchedActivations = BatchedActivations.from_tensor( activations=PROCESSED_ACTIVATIONS.activations, labels=list(PROCESSED_ACTIVATIONS.labels), ) @@ -177,7 +179,7 @@ def _plot_func( ENSEMBLE_SIZE: int = 4 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): - batched_acts = BatchedActivations.from_tensor( + batched_acts: BatchedActivations = BatchedActivations.from_tensor( activations=PROCESSED_ACTIVATIONS.activations, labels=list(PROCESSED_ACTIVATIONS.labels), ) From 8491c324baa855e07c3a2925ffed4fc95befcb8a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:04:00 +0000 Subject: [PATCH 46/61] update STYLE.md claude really loves to be lazy and not include type hints even in cases where they are useful. claude also loves to then write code that doesnt work by hallucinating the types of those variables --- STYLE.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/STYLE.md b/STYLE.md index d94a92ed5..fcf39bc56 100644 --- a/STYLE.md +++ b/STYLE.md @@ -18,10 +18,12 @@ If there's an assumption you're making while writing code, assert it. - If you were wrong, then the code **should** fail. ## Type Annotations -- Use jaxtyping for tensor shapes (though for now we don't do runtime checking) +- **Always** use jaxtyping for tensor or numpy array shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. - Use `dict`, `list` and `tuple` not `Dict`, `List` and `Tuple` -- Don't add type annotations when they're redundant. (i.e. `my_thing: Thing = Thing()` or `name: str = "John Doe"`) +- Don't add type annotations only when they're redundant. + - i.e. `my_thing: Thing = Thing()` or `name: str = "John Doe"` don't need type annotations. + - however, `var = foo()` or `result = thing.bar()` should be annotated! ## Tensor Operations - Try to use einops by default for clarity. From 40df68285fc633576de2fb0da408edc604f0ff32 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:12:44 +0000 Subject: [PATCH 47/61] more updates to STYLE.md claude loves to waste time doing what ruff does instantly --- STYLE.md | 1 + 1 file changed, 1 insertion(+) diff --git a/STYLE.md b/STYLE.md index fcf39bc56..b5dc38b5e 100644 --- a/STYLE.md +++ b/STYLE.md @@ -24,6 +24,7 @@ If there's an assumption you're making while writing code, assert it. - Don't add type annotations only when they're redundant. - i.e. `my_thing: Thing = Thing()` or `name: str = "John Doe"` don't need type annotations. - however, `var = foo()` or `result = thing.bar()` should be annotated! +- FOR CLAUDE: don't worry about cleaning up unused imports, we do this automatically with ruff using `make format` ## Tensor Operations - Try to use einops by default for clarity. From 89944902592d8e41ec09f44eefc24ade579b4f43 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:13:00 +0000 Subject: [PATCH 48/61] refactor getting batches there was a lot of duplications previously --- spd/clustering/batched_activations.py | 243 ++++++++++++++++------- spd/clustering/scripts/run_clustering.py | 116 +++-------- 2 files changed, 191 insertions(+), 168 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 97e14a3f4..a69e0007f 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -26,6 +26,7 @@ ) from spd.clustering.consts import ActivationsTensor, BatchTensor, ComponentLabels from spd.clustering.dataset import load_dataset +from spd.clustering.util import ModuleFilterFunc from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.spd_types import TaskName @@ -135,34 +136,96 @@ def from_tensor( return BatchedActivations(temp_dir) -def precompute_batches_for_ensemble( +def _generate_activation_batches( + model: ComponentModel, + device: str, + task_name: TaskName, + model_path: str, + batch_size: int, + n_batches: int, + output_dir: Path, + base_seed: int, + filter_dead_threshold: float, + filter_modules: ModuleFilterFunc | None, +) -> None: + """Core function to generate activation batches. + + Args: + model: ComponentModel to compute activations + device: Device to use for computation + task_name: Task name for dataset loading + model_path: Path to model for dataset loading (as string) + batch_size: Batch size for dataset + n_batches: Number of batches to generate + output_dir: Directory to save batches + base_seed: Base seed for dataset loading (each batch gets base_seed + batch_idx) + filter_dead_threshold: Threshold for filtering dead components (0.0 = no filtering) + filter_modules: Module filter function (None = no filtering) + """ + + batch_idx: int + for batch_idx in tqdm(range(n_batches), desc="Generating batches", leave=False): + # Use unique seed for each batch + seed: int = base_seed + batch_idx + + # Load data + batch_data: BatchTensor = load_dataset( + model_path=model_path, + task_name=task_name, + batch_size=batch_size, + seed=seed, + ).to(device) + + # Compute activations + with torch.no_grad(): + acts_dict: dict[str, ActivationsTensor] = component_activations( + model, device, batch_data + ) + + # Process activations + processed: ProcessedActivations = process_activations( + activations=acts_dict, + filter_dead_threshold=filter_dead_threshold, + seq_mode="concat" if task_name == "lm" else None, + filter_modules=filter_modules, + ) + + # Save as ActivationBatch + activation_batch: ActivationBatch = ActivationBatch( + activations=processed.activations.cpu(), # Move to CPU for storage + labels=ComponentLabels(list(processed.labels)), + ) + activation_batch.save(output_dir / _BATCH_FORMAT.format(idx=batch_idx)) + + # Clean up + del batch_data, acts_dict, processed, activation_batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def precompute_batches_for_single_run( clustering_run_config: "ClusteringRunConfig", - n_runs: int, output_dir: Path, -) -> Path | None: + base_seed: int, + apply_filtering: bool, +) -> int: """ - Precompute activation batches for all runs in ensemble. + Precompute activation batches for a single clustering run. - This loads the model ONCE and generates all batches for all runs, - then saves them to disk. Each clustering run will load batches - from disk without needing the model. + This loads the model ONCE, calculates how many batches are needed + (based on recompute_costs_every and n_iters), generates all batches, + and saves them to disk. Args: - clustering_run_config: Configuration for clustering runs - n_runs: Number of runs in the ensemble - output_dir: Base directory to save precomputed batches + clustering_run_config: Configuration for clustering run + output_dir: Directory to save batches (will contain batch_0000.zip, batch_0001.zip, etc.) + base_seed: Base seed for dataset loading (each batch gets base_seed + batch_idx) + apply_filtering: Whether to apply filter_dead_threshold and filter_modules from config Returns: - Path to base directory containing batches for all runs, - or None if single-batch mode (recompute_costs_every=1) + Number of batches generated """ - # Check if multi-batch mode - recompute_every: int | None = clustering_run_config.merge_config.recompute_costs_every - if recompute_every is None: - logger.info("Single-batch mode (recompute_costs_every=`None`), skipping precomputation") - return None - - logger.info("Multi-batch mode detected, precomputing activation batches") + output_dir.mkdir(exist_ok=True, parents=True) # Load model to determine number of components device: str = get_device() @@ -170,8 +233,7 @@ def precompute_batches_for_ensemble( model: ComponentModel = ComponentModel.from_run_info(spd_run).to(device) task_name: TaskName = spd_run.config.task_config.task_name - # Get number of components (no filtering, so just count from model) - # Load a sample to count components + # Load a sample batch to count components logger.info("Loading sample batch to count components") sample_batch: BatchTensor = load_dataset( model_path=clustering_run_config.model_path, @@ -188,75 +250,104 @@ def precompute_batches_for_ensemble( # Count total components across all modules n_components: int = sum(act.shape[-1] for act in sample_acts.values()) - # Calculate number of iterations + # Calculate number of iterations and batches needed n_iters: int = clustering_run_config.merge_config.get_num_iters(n_components) + recompute_every: int | None = clustering_run_config.merge_config.recompute_costs_every + + n_batches_needed: int + if recompute_every is None: + # Single-batch mode: generate 1 batch, reuse for all iterations + n_batches_needed = 1 + logger.info(f"Single-batch mode: generating 1 batch for {n_iters} iterations") + else: + # Multi-batch mode: generate enough batches to cover all iterations + n_batches_needed = (n_iters + recompute_every - 1) // recompute_every + logger.info( + f"Multi-batch mode: generating {n_batches_needed} batches for {n_iters} iterations (recompute_every={recompute_every})" + ) + + # Determine filtering parameters + filter_dead_threshold: float + filter_modules: ModuleFilterFunc | None + if apply_filtering: + filter_dead_threshold = clustering_run_config.merge_config.filter_dead_threshold + filter_modules = clustering_run_config.merge_config.filter_modules + else: + filter_dead_threshold = 0.0 + filter_modules = None + + # Generate batches + _generate_activation_batches( + model=model, + device=device, + task_name=task_name, + model_path=clustering_run_config.model_path, + batch_size=clustering_run_config.batch_size, + n_batches=n_batches_needed, + output_dir=output_dir, + base_seed=base_seed, + filter_dead_threshold=filter_dead_threshold, + filter_modules=filter_modules, + ) + + # Clean up model + del model, sample_batch, sample_acts + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info(f"Generated {n_batches_needed} batches and saved to {output_dir}") + return n_batches_needed + + +def precompute_batches_for_ensemble( + clustering_run_config: "ClusteringRunConfig", + n_runs: int, + output_dir: Path, +) -> Path | None: + """ + Precompute activation batches for all runs in ensemble. - # Calculate batches needed per run - n_batches_needed: int = (n_iters + recompute_every - 1) // recompute_every + This generates all batches for all runs by calling precompute_batches_for_single_run() + for each run with a unique seed offset. - logger.info(f"Precomputing {n_batches_needed} batches per run for {n_runs} runs") - logger.info(f"Total: {n_batches_needed * n_runs} batches") + Args: + clustering_run_config: Configuration for clustering runs + n_runs: Number of runs in the ensemble + output_dir: Base directory to save precomputed batches + + Returns: + Path to base directory containing batches for all runs, + or None if single-batch mode (recompute_costs_every=None) + """ + # Check if multi-batch mode + recompute_every: int | None = clustering_run_config.merge_config.recompute_costs_every + if recompute_every is None: + logger.info("Single-batch mode (recompute_costs_every=`None`), skipping precomputation") + return None + + logger.info("Multi-batch mode detected, precomputing activation batches") # Create batches directory batches_base_dir: Path = output_dir / "precomputed_batches" batches_base_dir.mkdir(exist_ok=True, parents=True) - # For each run in ensemble + # Generate batches for each run run_idx: int for run_idx in tqdm(range(n_runs), desc="Ensemble runs"): run_batch_dir: Path = batches_base_dir / f"run_{run_idx}" run_batch_dir.mkdir(exist_ok=True) - # Generate batches for this run - batch_idx: int - for batch_idx in tqdm( - range(n_batches_needed), - desc=f" Run {run_idx} batches", - leave=False, - ): - # Use unique seed: base_seed + run_idx * 1000 + batch_idx - # This ensures different data for each run and each batch - seed: int = clustering_run_config.dataset_seed + run_idx * 1000 + batch_idx - - # Load data - batch_data: BatchTensor = load_dataset( - model_path=clustering_run_config.model_path, - task_name=task_name, - batch_size=clustering_run_config.batch_size, - seed=seed, - ).to(device) - - # Compute activations - with torch.no_grad(): - acts_dict: dict[str, ActivationsTensor] = component_activations( - model, device, batch_data - ) - - processed: ProcessedActivations = process_activations( - activations=acts_dict, - filter_dead_threshold=0.0, # NO FILTERING - seq_mode="concat" if task_name == "lm" else None, - filter_modules=None, - ) + # Use unique seed offset for this run + run_seed: int = clustering_run_config.dataset_seed + run_idx * 1000 - # Save as ActivationBatch - activation_batch: ActivationBatch = ActivationBatch( - activations=processed.activations.cpu(), # Move to CPU for storage - labels=ComponentLabels(list(processed.labels)), - ) - activation_batch.save(run_batch_dir / f"batch_{batch_idx}.pt") - - # Clean up - del batch_data, acts_dict, processed, activation_batch - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Clean up model - del model, sample_batch, sample_acts - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # Generate all batches for this run (no filtering in ensemble mode) + precompute_batches_for_single_run( + clustering_run_config=clustering_run_config, + output_dir=run_batch_dir, + base_seed=run_seed, + apply_filtering=False, + ) logger.info(f"All batches precomputed and saved to {batches_base_dir}") - return batches_base_dir diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 31d7c303b..ce8cd6067 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -10,7 +10,6 @@ """ import argparse -import gc import os import tempfile from collections.abc import Callable @@ -26,32 +25,27 @@ from torch import Tensor from wandb.sdk.wandb_run import Run -from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, +from spd.clustering.batched_activations import ( + ActivationBatch, + BatchedActivations, + precompute_batches_for_single_run, ) -from spd.clustering.batched_activations import ActivationBatch, BatchedActivations from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( - BatchTensor, ClusterCoactivationShaped, ComponentLabels, ) -from spd.clustering.dataset import load_dataset from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.math.semilog import semilog from spd.clustering.merge import merge_iteration from spd.clustering.merge_history import MergeHistory -from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase from spd.clustering.wandb_tensor_info import wandb_log_tensor from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.component_model import SPDRunInfo from spd.spd_types import TaskName -from spd.utils.distributed_utils import get_device from spd.utils.general_utils import replace_pydantic_model from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str @@ -251,7 +245,6 @@ def main(run_config: ClusteringRunConfig) -> Path: # start logger.info("Starting clustering run") logger.info(f"Output directory: {storage.base_dir}") - device = get_device() spd_run: SPDRunInfo = SPDRunInfo.from_path(run_config.model_path) task_name: TaskName = spd_run.config.task_config.task_name @@ -280,94 +273,33 @@ def main(run_config: ClusteringRunConfig) -> Path: component_labels: ComponentLabels if run_config.precomputed_activations_dir is not None: - # Case 1: Use precomputed batches from disk + # Case 1: Use precomputed batches from disk (from ensemble pipeline) logger.info(f"Loading precomputed batches from {run_config.precomputed_activations_dir}") batched_activations = BatchedActivations(run_config.precomputed_activations_dir) - - # Get labels from first batch - first_batch: ActivationBatch = batched_activations._get_next_batch() - component_labels = ComponentLabels(first_batch.labels) - logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") else: - # Case 2: Compute single batch on-the-fly (original behavior) - logger.info(f"Computing single batch (seed={run_config.dataset_seed})") - - # Load model - logger.info("Loading model") - model: ComponentModel = ComponentModel.from_run_info(spd_run).to(device) - - # Load data - logger.info("Loading dataset") - load_dataset_kwargs: dict[str, Any] = dict() - if run_config.dataset_streaming: - logger.info("Using streaming dataset loading") - load_dataset_kwargs["config_kwargs"] = dict(streaming=True) - assert task_name == "lm", ( - f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." - ) - - batch: BatchTensor = load_dataset( - model_path=run_config.model_path, - task_name=task_name, - batch_size=run_config.batch_size, - seed=run_config.dataset_seed, - **load_dataset_kwargs, - ).to(device) - - # Compute activations - logger.info("Computing activations") - activations_dict: ( - dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] - ) = component_activations( - model=model, - batch=batch, - device=device, - ) - - # Process (concat modules, with filtering) - logger.info("Processing activations") - processed: ProcessedActivations = process_activations( - activations=activations_dict, - filter_dead_threshold=run_config.merge_config.filter_dead_threshold, - seq_mode="concat" if task_name == "lm" else None, - filter_modules=run_config.merge_config.filter_modules, + # Case 2: Generate batches for this single run + logger.info(f"Generating activation batches (seed={run_config.dataset_seed})") + + batch_dir: Path = storage.base_dir / "batches" + batch_dir.mkdir(exist_ok=True) + + # Generate all needed batches (respects recompute_costs_every) + n_batches: int = precompute_batches_for_single_run( + clustering_run_config=run_config, + output_dir=batch_dir, + base_seed=run_config.dataset_seed, + apply_filtering=True, # Apply config filtering for single runs ) - # Save as single batch to temp dir - temp_batch_dir: Path = storage.base_dir / "temp_batch" - temp_batch_dir.mkdir(exist_ok=True) - - single_batch: ActivationBatch = ActivationBatch( - activations=processed.activations, - labels=ComponentLabels(list(processed.labels)), - ) - single_batch.save(temp_batch_dir / "batch_0.pt") - - batched_activations = BatchedActivations(temp_batch_dir) - component_labels = processed.labels - - # Log activations to WandB (if enabled) - if wandb_run is not None: - logger.info("Plotting activations") - plot_activations( - processed_activations=processed, - save_dir=None, - n_samples_max=256, - wandb_run=wandb_run, - ) - wandb_log_tensor( - wandb_run, - processed.activations, - "activations", - 0, - single=True, - ) + # Load batches + batched_activations = BatchedActivations(batch_dir) + logger.info(f"Generated and loaded {n_batches} batches") - # Clean up memory - del model, batch, activations_dict, processed - gc.collect() + # Get labels from first batch + first_batch: ActivationBatch = batched_activations._get_next_batch() + component_labels = ComponentLabels(first_batch.labels) # Run merge iteration # ===================================== From c9a8007a6e6be501e6a794bab9677c59681aa0e0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:24:05 +0000 Subject: [PATCH 49/61] wip --- spd/clustering/batched_activations.py | 18 ++++++++++++++++++ spd/clustering/scripts/run_clustering.py | 7 ++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index a69e0007f..15c906eb6 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -37,6 +37,7 @@ _BATCH_FORMAT: str = "batch_{idx:04}.zip" +_LABELS_FILE: str = "labels.txt" @dataclass @@ -89,6 +90,18 @@ def __init__(self, batch_dir: Path): f"Expected batch file '{expected_name}', found '{actual_name}'" ) + # Load labels from file + labels_path: Path = batch_dir / _LABELS_FILE + assert labels_path.exists(), f"Labels file not found: {labels_path}" + self._labels: ComponentLabels = ComponentLabels( + labels_path.read_text().strip().splitlines() + ) + + @property + def labels(self) -> ComponentLabels: + """Get component labels for all batches.""" + return self._labels + @property def n_batches(self) -> int: return len(self.batch_paths) @@ -190,6 +203,11 @@ def _generate_activation_batches( filter_modules=filter_modules, ) + # Save labels file (once, from first batch) + if batch_idx == 0: + labels_path: Path = output_dir / _LABELS_FILE + labels_path.write_text("\n".join(processed.labels)) + # Save as ActivationBatch activation_batch: ActivationBatch = ActivationBatch( activations=processed.activations.cpu(), # Move to CPU for storage diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index ce8cd6067..e09cb201a 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -26,7 +26,6 @@ from wandb.sdk.wandb_run import Run from spd.clustering.batched_activations import ( - ActivationBatch, BatchedActivations, precompute_batches_for_single_run, ) @@ -270,7 +269,6 @@ def main(run_config: ClusteringRunConfig) -> Path: # Load or compute activations # ===================================== batched_activations: BatchedActivations - component_labels: ComponentLabels if run_config.precomputed_activations_dir is not None: # Case 1: Use precomputed batches from disk (from ensemble pipeline) @@ -297,9 +295,8 @@ def main(run_config: ClusteringRunConfig) -> Path: batched_activations = BatchedActivations(batch_dir) logger.info(f"Generated and loaded {n_batches} batches") - # Get labels from first batch - first_batch: ActivationBatch = batched_activations._get_next_batch() - component_labels = ComponentLabels(first_batch.labels) + # Get labels from batches + component_labels: ComponentLabels = batched_activations.labels # Run merge iteration # ===================================== From 299aa4850a1a9e884cf1a455fccf62c689bdc207 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:24:47 +0000 Subject: [PATCH 50/61] pyright fix --- tests/clustering/scripts/cluster_resid_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index 3013ea150..d2c1efdc0 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -75,7 +75,7 @@ data_generation_type=DATASET.data_generation_type, ) ) -DATALOADER: DatasetGeneratedDataLoader = DatasetGeneratedDataLoader( +DATALOADER: DatasetGeneratedDataLoader[Any] = DatasetGeneratedDataLoader( DATASET, batch_size=N_SAMPLES, shuffle=False ) From f028eb4792f4b4d3275adac45ebcd6734b087b4b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:33:42 +0000 Subject: [PATCH 51/61] wip --- spd/clustering/batched_activations.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 15c906eb6..53a1a8cc6 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -139,10 +139,15 @@ def from_tensor( # Create a temporary directory temp_dir: Path = Path(tempfile.mkdtemp(prefix="batch_temp_")) + # Normalize labels + normalized_labels: ComponentLabels = ComponentLabels(labels) + + # Save labels file + labels_path: Path = temp_dir / _LABELS_FILE + labels_path.write_text("\n".join(normalized_labels)) + # Save the single batch - batch: ActivationBatch = ActivationBatch( - activations=activations, labels=ComponentLabels(labels) - ) + batch: ActivationBatch = ActivationBatch(activations=activations, labels=normalized_labels) batch.save(temp_dir / _BATCH_FORMAT.format(idx=0)) # Return BatchedActivations that will cycle through this single batch From 9f2e83f5ff897001d9df98328a4122a7da28f5f9 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 16:47:36 +0000 Subject: [PATCH 52/61] wip, broken --- spd/clustering/batched_activations.py | 38 ++++++++-------------- spd/clustering/merge.py | 36 ++++++++++---------- spd/clustering/scripts/run_clustering.py | 1 - tests/clustering/test_merge_integration.py | 6 ++-- 4 files changed, 35 insertions(+), 46 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 53a1a8cc6..1bb3e354c 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -26,7 +26,6 @@ ) from spd.clustering.consts import ActivationsTensor, BatchTensor, ComponentLabels from spd.clustering.dataset import load_dataset -from spd.clustering.util import ModuleFilterFunc from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.spd_types import TaskName @@ -163,11 +162,12 @@ def _generate_activation_batches( n_batches: int, output_dir: Path, base_seed: int, - filter_dead_threshold: float, - filter_modules: ModuleFilterFunc | None, ) -> None: """Core function to generate activation batches. + Batches are saved WITHOUT filtering - they contain raw/unfiltered activations. + This is required for merge_iteration to correctly recompute costs from fresh batches. + Args: model: ComponentModel to compute activations device: Device to use for computation @@ -177,8 +177,6 @@ def _generate_activation_batches( n_batches: Number of batches to generate output_dir: Directory to save batches base_seed: Base seed for dataset loading (each batch gets base_seed + batch_idx) - filter_dead_threshold: Threshold for filtering dead components (0.0 = no filtering) - filter_modules: Module filter function (None = no filtering) """ batch_idx: int @@ -200,12 +198,14 @@ def _generate_activation_batches( model, device, batch_data ) - # Process activations + # Process activations WITHOUT filtering + # Batches must contain raw/unfiltered activations because merge_iteration + # expects to reload unfiltered data when recomputing costs processed: ProcessedActivations = process_activations( activations=acts_dict, - filter_dead_threshold=filter_dead_threshold, + filter_dead_threshold=0.0, # Never filter when saving batches seq_mode="concat" if task_name == "lm" else None, - filter_modules=filter_modules, + filter_modules=None, # Never filter modules when saving batches ) # Save labels file (once, from first batch) @@ -230,7 +230,6 @@ def precompute_batches_for_single_run( clustering_run_config: "ClusteringRunConfig", output_dir: Path, base_seed: int, - apply_filtering: bool, ) -> int: """ Precompute activation batches for a single clustering run. @@ -239,11 +238,13 @@ def precompute_batches_for_single_run( (based on recompute_costs_every and n_iters), generates all batches, and saves them to disk. + Batches are saved WITHOUT filtering to ensure merge_iteration can correctly + recompute costs from fresh batches. + Args: clustering_run_config: Configuration for clustering run output_dir: Directory to save batches (will contain batch_0000.zip, batch_0001.zip, etc.) base_seed: Base seed for dataset loading (each batch gets base_seed + batch_idx) - apply_filtering: Whether to apply filter_dead_threshold and filter_modules from config Returns: Number of batches generated @@ -289,17 +290,7 @@ def precompute_batches_for_single_run( f"Multi-batch mode: generating {n_batches_needed} batches for {n_iters} iterations (recompute_every={recompute_every})" ) - # Determine filtering parameters - filter_dead_threshold: float - filter_modules: ModuleFilterFunc | None - if apply_filtering: - filter_dead_threshold = clustering_run_config.merge_config.filter_dead_threshold - filter_modules = clustering_run_config.merge_config.filter_modules - else: - filter_dead_threshold = 0.0 - filter_modules = None - - # Generate batches + # Generate batches (no filtering applied) _generate_activation_batches( model=model, device=device, @@ -309,8 +300,6 @@ def precompute_batches_for_single_run( n_batches=n_batches_needed, output_dir=output_dir, base_seed=base_seed, - filter_dead_threshold=filter_dead_threshold, - filter_modules=filter_modules, ) # Clean up model @@ -364,12 +353,11 @@ def precompute_batches_for_ensemble( # Use unique seed offset for this run run_seed: int = clustering_run_config.dataset_seed + run_idx * 1000 - # Generate all batches for this run (no filtering in ensemble mode) + # Generate all batches for this run precompute_batches_for_single_run( clustering_run_config=clustering_run_config, output_dir=run_batch_dir, base_seed=run_seed, - apply_filtering=False, ) logger.info(f"All batches precomputed and saved to {batches_base_dir}") diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 65f6c2d01..78bd1dd70 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -148,6 +148,24 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: + # Recompute from new batch if it's time (do this BEFORE computing costs) + # -------------------------------------------------- + if merge_config.recompute_costs_every is not None: + should_recompute: bool = ( + iter_idx % merge_config.recompute_costs_every == 0 + ) and (iter_idx > 0) + + if should_recompute: + new_batch: ActivationBatch = batched_activations._get_next_batch() + activations = new_batch.activations + + # Recompute fresh coacts with current merge groups + current_coact, current_act_mask = recompute_coacts_from_scratch( + activations=activations, + current_merge=current_merge, + activation_threshold=merge_config.activation_threshold, + ) + # compute costs, figure out what to merge # -------------------------------------------------- # HACK: this is messy @@ -197,24 +215,6 @@ def merge_iteration( current_merge=current_merge, ) - # Recompute from new batch if it's time - # -------------------------------------------------- - if merge_config.recompute_costs_every is not None: - should_recompute: bool = ( - (iter_idx + 1) % merge_config.recompute_costs_every == 0 - ) and (iter_idx + 1 < num_iters) - - if should_recompute: - new_batch: ActivationBatch = batched_activations._get_next_batch() - activations = new_batch.activations - - # Recompute fresh coacts with current merge groups - current_coact, current_act_mask = recompute_coacts_from_scratch( - activations=activations, - current_merge=current_merge, - activation_threshold=merge_config.activation_threshold, - ) - # Compute metrics for logging # -------------------------------------------------- # the MDL loss computed here is the *cost of the current merge*, a single scalar value diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index e09cb201a..110469094 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -288,7 +288,6 @@ def main(run_config: ClusteringRunConfig) -> Path: clustering_run_config=run_config, output_dir=batch_dir, base_seed=run_config.dataset_seed, - apply_filtering=True, # Apply config filtering for single runs ) # Load batches diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 89bbb470d..af33ef1fd 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -29,7 +29,7 @@ def test_merge_with_range_sampler(self): iters=5, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - filter_dead_threshold=0.001, + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) # Run merge iteration @@ -69,7 +69,7 @@ def test_merge_with_mcmc_sampler(self): iters=5, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0}, - filter_dead_threshold=0.001, + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) # Run merge iteration @@ -113,6 +113,7 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) batched_activations_range: BatchedActivations = BatchedActivations.from_tensor( @@ -131,6 +132,7 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) batched_activations_mcmc: BatchedActivations = BatchedActivations.from_tensor( From ea0ff13dd4db057c0ed6cb5fddc80042ba6ca003 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:01:59 +0000 Subject: [PATCH 53/61] wip --- spd/clustering/merge.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 78bd1dd70..a7b2966d0 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -148,23 +148,30 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: - # Recompute from new batch if it's time (do this BEFORE computing costs) + # Recompute from batch if needed (do this BEFORE computing costs) # -------------------------------------------------- - if merge_config.recompute_costs_every is not None: - should_recompute: bool = ( - iter_idx % merge_config.recompute_costs_every == 0 - ) and (iter_idx > 0) + # With NaN masking, we must recompute before every iteration (except first) + # because the coact matrix is invalidated after each merge. + # When recompute_costs_every is set, we cycle through batches; + # otherwise we reuse the same batch. + if iter_idx > 0: + # Check if we should load a new batch + should_load_new_batch: bool = ( + merge_config.recompute_costs_every is not None + and iter_idx % merge_config.recompute_costs_every == 0 + ) - if should_recompute: + if should_load_new_batch: new_batch: ActivationBatch = batched_activations._get_next_batch() activations = new_batch.activations - # Recompute fresh coacts with current merge groups - current_coact, current_act_mask = recompute_coacts_from_scratch( - activations=activations, - current_merge=current_merge, - activation_threshold=merge_config.activation_threshold, - ) + # Always recompute coacts from current activations after iteration 0 + # (needed because NaN masking invalidates the matrix) + current_coact, current_act_mask = recompute_coacts_from_scratch( + activations=activations, + current_merge=current_merge, + activation_threshold=merge_config.activation_threshold, + ) # compute costs, figure out what to merge # -------------------------------------------------- From 6376ffd45f99300218b9f138524f698bcc26ee96 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:24:46 +0000 Subject: [PATCH 54/61] [temp] remove CI timeout --- .github/workflows/checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index dfeafd731..b4034fedf 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest container: image: ghcr.io/${{ github.repository }}/ci-mpi:latest - timeout-minutes: 15 + # timeout-minutes: 15 steps: - name: Checkout uses: actions/checkout@v4 From b4e8d82257b036cd0cfaaab2431eb030319ba8c9 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:39:28 +0000 Subject: [PATCH 55/61] change component indices dtype int16 -> int32 finally got to do the bit from the original comment on this: > if you have more than 32k components, change this to np.int32 > if you have more than 2.1b components, rethink your life choices --- spd/clustering/consts.py | 8 ++++++++ spd/clustering/math/merge_matrix.py | 6 +++--- spd/clustering/merge_history.py | 13 ++++++------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py index cc86ff432..fe2f4ebde 100644 --- a/spd/clustering/consts.py +++ b/spd/clustering/consts.py @@ -5,11 +5,19 @@ from typing import Literal, NewType import numpy as np +import torch from jaxtyping import Bool, Float, Int from torch import Tensor # TODO: docstrings for all types below +ComponentIndexDtype = np.int32 +ComponentIndexDtypeTorch = torch.int32 +# if you have more than 32k components, change this to np.int32 +# if you have more than 2.1b components, rethink your life choices +# note 2025-10-29 10:37 -- obviously we will need to handle components in the billions, +# but this makes the current method infeasible -- we will need to cluster within layers first + # Merge arrays and distances (numpy-based for storage/analysis) MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] diff --git a/spd/clustering/math/merge_matrix.py b/spd/clustering/math/merge_matrix.py index 118f575e2..ddb3ced5f 100644 --- a/spd/clustering/math/merge_matrix.py +++ b/spd/clustering/math/merge_matrix.py @@ -5,7 +5,7 @@ from muutils.tensor_info import array_summary from torch import Tensor -from spd.clustering.consts import GroupIdxsTensor +from spd.clustering.consts import ComponentIndexDtypeTorch, GroupIdxsTensor # pyright: reportUnnecessaryTypeIgnoreComment=false @@ -200,8 +200,8 @@ def summary(self) -> dict[str, int | str | None]: def init_empty(cls, batch_size: int, n_components: int) -> "BatchedGroupMerge": """Initialize an empty BatchedGroupMerge with the given batch size and number of components.""" return cls( - group_idxs=torch.full((batch_size, n_components), -1, dtype=torch.int16), - k_groups=torch.zeros(batch_size, dtype=torch.int16), + group_idxs=torch.full((batch_size, n_components), -1, dtype=ComponentIndexDtypeTorch), + k_groups=torch.zeros(batch_size, dtype=ComponentIndexDtypeTorch), ) @property diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py index bbff78893..7bbbd046b 100644 --- a/spd/clustering/merge_history.py +++ b/spd/clustering/merge_history.py @@ -11,6 +11,7 @@ from muutils.dbg import dbg_tensor from spd.clustering.consts import ( + ComponentIndexDtype, ComponentLabels, DistancesArray, DistancesMethod, @@ -74,7 +75,7 @@ def from_config( return MergeHistory( labels=labels, n_iters_current=0, - selected_pairs=np.full((n_iters_target, 2), -1, dtype=np.int16), + selected_pairs=np.full((n_iters_target, 2), -1, dtype=ComponentIndexDtype), merges=BatchedGroupMerge.init_empty( batch_size=n_iters_target, n_components=n_components ), @@ -108,7 +109,7 @@ def add_iteration( current_merge: GroupMerge, ) -> None: """Add data for one iteration.""" - self.selected_pairs[idx] = np.array(selected_pair, dtype=np.int16) + self.selected_pairs[idx] = np.array(selected_pair, dtype=ComponentIndexDtype) self.merges[idx] = current_merge assert self.n_iters_current == idx @@ -339,9 +340,7 @@ def merges_array(self) -> MergesArray: output: MergesArray = np.full( (n_ens, n_iters, c_components), fill_value=-1, - dtype=np.int16, - # if you have more than 32k components, change this to np.int32 - # if you have more than 2.1b components, rethink your life choices + dtype=ComponentIndexDtype, ) for i_ens, history in enumerate(self.data): for i_iter, merge in enumerate(history.merges): @@ -373,7 +372,7 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]: merges_array: MergesArray = np.full( (self.n_ensemble, self.n_iters_min, c_components), fill_value=-1, - dtype=np.int16, + dtype=ComponentIndexDtype, ) except Exception as e: err_msg = ( @@ -418,7 +417,7 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]: merges_array[i_ens, :, i_comp_new_relabel] = np.full( self.n_iters_min, fill_value=idx_missing + hist_n_components, - dtype=np.int16, + dtype=ComponentIndexDtype, ) # TODO: Consider logging overlap_stats to WandB if run is available From 595c73778b8735c16804340b839a274192d9ab30 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:50:16 +0000 Subject: [PATCH 56/61] fix: allow dataset streaming in precompute_batches_for_single_run --- spd/clustering/batched_activations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 1bb3e354c..50491c756 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -264,6 +264,9 @@ def precompute_batches_for_single_run( task_name=task_name, batch_size=clustering_run_config.batch_size, seed=0, + config_kwargs=dict( + steaming=clustering_run_config.dataset_streaming, + ), ).to(device) with torch.no_grad(): From 791d0d160620b0d41def0dea75a31cd875ca1d9e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:56:50 +0000 Subject: [PATCH 57/61] "steaming" -> "streaming" typo fix ... --- spd/clustering/batched_activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 50491c756..6907740a7 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -265,7 +265,7 @@ def precompute_batches_for_single_run( batch_size=clustering_run_config.batch_size, seed=0, config_kwargs=dict( - steaming=clustering_run_config.dataset_streaming, + streaming=clustering_run_config.dataset_streaming, ), ).to(device) From 58ed42a6607720c18229eea1fa50883656ac73f7 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 11:30:37 +0000 Subject: [PATCH 58/61] various dataset streaming fixes --- spd/clustering/dataset.py | 25 +++++++++++++++++++++++-- spd/clustering/scripts/run_pipeline.py | 14 +++++++++++++- spd/clustering/util.py | 9 ++++++--- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index ea9b9f904..22799e34e 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -3,6 +3,7 @@ Each clustering run loads its own dataset batch, seeded by the run index. """ +import warnings from typing import Any from spd.clustering.consts import BatchTensor @@ -54,7 +55,10 @@ def load_dataset( def _load_lm_batch( - model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None + model_path: str, + batch_size: int, + seed: int, + config_kwargs: dict[str, Any] | None = None, ) -> BatchTensor: """Load a batch for language model task.""" spd_run = SPDRunInfo.from_path(model_path) @@ -102,7 +106,12 @@ def _load_lm_batch( return batch["input_ids"] -def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchTensor: +def _load_resid_mlp_batch( + model_path: str, + batch_size: int, + seed: int, + config_kwargs: dict[str, Any] | None = None, +) -> BatchTensor: """Load a batch for ResidMLP task.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader @@ -118,6 +127,18 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT f"Expected target_model to be of type ResidMLP, but got {type(component_model.target_model) = }" ) + if config_kwargs is not None: + if "streaming" in config_kwargs: + warnings.warn( + "The 'streaming' option is not supported for ResidMLPDataset and will be ignored.", + stacklevel=1, + ) + config_kwargs.pop("streaming") + + assert len(config_kwargs) == 0, ( + f"Unsupported config_kwargs for ResidMLPDataset: {config_kwargs=}" + ) + # Create dataset with run-specific seed dataset = ResidMLPDataset( n_features=component_model.target_model.config.n_features, diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 92503229f..5e3db08dc 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -277,16 +277,28 @@ def main( ) logger.info(f"WandB workspace: {workspace_url}") - # Precompute batches if multi-batch mode clustering_run_config: ClusteringRunConfig = ClusteringRunConfig.from_file( pipeline_config.clustering_run_config_path ) + + # Precompute batches if multi-batch mode + # ========================================================================================== + + # pass streaming to the crc + clustering_run_config = replace_pydantic_model( + clustering_run_config, + {"dataset_streaming": dataset_streaming}, + ) + batches_base_dir: Path | None = precompute_batches_for_ensemble( clustering_run_config=clustering_run_config, n_runs=pipeline_config.n_runs, output_dir=storage.base_dir, ) + # run + # ========================================================================================== + # Generate commands for clustering runs clustering_commands: list[str] = generate_clustering_commands( pipeline_config=pipeline_config, diff --git a/spd/clustering/util.py b/spd/clustering/util.py index bd11e2fd4..0c1300640 100644 --- a/spd/clustering/util.py +++ b/spd/clustering/util.py @@ -8,10 +8,13 @@ def format_scientific_latex(value: float) -> str: import math - exponent: int = int(math.floor(math.log10(abs(value)))) - mantissa: float = value / (10**exponent) + try: + exponent: int = int(math.floor(math.log10(abs(value)))) + mantissa: float = value / (10**exponent) - return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + except Exception: + return f"${value}$" ModuleFilterSource = str | Callable[[str], bool] | set[str] | None From b67c3addc286f0b4d81778d5d697b4287db9fac7 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 12:40:54 +0000 Subject: [PATCH 59/61] refactor to return dataloaders instead of single batches like before --- spd/clustering/batched_activations.py | 65 +++++++++++----------- spd/clustering/dataset.py | 78 ++++++++++++++++++++------- 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 6907740a7..e6d51a8b0 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -25,7 +25,8 @@ process_activations, ) from spd.clustering.consts import ActivationsTensor, BatchTensor, ComponentLabels -from spd.clustering.dataset import load_dataset +from spd.clustering.dataset import create_dataset_loader +from spd.data import loop_dataloader from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.spd_types import TaskName @@ -162,6 +163,7 @@ def _generate_activation_batches( n_batches: int, output_dir: Path, base_seed: int, + dataset_streaming: bool = False, ) -> None: """Core function to generate activation batches. @@ -176,21 +178,36 @@ def _generate_activation_batches( batch_size: Batch size for dataset n_batches: Number of batches to generate output_dir: Directory to save batches - base_seed: Base seed for dataset loading (each batch gets base_seed + batch_idx) + base_seed: Base seed for dataset loading + dataset_streaming: Whether to use streaming for dataset loading """ + # Create dataloader ONCE instead of reloading for each batch + dataloader = create_dataset_loader( + model_path=model_path, + task_name=task_name, + batch_size=batch_size, + seed=base_seed, + config_kwargs=dict( + streaming=dataset_streaming, + ), + ) + + # Use loop_dataloader for efficient iteration that handles exhaustion + batch_iterator = loop_dataloader(dataloader) + batch_idx: int for batch_idx in tqdm(range(n_batches), desc="Generating batches", leave=False): - # Use unique seed for each batch - seed: int = base_seed + batch_idx + # Get next batch from iterator + batch_data_raw = next(batch_iterator) - # Load data - batch_data: BatchTensor = load_dataset( - model_path=model_path, - task_name=task_name, - batch_size=batch_size, - seed=seed, - ).to(device) + # Extract input based on task type + if task_name == "lm": + batch_data: BatchTensor = batch_data_raw["input_ids"].to(device) + elif task_name == "resid_mlp": + batch_data = batch_data_raw[0].to(device) # (batch, labels) tuple + else: + raise ValueError(f"Unsupported task: {task_name}") # Compute activations with torch.no_grad(): @@ -244,7 +261,7 @@ def precompute_batches_for_single_run( Args: clustering_run_config: Configuration for clustering run output_dir: Directory to save batches (will contain batch_0000.zip, batch_0001.zip, etc.) - base_seed: Base seed for dataset loading (each batch gets base_seed + batch_idx) + base_seed: Base seed for dataset loading Returns: Number of batches generated @@ -257,25 +274,8 @@ def precompute_batches_for_single_run( model: ComponentModel = ComponentModel.from_run_info(spd_run).to(device) task_name: TaskName = spd_run.config.task_config.task_name - # Load a sample batch to count components - logger.info("Loading sample batch to count components") - sample_batch: BatchTensor = load_dataset( - model_path=clustering_run_config.model_path, - task_name=task_name, - batch_size=clustering_run_config.batch_size, - seed=0, - config_kwargs=dict( - streaming=clustering_run_config.dataset_streaming, - ), - ).to(device) - - with torch.no_grad(): - sample_acts: dict[str, Float[Tensor, "samples components"]] = component_activations( - model, device, sample_batch - ) - - # Count total components across all modules - n_components: int = sum(act.shape[-1] for act in sample_acts.values()) + # Count total components directly from model (sum C across all component modules) + n_components: int = sum(comp.C for comp in model.components.values()) # Calculate number of iterations and batches needed n_iters: int = clustering_run_config.merge_config.get_num_iters(n_components) @@ -303,10 +303,11 @@ def precompute_batches_for_single_run( n_batches=n_batches_needed, output_dir=output_dir, base_seed=base_seed, + dataset_streaming=clustering_run_config.dataset_streaming, ) # Clean up model - del model, sample_batch, sample_acts + del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index 22799e34e..1b3adc548 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -6,6 +6,8 @@ import warnings from typing import Any +from torch.utils.data import DataLoader + from spd.clustering.consts import BatchTensor from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig @@ -15,16 +17,17 @@ from spd.spd_types import TaskName -def load_dataset( +def create_dataset_loader( model_path: str, task_name: TaskName, batch_size: int, seed: int, **kwargs: Any, -) -> BatchTensor: - """Load a single batch for clustering. +) -> DataLoader[Any]: + """Create a dataloader for clustering that can be iterated multiple times. - Each run gets its own dataset batch, seeded by index in ensemble. + This is more efficient than load_dataset() when you need multiple batches, + as it creates the dataloader once and allows iteration through many batches. Args: model_path: Path to decomposed model @@ -33,18 +36,18 @@ def load_dataset( seed: Random seed for dataset Returns: - Single batch of data + DataLoader that can be iterated to get multiple batches """ match task_name: case "lm": - return _load_lm_batch( + return _create_lm_dataloader( model_path=model_path, batch_size=batch_size, seed=seed, **kwargs, ) case "resid_mlp": - return _load_resid_mlp_batch( + return _create_resid_mlp_dataloader( model_path=model_path, batch_size=batch_size, seed=seed, @@ -54,13 +57,53 @@ def load_dataset( raise ValueError(f"Unsupported task: {task_name}") -def _load_lm_batch( +def load_dataset( model_path: str, + task_name: TaskName, batch_size: int, seed: int, - config_kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> BatchTensor: - """Load a batch for language model task.""" + """Load a single batch for clustering. + + This is a convenience wrapper around create_dataset_loader() that extracts + just the first batch. Use create_dataset_loader() directly if you need + multiple batches for better efficiency. + + Args: + model_path: Path to decomposed model + task_name: Task type + batch_size: Batch size + seed: Random seed for dataset + + Returns: + Single batch of data + """ + dataloader = create_dataset_loader( + model_path=model_path, + task_name=task_name, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + + # Extract first batch based on task type + batch = next(iter(dataloader)) + if task_name == "lm": + return batch["input_ids"] + elif task_name == "resid_mlp": + return batch[0] # ResidMLP returns (batch, labels) tuple + else: + raise ValueError(f"Unsupported task: {task_name}") + + +def _create_lm_dataloader( + model_path: str, + batch_size: int, + seed: int, + config_kwargs: dict[str, Any] | None = None, +) -> DataLoader[Any]: + """Create a dataloader for language model task.""" spd_run = SPDRunInfo.from_path(model_path) cfg = spd_run.config @@ -101,18 +144,16 @@ def _load_lm_batch( ddp_world_size=1, ) - # Get first batch - batch = next(iter(dataloader)) - return batch["input_ids"] + return dataloader -def _load_resid_mlp_batch( +def _create_resid_mlp_dataloader( model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None, -) -> BatchTensor: - """Load a batch for ResidMLP task.""" +) -> DataLoader[Any]: + """Create a dataloader for ResidMLP task.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader @@ -152,7 +193,6 @@ def _load_resid_mlp_batch( data_generation_type=cfg.task_config.data_generation_type, ) - # Generate batch + # Create dataloader dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False) - batch, _ = next(iter(dataloader)) - return batch + return dataloader From 1e2fe430fe56db048e7e6d3770f5bbb2612aee1f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 12:49:58 +0000 Subject: [PATCH 60/61] memory issues? --- spd/clustering/batched_activations.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index e6d51a8b0..571ec13ec 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -238,10 +238,15 @@ def _generate_activation_batches( activation_batch.save(output_dir / _BATCH_FORMAT.format(idx=batch_idx)) # Clean up - del batch_data, acts_dict, processed, activation_batch + del batch_data, batch_data_raw, acts_dict, processed, activation_batch if torch.cuda.is_available(): torch.cuda.empty_cache() + del dataloader, batch_iterator + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + def precompute_batches_for_single_run( clustering_run_config: "ClusteringRunConfig", From c886bd3365a5e7446206c8a66172b1d5e12e8327 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 13:29:58 +0000 Subject: [PATCH 61/61] :( --- spd/clustering/batched_activations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py index 571ec13ec..bd0137ba4 100644 --- a/spd/clustering/batched_activations.py +++ b/spd/clustering/batched_activations.py @@ -237,8 +237,9 @@ def _generate_activation_batches( ) activation_batch.save(output_dir / _BATCH_FORMAT.format(idx=batch_idx)) - # Clean up + # Clean up immediately after saving to avoid memory accumulation del batch_data, batch_data_raw, acts_dict, processed, activation_batch + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()