diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index dfeafd731..b4034fedf 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest container: image: ghcr.io/${{ github.repository }}/ci-mpi:latest - timeout-minutes: 15 + # timeout-minutes: 15 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/STYLE.md b/STYLE.md index d94a92ed5..b5dc38b5e 100644 --- a/STYLE.md +++ b/STYLE.md @@ -18,10 +18,13 @@ If there's an assumption you're making while writing code, assert it. - If you were wrong, then the code **should** fail. ## Type Annotations -- Use jaxtyping for tensor shapes (though for now we don't do runtime checking) +- **Always** use jaxtyping for tensor or numpy array shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. - Use `dict`, `list` and `tuple` not `Dict`, `List` and `Tuple` -- Don't add type annotations when they're redundant. (i.e. `my_thing: Thing = Thing()` or `name: str = "John Doe"`) +- Don't add type annotations only when they're redundant. + - i.e. `my_thing: Thing = Thing()` or `name: str = "John Doe"` don't need type annotations. + - however, `var = foo()` or `result = thing.bar()` should be annotated! +- FOR CLAUDE: don't worry about cleaning up unused imports, we do this automatically with ruff using `make format` ## Tensor Operations - Try to use einops by default for clarity. diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 9e6f14815..000000000 --- a/TODO.md +++ /dev/null @@ -1,73 +0,0 @@ -# TODO: Cluster Coactivation Matrix Implementation - -## What Was Changed - -### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`) -- New dataclass to hold vectorized cluster activations for all clusters -- Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list - -### 2. Added `compute_all_cluster_activations()` function -- Vectorized computation of all cluster activations at once -- Replaces the per-cluster loop for better performance -- Returns `ClusterActivations` object - -### 3. Added `compute_cluster_coactivations()` function -- Computes coactivation matrix from list of `ClusterActivations` across batches -- Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask` -- Follows the pattern from `spd/clustering/merge.py:69` -- Returns tuple of (coactivation_matrix, cluster_indices) - -### 4. Modified `compute_max_activations()` function -- Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list -- Calls `compute_cluster_coactivations()` to compute the matrix -- **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]` - - Added coactivation matrix and cluster_indices to return value - -### 5. Modified `spd/clustering/dashboard/run.py` -- Updated to handle new return value from `compute_max_activations()` -- Saves coactivation matrix as `coactivations.npz` in the dashboard output directory -- NPZ file contains: - - `coactivations`: the [n_clusters, n_clusters] matrix - - `cluster_indices`: array mapping matrix positions to cluster IDs - -## What Needs to be Checked - -### Testing -- [ ] **Run the dashboard pipeline** on a real clustering run to verify: - - Coactivation computation doesn't crash - - Coactivations are saved correctly to NPZ file - - Matrix dimensions are correct - - `cluster_indices` mapping is correct - -### Type Checking -- [ ] Run `make type` to ensure no type errors were introduced -- [ ] Verify jaxtyping annotations are correct - -### Verification -- [ ] Load a saved `coactivations.npz` file and verify: - ```python - data = np.load("coactivations.npz") - coact = data["coactivations"] - cluster_indices = data["cluster_indices"] - # Check: coact should be symmetric - # Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most) - # Check: cluster_indices length should match coact.shape[0] - ``` - -### Performance -- [ ] Check if vectorization actually improved performance -- [ ] Monitor memory usage with large numbers of clusters - -### Edge Cases -- [ ] Test with clusters that have zero activations -- [ ] Test with single-batch runs -- [ ] Test with very large number of clusters - -### Integration -- [ ] Verify the coactivation matrix can be used in downstream analysis -- [ ] Consider if visualization of coactivations should be added to dashboard - -## Notes -- The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples) -- Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable -- The computation happens in the dashboard pipeline, NOT during the main clustering pipeline diff --git a/spd/clustering/batched_activations.py b/spd/clustering/batched_activations.py new file mode 100644 index 000000000..bd0137ba4 --- /dev/null +++ b/spd/clustering/batched_activations.py @@ -0,0 +1,374 @@ +"""Activation batch storage and precomputation for multi-batch clustering. + +This module provides: +1. Data structures for storing and loading activation batches (ActivationBatch, BatchedActivations) +2. Precomputation logic to generate batches for ensemble runs (precompute_batches_for_ensemble) +""" + +import gc +import re +import zipfile +from collections.abc import Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, override + +import numpy as np +import torch +from jaxtyping import Float +from torch import Tensor +from tqdm import tqdm + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.consts import ActivationsTensor, BatchTensor, ComponentLabels +from spd.clustering.dataset import create_dataset_loader +from spd.data import loop_dataloader +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName +from spd.utils.distributed_utils import get_device + +if TYPE_CHECKING: + from spd.clustering.clustering_run_config import ClusteringRunConfig + + +_BATCH_FORMAT: str = "batch_{idx:04}.zip" +_LABELS_FILE: str = "labels.txt" + + +@dataclass +class ActivationBatch: + """Single batch of subcomponent activations""" + + activations: ActivationsTensor + labels: ComponentLabels + + def save(self, path: Path) -> None: + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "w") as zf: + with zf.open("activations.npy", "w") as f: + np.save(f, self.activations.cpu().numpy()) + zf.writestr("labels.txt", "\n".join(self.labels)) + + def save_idx(self, batch_dir: Path, idx: int) -> None: + self.save(batch_dir / _BATCH_FORMAT.format(idx=idx)) + + @staticmethod + def read(path: Path) -> "ActivationBatch": + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "r") as zf: + with zf.open("activations.npy", "r") as f: + activations_np: Float[np.ndarray, "samples n_components"] = np.load(f) + labels_raw: list[str] = zf.read("labels.txt").decode("utf-8").splitlines() + return ActivationBatch( + activations=torch.from_numpy(activations_np), + labels=ComponentLabels(labels_raw), + ) + + +class BatchedActivations(Iterator[ActivationBatch]): + """Iterator over activation batches from disk.""" + + def __init__(self, batch_dir: Path): + self.batch_dir: Path = batch_dir + # Find all batch files + _glob_pattern: str = re.sub(r"\{[^{}]*\}", "*", _BATCH_FORMAT) # returns `batch_*.zip` + self.batch_paths: list[Path] = sorted(batch_dir.glob(_glob_pattern)) + assert len(self.batch_paths) > 0, f"No batch files found in {batch_dir}" + self.current_idx: int = 0 + + # Verify naming + i: int + for i in range(len(self.batch_paths)): + expected_name: str = _BATCH_FORMAT.format(idx=i) + actual_name: str = self.batch_paths[i].name + assert expected_name == actual_name, ( + f"Expected batch file '{expected_name}', found '{actual_name}'" + ) + + # Load labels from file + labels_path: Path = batch_dir / _LABELS_FILE + assert labels_path.exists(), f"Labels file not found: {labels_path}" + self._labels: ComponentLabels = ComponentLabels( + labels_path.read_text().strip().splitlines() + ) + + @property + def labels(self) -> ComponentLabels: + """Get component labels for all batches.""" + return self._labels + + @property + def n_batches(self) -> int: + return len(self.batch_paths) + + def _get_next_batch(self) -> ActivationBatch: + """Load and return next batch, cycling through available batches.""" + batch: ActivationBatch = ActivationBatch.read( + self.batch_paths[self.current_idx % self.n_batches] + ) + self.current_idx += 1 + return batch + + @override + def __next__(self) -> ActivationBatch: + return self._get_next_batch() + + @classmethod + def from_tensor( + cls, activations: Tensor, labels: ComponentLabels | list[str] + ) -> "BatchedActivations": + """Create a BatchedActivations instance from a single activation tensor. + + This is a helper for backward compatibility with tests and code that uses + single-batch mode. It creates a temporary directory with a single batch file. + + Args: + activations: Activation tensor [samples, n_components] + labels: Component labels ["module:idx", ...] + + Returns: + BatchedActivations instance that cycles through the single batch + """ + import tempfile + + # Create a temporary directory + temp_dir: Path = Path(tempfile.mkdtemp(prefix="batch_temp_")) + + # Normalize labels + normalized_labels: ComponentLabels = ComponentLabels(labels) + + # Save labels file + labels_path: Path = temp_dir / _LABELS_FILE + labels_path.write_text("\n".join(normalized_labels)) + + # Save the single batch + batch: ActivationBatch = ActivationBatch(activations=activations, labels=normalized_labels) + batch.save(temp_dir / _BATCH_FORMAT.format(idx=0)) + + # Return BatchedActivations that will cycle through this single batch + return BatchedActivations(temp_dir) + + +def _generate_activation_batches( + model: ComponentModel, + device: str, + task_name: TaskName, + model_path: str, + batch_size: int, + n_batches: int, + output_dir: Path, + base_seed: int, + dataset_streaming: bool = False, +) -> None: + """Core function to generate activation batches. + + Batches are saved WITHOUT filtering - they contain raw/unfiltered activations. + This is required for merge_iteration to correctly recompute costs from fresh batches. + + Args: + model: ComponentModel to compute activations + device: Device to use for computation + task_name: Task name for dataset loading + model_path: Path to model for dataset loading (as string) + batch_size: Batch size for dataset + n_batches: Number of batches to generate + output_dir: Directory to save batches + base_seed: Base seed for dataset loading + dataset_streaming: Whether to use streaming for dataset loading + """ + + # Create dataloader ONCE instead of reloading for each batch + dataloader = create_dataset_loader( + model_path=model_path, + task_name=task_name, + batch_size=batch_size, + seed=base_seed, + config_kwargs=dict( + streaming=dataset_streaming, + ), + ) + + # Use loop_dataloader for efficient iteration that handles exhaustion + batch_iterator = loop_dataloader(dataloader) + + batch_idx: int + for batch_idx in tqdm(range(n_batches), desc="Generating batches", leave=False): + # Get next batch from iterator + batch_data_raw = next(batch_iterator) + + # Extract input based on task type + if task_name == "lm": + batch_data: BatchTensor = batch_data_raw["input_ids"].to(device) + elif task_name == "resid_mlp": + batch_data = batch_data_raw[0].to(device) # (batch, labels) tuple + else: + raise ValueError(f"Unsupported task: {task_name}") + + # Compute activations + with torch.no_grad(): + acts_dict: dict[str, ActivationsTensor] = component_activations( + model, device, batch_data + ) + + # Process activations WITHOUT filtering + # Batches must contain raw/unfiltered activations because merge_iteration + # expects to reload unfiltered data when recomputing costs + processed: ProcessedActivations = process_activations( + activations=acts_dict, + filter_dead_threshold=0.0, # Never filter when saving batches + seq_mode="concat" if task_name == "lm" else None, + filter_modules=None, # Never filter modules when saving batches + ) + + # Save labels file (once, from first batch) + if batch_idx == 0: + labels_path: Path = output_dir / _LABELS_FILE + labels_path.write_text("\n".join(processed.labels)) + + # Save as ActivationBatch + activation_batch: ActivationBatch = ActivationBatch( + activations=processed.activations.cpu(), # Move to CPU for storage + labels=ComponentLabels(list(processed.labels)), + ) + activation_batch.save(output_dir / _BATCH_FORMAT.format(idx=batch_idx)) + + # Clean up immediately after saving to avoid memory accumulation + del batch_data, batch_data_raw, acts_dict, processed, activation_batch + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + del dataloader, batch_iterator + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def precompute_batches_for_single_run( + clustering_run_config: "ClusteringRunConfig", + output_dir: Path, + base_seed: int, +) -> int: + """ + Precompute activation batches for a single clustering run. + + This loads the model ONCE, calculates how many batches are needed + (based on recompute_costs_every and n_iters), generates all batches, + and saves them to disk. + + Batches are saved WITHOUT filtering to ensure merge_iteration can correctly + recompute costs from fresh batches. + + Args: + clustering_run_config: Configuration for clustering run + output_dir: Directory to save batches (will contain batch_0000.zip, batch_0001.zip, etc.) + base_seed: Base seed for dataset loading + + Returns: + Number of batches generated + """ + output_dir.mkdir(exist_ok=True, parents=True) + + # Load model to determine number of components + device: str = get_device() + spd_run: SPDRunInfo = SPDRunInfo.from_path(clustering_run_config.model_path) + model: ComponentModel = ComponentModel.from_run_info(spd_run).to(device) + task_name: TaskName = spd_run.config.task_config.task_name + + # Count total components directly from model (sum C across all component modules) + n_components: int = sum(comp.C for comp in model.components.values()) + + # Calculate number of iterations and batches needed + n_iters: int = clustering_run_config.merge_config.get_num_iters(n_components) + recompute_every: int | None = clustering_run_config.merge_config.recompute_costs_every + + n_batches_needed: int + if recompute_every is None: + # Single-batch mode: generate 1 batch, reuse for all iterations + n_batches_needed = 1 + logger.info(f"Single-batch mode: generating 1 batch for {n_iters} iterations") + else: + # Multi-batch mode: generate enough batches to cover all iterations + n_batches_needed = (n_iters + recompute_every - 1) // recompute_every + logger.info( + f"Multi-batch mode: generating {n_batches_needed} batches for {n_iters} iterations (recompute_every={recompute_every})" + ) + + # Generate batches (no filtering applied) + _generate_activation_batches( + model=model, + device=device, + task_name=task_name, + model_path=clustering_run_config.model_path, + batch_size=clustering_run_config.batch_size, + n_batches=n_batches_needed, + output_dir=output_dir, + base_seed=base_seed, + dataset_streaming=clustering_run_config.dataset_streaming, + ) + + # Clean up model + del model + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info(f"Generated {n_batches_needed} batches and saved to {output_dir}") + return n_batches_needed + + +def precompute_batches_for_ensemble( + clustering_run_config: "ClusteringRunConfig", + n_runs: int, + output_dir: Path, +) -> Path | None: + """ + Precompute activation batches for all runs in ensemble. + + This generates all batches for all runs by calling precompute_batches_for_single_run() + for each run with a unique seed offset. + + Args: + clustering_run_config: Configuration for clustering runs + n_runs: Number of runs in the ensemble + output_dir: Base directory to save precomputed batches + + Returns: + Path to base directory containing batches for all runs, + or None if single-batch mode (recompute_costs_every=None) + """ + # Check if multi-batch mode + recompute_every: int | None = clustering_run_config.merge_config.recompute_costs_every + if recompute_every is None: + logger.info("Single-batch mode (recompute_costs_every=`None`), skipping precomputation") + return None + + logger.info("Multi-batch mode detected, precomputing activation batches") + + # Create batches directory + batches_base_dir: Path = output_dir / "precomputed_batches" + batches_base_dir.mkdir(exist_ok=True, parents=True) + + # Generate batches for each run + run_idx: int + for run_idx in tqdm(range(n_runs), desc="Ensemble runs"): + run_batch_dir: Path = batches_base_dir / f"run_{run_idx}" + run_batch_dir.mkdir(exist_ok=True) + + # Use unique seed offset for this run + run_seed: int = clustering_run_config.dataset_seed + run_idx * 1000 + + # Generate all batches for this run + precompute_batches_for_single_run( + clustering_run_config=clustering_run_config, + output_dir=run_batch_dir, + base_seed=run_seed, + ) + + logger.info(f"All batches precomputed and saved to {batches_base_dir}") + return batches_base_dir diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index 95d72f9bd..1da579488 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -69,6 +69,10 @@ class ClusteringRunConfig(BaseConfig): default=False, description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", ) + precomputed_activations_dir: Path | None = Field( + default=None, + description="Path to directory containing precomputed activation batches. If None, batches will be auto-generated before merging starts.", + ) @model_validator(mode="before") def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py index 8a9647dc8..fe2f4ebde 100644 --- a/spd/clustering/consts.py +++ b/spd/clustering/consts.py @@ -5,9 +5,19 @@ from typing import Literal, NewType import numpy as np +import torch from jaxtyping import Bool, Float, Int from torch import Tensor +# TODO: docstrings for all types below + +ComponentIndexDtype = np.int32 +ComponentIndexDtypeTorch = torch.int32 +# if you have more than 32k components, change this to np.int32 +# if you have more than 2.1b components, rethink your life choices +# note 2025-10-29 10:37 -- obviously we will need to handle components in the billions, +# but this makes the current method infeasible -- we will need to cluster within layers first + # Merge arrays and distances (numpy-based for storage/analysis) MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index ea9b9f904..1b3adc548 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -3,8 +3,11 @@ Each clustering run loads its own dataset batch, seeded by the run index. """ +import warnings from typing import Any +from torch.utils.data import DataLoader + from spd.clustering.consts import BatchTensor from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig @@ -14,16 +17,17 @@ from spd.spd_types import TaskName -def load_dataset( +def create_dataset_loader( model_path: str, task_name: TaskName, batch_size: int, seed: int, **kwargs: Any, -) -> BatchTensor: - """Load a single batch for clustering. +) -> DataLoader[Any]: + """Create a dataloader for clustering that can be iterated multiple times. - Each run gets its own dataset batch, seeded by index in ensemble. + This is more efficient than load_dataset() when you need multiple batches, + as it creates the dataloader once and allows iteration through many batches. Args: model_path: Path to decomposed model @@ -32,18 +36,18 @@ def load_dataset( seed: Random seed for dataset Returns: - Single batch of data + DataLoader that can be iterated to get multiple batches """ match task_name: case "lm": - return _load_lm_batch( + return _create_lm_dataloader( model_path=model_path, batch_size=batch_size, seed=seed, **kwargs, ) case "resid_mlp": - return _load_resid_mlp_batch( + return _create_resid_mlp_dataloader( model_path=model_path, batch_size=batch_size, seed=seed, @@ -53,10 +57,53 @@ def load_dataset( raise ValueError(f"Unsupported task: {task_name}") -def _load_lm_batch( - model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None +def load_dataset( + model_path: str, + task_name: TaskName, + batch_size: int, + seed: int, + **kwargs: Any, ) -> BatchTensor: - """Load a batch for language model task.""" + """Load a single batch for clustering. + + This is a convenience wrapper around create_dataset_loader() that extracts + just the first batch. Use create_dataset_loader() directly if you need + multiple batches for better efficiency. + + Args: + model_path: Path to decomposed model + task_name: Task type + batch_size: Batch size + seed: Random seed for dataset + + Returns: + Single batch of data + """ + dataloader = create_dataset_loader( + model_path=model_path, + task_name=task_name, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + + # Extract first batch based on task type + batch = next(iter(dataloader)) + if task_name == "lm": + return batch["input_ids"] + elif task_name == "resid_mlp": + return batch[0] # ResidMLP returns (batch, labels) tuple + else: + raise ValueError(f"Unsupported task: {task_name}") + + +def _create_lm_dataloader( + model_path: str, + batch_size: int, + seed: int, + config_kwargs: dict[str, Any] | None = None, +) -> DataLoader[Any]: + """Create a dataloader for language model task.""" spd_run = SPDRunInfo.from_path(model_path) cfg = spd_run.config @@ -97,13 +144,16 @@ def _load_lm_batch( ddp_world_size=1, ) - # Get first batch - batch = next(iter(dataloader)) - return batch["input_ids"] + return dataloader -def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchTensor: - """Load a batch for ResidMLP task.""" +def _create_resid_mlp_dataloader( + model_path: str, + batch_size: int, + seed: int, + config_kwargs: dict[str, Any] | None = None, +) -> DataLoader[Any]: + """Create a dataloader for ResidMLP task.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader @@ -118,6 +168,18 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT f"Expected target_model to be of type ResidMLP, but got {type(component_model.target_model) = }" ) + if config_kwargs is not None: + if "streaming" in config_kwargs: + warnings.warn( + "The 'streaming' option is not supported for ResidMLPDataset and will be ignored.", + stacklevel=1, + ) + config_kwargs.pop("streaming") + + assert len(config_kwargs) == 0, ( + f"Unsupported config_kwargs for ResidMLPDataset: {config_kwargs=}" + ) + # Create dataset with run-specific seed dataset = ResidMLPDataset( n_features=component_model.target_model.config.n_features, @@ -131,7 +193,6 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT data_generation_type=cfg.task_config.data_generation_type, ) - # Generate batch + # Create dataloader dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False) - batch, _ = next(iter(dataloader)) - return batch + return dataloader diff --git a/spd/clustering/math/merge_matrix.py b/spd/clustering/math/merge_matrix.py index 118f575e2..ddb3ced5f 100644 --- a/spd/clustering/math/merge_matrix.py +++ b/spd/clustering/math/merge_matrix.py @@ -5,7 +5,7 @@ from muutils.tensor_info import array_summary from torch import Tensor -from spd.clustering.consts import GroupIdxsTensor +from spd.clustering.consts import ComponentIndexDtypeTorch, GroupIdxsTensor # pyright: reportUnnecessaryTypeIgnoreComment=false @@ -200,8 +200,8 @@ def summary(self) -> dict[str, int | str | None]: def init_empty(cls, batch_size: int, n_components: int) -> "BatchedGroupMerge": """Initialize an empty BatchedGroupMerge with the given batch size and number of components.""" return cls( - group_idxs=torch.full((batch_size, n_components), -1, dtype=torch.int16), - k_groups=torch.zeros(batch_size, dtype=torch.int16), + group_idxs=torch.full((batch_size, n_components), -1, dtype=ComponentIndexDtypeTorch), + k_groups=torch.zeros(batch_size, dtype=ComponentIndexDtypeTorch), ) @property diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py index 24c050d36..bb8eaa3e8 100644 --- a/spd/clustering/math/merge_pair_samplers.py +++ b/spd/clustering/math/merge_pair_samplers.py @@ -25,6 +25,22 @@ def __call__( ) -> MergePair: ... +def get_valid_mask( + costs: ClusterCoactivationShaped, +) -> ClusterCoactivationShaped: + """Get a boolean mask of valid merge pairs (non-NaN, non-diagonal).""" + k_groups: int = costs.shape[0] + valid_mask: ClusterCoactivationShaped = ( + ~torch.isnan(costs) # mask out NaN entries + & ~torch.eye( + k_groups, dtype=torch.bool, device=costs.device + ) # mask out diagonal (can't merge with self) + ) + if not valid_mask.any(): + raise ValueError("All non-diagonal costs are NaN, cannot sample merge pair") + return valid_mask + + def range_sampler( costs: ClusterCoactivationShaped, threshold: float = 0.05, @@ -36,7 +52,7 @@ def range_sampler( of the range of non-diagonal costs, then randomly selects one. Args: - costs: Cost matrix for all possible merges + costs: Cost matrix for all possible merges (may contain NaN for invalid pairs) k_groups: Number of current groups threshold: Fraction of cost range to consider (0=min only, 1=all pairs) @@ -47,22 +63,26 @@ def range_sampler( k_groups: int = costs.shape[0] assert costs.shape[1] == k_groups, "Cost matrix must be square" - # Find the range of non-diagonal costs - non_diag_costs: Float[Tensor, " k_groups_squared_minus_k"] = costs[ - ~torch.eye(k_groups, dtype=torch.bool, device=costs.device) - ] - min_cost: float = float(non_diag_costs.min().item()) - max_cost: float = float(non_diag_costs.max().item()) + valid_mask: ClusterCoactivationShaped = get_valid_mask(costs) + + # Get valid costs + valid_costs: Float[Tensor, " n_valid"] = costs[valid_mask] + + # Find the range of valid costs + min_cost: float = float(valid_costs.min().item()) + max_cost: float = float(valid_costs.max().item()) # Calculate threshold cost max_considered_cost: float = (max_cost - min_cost) * threshold + min_cost - # Find all pairs below threshold - considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack( - torch.where(costs <= max_considered_cost), dim=1 - ) - # Remove diagonal entries (i == j) - considered_idxs = considered_idxs[considered_idxs[:, 0] != considered_idxs[:, 1]] + # Find all valid pairs below threshold + within_range: Bool[Tensor, "k_groups k_groups"] = (costs <= max_considered_cost) & valid_mask + + # Get indices of candidate pairs + considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack(torch.where(within_range), dim=1) + + if considered_idxs.shape[0] == 0: + raise ValueError("No valid pairs within threshold range") # Randomly select one of the considered pairs selected_idx: int = random.randint(0, considered_idxs.shape[0] - 1) @@ -78,7 +98,7 @@ def mcmc_sampler( """Sample a merge pair using MCMC with probability proportional to exp(-cost/temperature). Args: - costs: Cost matrix for all possible merges + costs: Cost matrix for all possible merges (may contain NaN for invalid pairs) k_groups: Number of current groups temperature: Temperature parameter for softmax (higher = more uniform sampling) @@ -89,21 +109,18 @@ def mcmc_sampler( k_groups: int = costs.shape[0] assert costs.shape[1] == k_groups, "Cost matrix must be square" - # Create mask for valid pairs (non-diagonal) - valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( - k_groups, dtype=torch.bool, device=costs.device - ) + valid_mask: ClusterCoactivationShaped = get_valid_mask(costs) # Compute probabilities: exp(-cost/temperature) # Use stable softmax computation to avoid overflow costs_masked: ClusterCoactivationShaped = costs.clone() - costs_masked[~valid_mask] = float("inf") # Set diagonal to inf so exp gives 0 + costs_masked[~valid_mask] = float("inf") # Set invalid entries to inf so exp gives 0 # Subtract min for numerical stability min_cost: float = float(costs_masked[valid_mask].min()) probs: ClusterCoactivationShaped = ( torch.exp((min_cost - costs_masked) / temperature) * valid_mask - ) # Zero out diagonal + ) # Zero out invalid entries probs_flatten: Float[Tensor, " k_groups_squared"] = probs.flatten() probs_flatten = probs_flatten / probs_flatten.sum() diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index dba55c878..a7b2966d0 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -8,14 +8,14 @@ from typing import Protocol import torch -from jaxtyping import Bool, Float +from jaxtyping import Bool, Float, Int from torch import Tensor from tqdm import tqdm +from spd.clustering.batched_activations import ActivationBatch, BatchedActivations from spd.clustering.compute_costs import ( compute_mdl_cost, compute_merge_costs, - recompute_coacts_merge_pair, ) from spd.clustering.consts import ( ActivationsTensor, @@ -29,6 +29,50 @@ from spd.clustering.merge_history import MergeHistory +def recompute_coacts_from_scratch( + activations: Tensor, + current_merge: GroupMerge, + activation_threshold: float | None, +) -> tuple[Tensor, Tensor]: + """ + Recompute coactivations from fresh activations using current merge state. + + Args: + activations: Fresh activation tensor [samples, n_components_original] + current_merge: Current merge state mapping original -> groups + activation_threshold: Threshold for binarizing activations + + Returns: + (coact, activation_mask) - coact matrix [k_groups, k_groups] and + mask [samples, k_groups] for current groups + """ + # Apply threshold + activation_mask: Bool[Tensor, "samples n_components"] = ( + activations > activation_threshold if activation_threshold is not None else activations + ) + + # Map component-level activations to group-level using scatter_add + # This is more efficient than materializing the full merge matrix + # current_merge.group_idxs: [n_components] with values 0 to k_groups-1 + n_samples: int = activation_mask.shape[0] + group_activations: Float[Tensor, "n_samples k_groups"] = torch.zeros( + (n_samples, current_merge.k_groups), + dtype=activation_mask.dtype, + device=activation_mask.device, + ) + + # Expand group_idxs to match batch dimension and scatter-add activations by group + group_idxs_expanded: Int[Tensor, "n_samples n_components"] = ( + current_merge.group_idxs.unsqueeze(0).expand(n_samples, -1).to(activation_mask.device) + ) + group_activations.scatter_add_(1, group_idxs_expanded, activation_mask) + + # Compute coactivations + coact: ClusterCoactivationShaped = group_activations.float().T @ group_activations.float() + + return coact, group_activations + + class LogCallback(Protocol): def __call__( self, @@ -48,20 +92,25 @@ def __call__( def merge_iteration( merge_config: MergeConfig, - activations: ActivationsTensor, + batched_activations: BatchedActivations, component_labels: ComponentLabels, log_callback: LogCallback | None = None, ) -> MergeHistory: """ - Merge iteration with optional logging/plotting callbacks. + Merge iteration with multi-batch support and optional logging/plotting callbacks. - This wraps the pure computation with logging capabilities while maintaining - the same core algorithm logic. + This implementation uses NaN masking to track invalid coactivation entries + and periodically recomputes the full coactivation matrix from fresh batches. """ - # compute coactivations + # Load first batch # -------------------------------------------------- - activation_mask_orig: BoolActivationsTensor | ActivationsTensor | None = ( + first_batch: ActivationBatch = batched_activations._get_next_batch() + activations: ActivationsTensor = first_batch.activations + + # Compute initial coactivations + # -------------------------------------------------- + activation_mask_orig: BoolActivationsTensor | ActivationsTensor = ( activations > merge_config.activation_threshold if merge_config.activation_threshold is not None else activations @@ -99,6 +148,31 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: + # Recompute from batch if needed (do this BEFORE computing costs) + # -------------------------------------------------- + # With NaN masking, we must recompute before every iteration (except first) + # because the coact matrix is invalidated after each merge. + # When recompute_costs_every is set, we cycle through batches; + # otherwise we reuse the same batch. + if iter_idx > 0: + # Check if we should load a new batch + should_load_new_batch: bool = ( + merge_config.recompute_costs_every is not None + and iter_idx % merge_config.recompute_costs_every == 0 + ) + + if should_load_new_batch: + new_batch: ActivationBatch = batched_activations._get_next_batch() + activations = new_batch.activations + + # Always recompute coacts from current activations after iteration 0 + # (needed because NaN masking invalidates the matrix) + current_coact, current_act_mask = recompute_coacts_from_scratch( + activations=activations, + current_merge=current_merge, + activation_threshold=merge_config.activation_threshold, + ) + # compute costs, figure out what to merge # -------------------------------------------------- # HACK: this is messy @@ -110,20 +184,38 @@ def merge_iteration( merge_pair: MergePair = merge_config.merge_pair_sample(costs) + # Store merge pair cost before updating + # -------------------------------------------------- + merge_pair_cost: float = float(costs[merge_pair].item()) + # merge the pair # -------------------------------------------------- - # we do this *before* logging, so we can see how the sampled pair cost compares - # to the costs of all the other possible pairs - current_merge, current_coact, current_act_mask = recompute_coacts_merge_pair( - coact=current_coact, - merges=current_merge, - merge_pair=merge_pair, - activation_mask=current_act_mask, + # Update merge state BEFORE NaN-ing out + current_merge = current_merge.merge_groups(merge_pair[0], merge_pair[1]) + + # NaN out the merged components' rows/cols + i, j = merge_pair + new_idx: int = min(i, j) + remove_idx: int = max(i, j) + + # Mark affected entries as invalid (can't compute cost anymore without recompute) + current_coact[remove_idx, :] = float("nan") + current_coact[:, remove_idx] = float("nan") + current_coact[new_idx, :] = float("nan") + current_coact[:, new_idx] = float("nan") + + # Remove the deleted row/col to maintain shape consistency + mask: Bool[Tensor, " k_groups"] = torch.ones( + k_groups, dtype=torch.bool, device=current_coact.device ) + mask[remove_idx] = False + current_coact = current_coact[mask, :][:, mask] + current_act_mask = current_act_mask[:, mask] + + k_groups -= 1 - # metrics and logging - # -------------------------------------------------- # Store in history + # -------------------------------------------------- merge_history.add_iteration( idx=iter_idx, selected_pair=merge_pair, @@ -131,6 +223,7 @@ def merge_iteration( ) # Compute metrics for logging + # -------------------------------------------------- # the MDL loss computed here is the *cost of the current merge*, a single scalar value # rather than the *delta in cost from merging a specific pair* (which is what `costs` matrix contains) diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) @@ -140,8 +233,6 @@ def merge_iteration( alpha=merge_config.alpha, ) mdl_loss_norm: float = mdl_loss / current_act_mask.shape[0] - # this is the cost for the selected pair - merge_pair_cost: float = float(costs[merge_pair].item()) # Update progress bar pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") @@ -161,9 +252,8 @@ def merge_iteration( diag_acts=diag_acts, ) - # iterate and sanity checks + # Sanity checks # -------------------------------------------------- - k_groups -= 1 assert current_coact.shape[0] == k_groups, ( "Coactivation matrix shape should match number of groups" ) diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index f471879b2..6a1c53069 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -72,6 +72,14 @@ class MergeConfig(BaseConfig): default=None, description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", ) + recompute_costs_every: PositiveInt | None = Field( + default=None, + description="Number of merges before recomputing costs with new batch. Set to `None` to use a single batch throughout.", + ) + batch_size: PositiveInt = Field( + default=64, + description="Size of each batch for processing", + ) @property def merge_pair_sample_func(self) -> MergePairSampler: diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py index bbff78893..7bbbd046b 100644 --- a/spd/clustering/merge_history.py +++ b/spd/clustering/merge_history.py @@ -11,6 +11,7 @@ from muutils.dbg import dbg_tensor from spd.clustering.consts import ( + ComponentIndexDtype, ComponentLabels, DistancesArray, DistancesMethod, @@ -74,7 +75,7 @@ def from_config( return MergeHistory( labels=labels, n_iters_current=0, - selected_pairs=np.full((n_iters_target, 2), -1, dtype=np.int16), + selected_pairs=np.full((n_iters_target, 2), -1, dtype=ComponentIndexDtype), merges=BatchedGroupMerge.init_empty( batch_size=n_iters_target, n_components=n_components ), @@ -108,7 +109,7 @@ def add_iteration( current_merge: GroupMerge, ) -> None: """Add data for one iteration.""" - self.selected_pairs[idx] = np.array(selected_pair, dtype=np.int16) + self.selected_pairs[idx] = np.array(selected_pair, dtype=ComponentIndexDtype) self.merges[idx] = current_merge assert self.n_iters_current == idx @@ -339,9 +340,7 @@ def merges_array(self) -> MergesArray: output: MergesArray = np.full( (n_ens, n_iters, c_components), fill_value=-1, - dtype=np.int16, - # if you have more than 32k components, change this to np.int32 - # if you have more than 2.1b components, rethink your life choices + dtype=ComponentIndexDtype, ) for i_ens, history in enumerate(self.data): for i_iter, merge in enumerate(history.merges): @@ -373,7 +372,7 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]: merges_array: MergesArray = np.full( (self.n_ensemble, self.n_iters_min, c_components), fill_value=-1, - dtype=np.int16, + dtype=ComponentIndexDtype, ) except Exception as e: err_msg = ( @@ -418,7 +417,7 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]: merges_array[i_ens, :, i_comp_new_relabel] = np.full( self.n_iters_min, fill_value=idx_missing + hist_n_components, - dtype=np.int16, + dtype=ComponentIndexDtype, ) # TODO: Consider logging overlap_stats to WandB if run is available diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 54f0805c6..110469094 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -10,7 +10,6 @@ """ import argparse -import gc import os import tempfile from collections.abc import Callable @@ -26,32 +25,26 @@ from torch import Tensor from wandb.sdk.wandb_run import Run -from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, +from spd.clustering.batched_activations import ( + BatchedActivations, + precompute_batches_for_single_run, ) from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( - ActivationsTensor, - BatchTensor, ClusterCoactivationShaped, ComponentLabels, ) -from spd.clustering.dataset import load_dataset from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.math.semilog import semilog from spd.clustering.merge import merge_iteration from spd.clustering.merge_history import MergeHistory -from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase from spd.clustering.wandb_tensor_info import wandb_log_tensor from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.component_model import SPDRunInfo from spd.spd_types import TaskName -from spd.utils.distributed_utils import get_device from spd.utils.general_utils import replace_pydantic_model from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str @@ -219,16 +212,16 @@ def main(run_config: ClusteringRunConfig) -> Path: """ # Create ExecutionStamp and storage # don't create git snapshot -- if we are part of an ensemble, the snapshot should be created by the pipeline - execution_stamp = ExecutionStamp.create( + execution_stamp: ExecutionStamp = ExecutionStamp.create( run_type="cluster", create_snapshot=False, ) - storage = ClusteringRunStorage(execution_stamp) - clustering_run_id = execution_stamp.run_id + storage: ClusteringRunStorage = ClusteringRunStorage(execution_stamp) + clustering_run_id: str = execution_stamp.run_id logger.info(f"Clustering run ID: {clustering_run_id}") # Register with ensemble if this is part of a pipeline - assigned_idx: int | None + assigned_idx: int | None = None if run_config.ensemble_id: assigned_idx = register_clustering_run( pipeline_run_id=run_config.ensemble_id, @@ -243,8 +236,6 @@ def main(run_config: ClusteringRunConfig) -> Path: run_config, {"dataset_seed": run_config.dataset_seed + assigned_idx}, ) - else: - assigned_idx = None # save config run_config.to_file(storage.config_path) @@ -253,30 +244,11 @@ def main(run_config: ClusteringRunConfig) -> Path: # start logger.info("Starting clustering run") logger.info(f"Output directory: {storage.base_dir}") - device = get_device() - spd_run = SPDRunInfo.from_path(run_config.model_path) + spd_run: SPDRunInfo = SPDRunInfo.from_path(run_config.model_path) task_name: TaskName = spd_run.config.task_config.task_name - # 1. Load dataset - logger.info(f"Loading dataset (seed={run_config.dataset_seed})") - load_dataset_kwargs: dict[str, Any] = dict() - if run_config.dataset_streaming: - logger.info("Using streaming dataset loading") - load_dataset_kwargs["config_kwargs"] = dict(streaming=True) - assert task_name == "lm", ( - f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." - ) - batch: BatchTensor = load_dataset( - model_path=run_config.model_path, - task_name=task_name, - batch_size=run_config.batch_size, - seed=run_config.dataset_seed, - **load_dataset_kwargs, - ) - batch = batch.to(device) - - # 2. Setup WandB for this run + # Setup WandB for this run wandb_run: Run | None = None if run_config.wandb_project is not None: wandb_run = wandb.init( @@ -293,58 +265,40 @@ def main(run_config: ClusteringRunConfig) -> Path: f"assigned_idx:{assigned_idx}", ], ) - # logger.info(f"WandB run: {wandb_run.url}") - - # 3. Load model - logger.info("Loading model") - model = ComponentModel.from_run_info(spd_run).to(device) - - # 4. Compute activations - logger.info("Computing activations") - activations_dict: ( - dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] - ) = component_activations( - model=model, - batch=batch, - device=device, - ) - # 5. Process activations - logger.info("Processing activations") - processed_activations: ProcessedActivations = process_activations( - activations=activations_dict, - filter_dead_threshold=run_config.merge_config.filter_dead_threshold, - seq_mode="concat" if task_name == "lm" else None, - filter_modules=run_config.merge_config.filter_modules, - ) + # Load or compute activations + # ===================================== + batched_activations: BatchedActivations - # 6. Log activations (if WandB enabled) - if wandb_run is not None: - logger.info("Plotting activations") - plot_activations( - processed_activations=processed_activations, - save_dir=None, # Don't save to disk, only WandB - n_samples_max=256, - wandb_run=wandb_run, - ) - wandb_log_tensor( - wandb_run, - processed_activations.activations, - "activations", - 0, - single=True, + if run_config.precomputed_activations_dir is not None: + # Case 1: Use precomputed batches from disk (from ensemble pipeline) + logger.info(f"Loading precomputed batches from {run_config.precomputed_activations_dir}") + batched_activations = BatchedActivations(run_config.precomputed_activations_dir) + logger.info(f"Loaded {batched_activations.n_batches} precomputed batches") + + else: + # Case 2: Generate batches for this single run + logger.info(f"Generating activation batches (seed={run_config.dataset_seed})") + + batch_dir: Path = storage.base_dir / "batches" + batch_dir.mkdir(exist_ok=True) + + # Generate all needed batches (respects recompute_costs_every) + n_batches: int = precompute_batches_for_single_run( + clustering_run_config=run_config, + output_dir=batch_dir, + base_seed=run_config.dataset_seed, ) - # Clean up memory - activations: ActivationsTensor = processed_activations.activations - component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) - del processed_activations - del activations_dict - del model - del batch - gc.collect() + # Load batches + batched_activations = BatchedActivations(batch_dir) + logger.info(f"Generated and loaded {n_batches} batches") + + # Get labels from batches + component_labels: ComponentLabels = batched_activations.labels - # 7. Run merge iteration + # Run merge iteration + # ===================================== logger.info("Starting merging") log_callback: LogCallback | None = ( partial(_log_callback, run=wandb_run, run_config=run_config) @@ -354,7 +308,7 @@ def main(run_config: ClusteringRunConfig) -> Path: history: MergeHistory = merge_iteration( merge_config=run_config.merge_config, - activations=activations, + batched_activations=batched_activations, component_labels=component_labels, log_callback=log_callback, ) @@ -412,6 +366,12 @@ def cli() -> None: action="store_true", help="Whether to use streaming dataset loading (if supported by the dataset)", ) + parser.add_argument( + "--precomputed-activations-dir", + type=Path, + default=None, + help="Path to directory containing precomputed activation batches", + ) args: argparse.Namespace = parser.parse_args() @@ -431,6 +391,8 @@ def cli() -> None: overrides["wandb_project"] = args.wandb_project if args.wandb_entity is not None: overrides["wandb_entity"] = args.wandb_entity + if args.precomputed_activations_dir is not None: + overrides["precomputed_activations_dir"] = args.precomputed_activations_dir run_config = replace_pydantic_model(run_config, overrides) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 179bc8bca..5e3db08dc 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -28,6 +28,7 @@ from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig +from spd.clustering.batched_activations import precompute_batches_for_ensemble from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import DistancesMethod from spd.clustering.storage import StorageBase @@ -151,6 +152,7 @@ def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str def generate_clustering_commands( pipeline_config: ClusteringPipelineConfig, pipeline_run_id: str, + batches_base_dir: Path | None = None, dataset_streaming: bool = False, ) -> list[str]: """Generate commands for each clustering run. @@ -158,6 +160,7 @@ def generate_clustering_commands( Args: pipeline_config: Pipeline configuration pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) + batches_base_dir: Path to precomputed batches directory, or None for single-batch mode dataset_streaming: Whether to use dataset streaming Returns: @@ -180,6 +183,12 @@ def generate_clustering_commands( "--wandb-entity", pipeline_config.wandb_entity, ] + + # Add precomputed batches path if available + if batches_base_dir is not None: + run_batch_dir = batches_base_dir / f"run_{idx}" + cmd_parts.extend(["--precomputed-activations-dir", str(run_batch_dir)]) + if dataset_streaming: cmd_parts.append("--dataset-streaming") @@ -252,7 +261,7 @@ def main( logger.info(f"Pipeline run ID: {pipeline_run_id}") # Initialize storage - storage = ClusteringPipelineStorage(execution_stamp) + storage: ClusteringPipelineStorage = ClusteringPipelineStorage(execution_stamp) logger.info(f"Pipeline output directory: {storage.base_dir}") # Save pipeline config @@ -261,22 +270,45 @@ def main( # Create WandB workspace if requested if pipeline_config.wandb_project is not None: - workspace_url = create_clustering_workspace_view( + workspace_url: str = create_clustering_workspace_view( ensemble_id=pipeline_run_id, project=pipeline_config.wandb_project, entity=pipeline_config.wandb_entity, ) logger.info(f"WandB workspace: {workspace_url}") + clustering_run_config: ClusteringRunConfig = ClusteringRunConfig.from_file( + pipeline_config.clustering_run_config_path + ) + + # Precompute batches if multi-batch mode + # ========================================================================================== + + # pass streaming to the crc + clustering_run_config = replace_pydantic_model( + clustering_run_config, + {"dataset_streaming": dataset_streaming}, + ) + + batches_base_dir: Path | None = precompute_batches_for_ensemble( + clustering_run_config=clustering_run_config, + n_runs=pipeline_config.n_runs, + output_dir=storage.base_dir, + ) + + # run + # ========================================================================================== + # Generate commands for clustering runs - clustering_commands = generate_clustering_commands( + clustering_commands: list[str] = generate_clustering_commands( pipeline_config=pipeline_config, pipeline_run_id=pipeline_run_id, + batches_base_dir=batches_base_dir, dataset_streaming=dataset_streaming, ) # Generate commands for calculating distances - calc_distances_commands = generate_calc_distances_commands( + calc_distances_commands: list[str] = generate_calc_distances_commands( pipeline_run_id=pipeline_run_id, distances_methods=pipeline_config.distances_methods, ) diff --git a/spd/clustering/util.py b/spd/clustering/util.py index bd11e2fd4..0c1300640 100644 --- a/spd/clustering/util.py +++ b/spd/clustering/util.py @@ -8,10 +8,13 @@ def format_scientific_latex(value: float) -> str: import math - exponent: int = int(math.floor(math.log10(abs(value)))) - mantissa: float = value / (10**exponent) + try: + exponent: int = int(math.floor(math.log10(abs(value)))) + mantissa: float = value / (10**exponent) - return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + except Exception: + return f"${value}$" ModuleFilterSource = str | Callable[[str], bool] | set[str] | None diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index bbfb5259e..d2c1efdc0 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -12,6 +12,7 @@ component_activations, process_activations, ) +from spd.clustering.batched_activations import BatchedActivations from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig @@ -74,7 +75,9 @@ data_generation_type=DATASET.data_generation_type, ) ) -DATALOADER = DatasetGeneratedDataLoader(DATASET, batch_size=N_SAMPLES, shuffle=False) +DATALOADER: DatasetGeneratedDataLoader[Any] = DatasetGeneratedDataLoader( + DATASET, batch_size=N_SAMPLES, shuffle=False +) # %% # Get component activations @@ -148,9 +151,13 @@ def _plot_func( ) +BATCHED_ACTIVATIONS: BatchedActivations = BatchedActivations.from_tensor( + activations=PROCESSED_ACTIVATIONS.activations, + labels=list(PROCESSED_ACTIVATIONS.labels), +) MERGE_HIST: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - activations=PROCESSED_ACTIVATIONS.activations, + batched_activations=BATCHED_ACTIVATIONS, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=_plot_func, ) @@ -172,9 +179,13 @@ def _plot_func( ENSEMBLE_SIZE: int = 4 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): + batched_acts: BatchedActivations = BatchedActivations.from_tensor( + activations=PROCESSED_ACTIVATIONS.activations, + labels=list(PROCESSED_ACTIVATIONS.labels), + ) HISTORY: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - activations=PROCESSED_ACTIVATIONS.activations, + batched_activations=batched_acts, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=None, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 0b7f8de97..5261484b3 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,13 +16,16 @@ component_activations, process_activations, ) +from spd.clustering.batched_activations import BatchedActivations from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.consts import DistancesArray from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution +from spd.configs import Config from spd.models.component_model import ComponentModel, SPDRunInfo DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -43,11 +46,11 @@ SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) MODEL.to(DEVICE) -SPD_CONFIG = SPD_RUN.config +SPD_CONFIG: Config = SPD_RUN.config # Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), + merge_config=MergeConfig(batch_size=2), model_path=MODEL_PATH, batch_size=2, dataset_seed=42, @@ -111,9 +114,13 @@ ENSEMBLE_SIZE: int = 2 HISTORIES: list[MergeHistory] = [] for _i in range(ENSEMBLE_SIZE): + batched_acts: BatchedActivations = BatchedActivations.from_tensor( + activations=PROCESSED_ACTIVATIONS.activations, + labels=list(PROCESSED_ACTIVATIONS.labels), + ) HISTORY: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - activations=PROCESSED_ACTIVATIONS.activations, + batched_activations=batched_acts, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=None, ) @@ -125,7 +132,7 @@ # %% # Compute and plot distances # ============================================================ -DISTANCES = ENSEMBLE.get_distances() +DISTANCES: DistancesArray = ENSEMBLE.get_distances() plot_dists_distribution( distances=DISTANCES, diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 8492300de..af33ef1fd 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -2,9 +2,11 @@ import torch +from spd.clustering.batched_activations import BatchedActivations from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory class TestMergeIntegration: @@ -13,24 +15,31 @@ class TestMergeIntegration: def test_merge_with_range_sampler(self): """Test merge iteration with range sampler.""" # Create test data - n_samples = 100 - n_components = 10 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + n_samples: int = 100 + n_components: int = 10 + activations: torch.Tensor = torch.rand(n_samples, n_components) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) # Configure with range sampler - config = MergeConfig( + config: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=5, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - filter_dead_threshold=0.001, + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) # Run merge iteration - history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels + batched_activations: BatchedActivations = BatchedActivations.from_tensor( + activations=activations, labels=list(component_labels) + ) + history: MergeHistory = merge_iteration( + batched_activations=batched_activations, + merge_config=config, + component_labels=component_labels, ) # Check results @@ -46,24 +55,31 @@ def test_merge_with_range_sampler(self): def test_merge_with_mcmc_sampler(self): """Test merge iteration with MCMC sampler.""" # Create test data - n_samples = 100 - n_components = 10 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + n_samples: int = 100 + n_components: int = 10 + activations: torch.Tensor = torch.rand(n_samples, n_components) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) # Configure with MCMC sampler - config = MergeConfig( + config: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=5, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0}, - filter_dead_threshold=0.001, + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) # Run merge iteration - history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels + batched_activations: BatchedActivations = BatchedActivations.from_tensor( + activations=activations, labels=list(component_labels) + ) + history: MergeHistory = merge_iteration( + batched_activations=batched_activations, + merge_config=config, + component_labels=component_labels, ) # Check results @@ -78,42 +94,52 @@ def test_merge_with_mcmc_sampler(self): def test_merge_comparison_samplers(self): """Compare behavior of different samplers with same data.""" # Create test data with clear structure - n_samples = 100 - n_components = 8 - activations = torch.rand(n_samples, n_components) + n_samples: int = 100 + n_components: int = 8 + activations: torch.Tensor = torch.rand(n_samples, n_components) # Make some components more active to create cost structure activations[:, 0] *= 2 # Component 0 is very active activations[:, 1] *= 0.1 # Component 1 is rarely active - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) # Run with range sampler (threshold=0 for deterministic minimum selection) - config_range = MergeConfig( + config_range: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) - history_range = merge_iteration( - activations=activations.clone(), + batched_activations_range: BatchedActivations = BatchedActivations.from_tensor( + activations=activations.clone(), labels=list(component_labels) + ) + history_range: MergeHistory = merge_iteration( + batched_activations=batched_activations_range, merge_config=config_range, component_labels=ComponentLabels(component_labels.copy()), ) # Run with MCMC sampler (low temperature for near-deterministic) - config_mcmc = MergeConfig( + config_mcmc: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=3, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp + recompute_costs_every=2, # Recompute every 2 iterations for single-batch tests ) - history_mcmc = merge_iteration( - activations=activations.clone(), + batched_activations_mcmc: BatchedActivations = BatchedActivations.from_tensor( + activations=activations.clone(), labels=list(component_labels) + ) + history_mcmc: MergeHistory = merge_iteration( + batched_activations=batched_activations_mcmc, merge_config=config_mcmc, component_labels=ComponentLabels(component_labels.copy()), ) @@ -127,12 +153,14 @@ def test_merge_comparison_samplers(self): def test_merge_with_small_components(self): """Test merge with very few components.""" # Edge case: only 3 components - n_samples = 50 - n_components = 3 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + n_samples: int = 50 + n_components: int = 3 + activations: torch.Tensor = torch.rand(n_samples, n_components) + component_labels: ComponentLabels = ComponentLabels( + [f"comp_{i}" for i in range(n_components)] + ) - config = MergeConfig( + config: MergeConfig = MergeConfig( activation_threshold=0.1, alpha=1.0, iters=1, # Just one merge @@ -140,8 +168,13 @@ def test_merge_with_small_components(self): merge_pair_sampling_kwargs={"temperature": 2.0}, ) - history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels + batched_activations: BatchedActivations = BatchedActivations.from_tensor( + activations=activations, labels=list(component_labels) + ) + history: MergeHistory = merge_iteration( + batched_activations=batched_activations, + merge_config=config, + component_labels=component_labels, ) # First entry is after first merge, so should be 3 - 1 = 2