From fd6c66bba40352fe0e429a42393526cc4d38d4aa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 13:22:06 +0100 Subject: [PATCH 01/20] timing cluster_ss.py nearly identical to 26c2957 but accidentally committed that to clustering/main --- .github/workflows/checks.yaml | 5 + tests/clustering/scripts/cluster_ss.py | 148 +++++++++++++------------ 2 files changed, 85 insertions(+), 68 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 258935d2b..2b651b334 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -49,6 +49,11 @@ jobs: - name: Run ruff format run: uv run ruff format . + - name: "[TEMP] run cluster_ss.py" + run: uv run python tests/clustering/scripts/cluster_ss.py + env: + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + - name: Run tests run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto env: diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 6ede368f0..fda2eec63 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -5,6 +5,7 @@ import torch from jaxtyping import Int from muutils.dbg import dbg_auto +from muutils.spinner import SpinnerContext from torch import Tensor from spd.clustering.activations import ( @@ -34,96 +35,107 @@ # %% # Load model and dataset # ============================================================ -MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" - -SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) -MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) -MODEL.to(DEVICE) -SPD_CONFIG = SPD_RUN.config - -# Use split_dataset with RunConfig to get real data -CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), - model_path=MODEL_PATH, - task_name="lm", - n_batches=1, - batch_size=2, -) -BATCHES, _ = split_dataset(config=CONFIG) +with SpinnerContext(message="Load model"): + MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" + + SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) + MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) + MODEL.to(DEVICE) + SPD_CONFIG = SPD_RUN.config + + # Use split_dataset with RunConfig to get real data + CONFIG: ClusteringRunConfig = ClusteringRunConfig( + merge_config=MergeConfig(), + model_path=MODEL_PATH, + task_name="lm", + n_batches=1, + batch_size=2, + ) + +with SpinnerContext(message="Load data"): + BATCHES, _ = split_dataset(config=CONFIG) # %% # Load data batch # ============================================================ -DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) +with SpinnerContext(message="Load data batch"): + DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) # %% # Get component activations # ============================================================ -COMPONENT_ACTS: dict[str, Tensor] = component_activations( - model=MODEL, - batch=DATA_BATCH, - device=DEVICE, - sigmoid_type="hard", -) +with SpinnerContext(message="Get component activations"): + COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + batch=DATA_BATCH, + device=DEVICE, + sigmoid_type="hard", + ) -_ = dbg_auto(COMPONENT_ACTS) + _ = dbg_auto(COMPONENT_ACTS) # %% # Process activations # ============================================================ -FILTER_DEAD_THRESHOLD: float = 0.001 -FILTER_MODULES: str = "model.layers.0" - -PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( - activations=COMPONENT_ACTS, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, - filter_modules=lambda x: x.startswith(FILTER_MODULES), - seq_mode="concat", -) +with SpinnerContext(message="Process activations"): + FILTER_DEAD_THRESHOLD: float = 0.001 + FILTER_MODULES: str = "model.layers.0" + + PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + activations=COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, + filter_modules=lambda x: x.startswith(FILTER_MODULES), + seq_mode="concat", + ) -plot_activations( - processed_activations=PROCESSED_ACTIVATIONS, - save_dir=TEMP_DIR, - n_samples_max=256, - wandb_run=None, -) +with SpinnerContext(message="Plot activations"): + plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, + ) # %% # Compute ensemble merge iterations # ============================================================ -MERGE_CFG: MergeConfig = MergeConfig( - activation_threshold=0.01, - alpha=0.01, - iters=2, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, - module_name_filter=FILTER_MODULES, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, -) - -# Modern approach: run merge_iteration multiple times to create ensemble -ENSEMBLE_SIZE: int = 2 -HISTORIES: list[MergeHistory] = [] -for i in range(ENSEMBLE_SIZE): - HISTORY: MergeHistory = merge_iteration( - merge_config=MERGE_CFG, - batch_id=f"batch_{i}", - activations=PROCESSED_ACTIVATIONS.activations, - component_labels=PROCESSED_ACTIVATIONS.labels, - log_callback=None, +with SpinnerContext(message="Compute merge iterations"): + MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.01, + alpha=0.01, + iters=2, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + pop_component_prob=0, + module_name_filter=FILTER_MODULES, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) - HISTORIES.append(HISTORY) -ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) + # Modern approach: run merge_iteration multiple times to create ensemble + ENSEMBLE_SIZE: int = 2 + HISTORIES: list[MergeHistory] = [] + for i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id=f"batch_{i}", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) + + ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) # %% # Compute and plot distances # ============================================================ -DISTANCES = ENSEMBLE.get_distances() +with SpinnerContext(message="compute distances"): + DISTANCES = ENSEMBLE.get_distances() -plot_dists_distribution( - distances=DISTANCES, - mode="points", -) -plt.legend() + +with SpinnerContext(message="plot distances"): + plot_dists_distribution( + distances=DISTANCES, + mode="points", + ) + plt.legend() From 05155dcb7b815e5fa6a688860307545fc291200e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 13:33:23 +0100 Subject: [PATCH 02/20] better timing --- .github/workflows/checks.yaml | 9 ++--- tests/clustering/scripts/cluster_ss.py | 49 ++++++++++++++++++++------ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 2b651b334..e3e865b39 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -40,6 +40,11 @@ jobs: - name: Print dependencies run: uv pip list + - name: "[TEMP] run cluster_ss.py" + run: uv run python tests/clustering/scripts/cluster_ss.py + env: + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + - name: Run basedpyright run: uv run basedpyright @@ -49,10 +54,6 @@ jobs: - name: Run ruff format run: uv run ruff format . - - name: "[TEMP] run cluster_ss.py" - run: uv run python tests/clustering/scripts/cluster_ss.py - env: - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - name: Run tests run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index fda2eec63..07ede3db5 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -5,7 +5,6 @@ import torch from jaxtyping import Int from muutils.dbg import dbg_auto -from muutils.spinner import SpinnerContext from torch import Tensor from spd.clustering.activations import ( @@ -28,6 +27,30 @@ ) # save to an actual dir that is gitignored, so users can view plots TEMP_DIR.mkdir(parents=True, exist_ok=True) + +TIMER_RECORDS: list[tuple[str, float]] = list() + + +class TimerContext: + def __init__(self, message: str): + self.message = message + self.start_time = None + self.end_time = None + + def __enter__(self): + print(f"[TIMER START] {self.message}") + self.start_time = torch.cuda.Event(enable_timing=True) + self.end_time = torch.cuda.Event(enable_timing=True) + self.start_time.record() + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time.record() + torch.cuda.synchronize() + elapsed_time = self.start_time.elapsed_time(self.end_time) / 1000.0 + TIMER_RECORDS.append((self.message, elapsed_time)) + print(f"[TIMER END] {self.message} - Elapsed time: {elapsed_time:.2f} seconds") + + # magic autoreload # %load_ext autoreload # %autoreload 2 @@ -35,7 +58,7 @@ # %% # Load model and dataset # ============================================================ -with SpinnerContext(message="Load model"): +with TimerContext(message="Load model"): MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) @@ -52,19 +75,19 @@ batch_size=2, ) -with SpinnerContext(message="Load data"): +with TimerContext(message="Load data"): BATCHES, _ = split_dataset(config=CONFIG) # %% # Load data batch # ============================================================ -with SpinnerContext(message="Load data batch"): +with TimerContext(message="Load data batch"): DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) # %% # Get component activations # ============================================================ -with SpinnerContext(message="Get component activations"): +with TimerContext(message="Get component activations"): COMPONENT_ACTS: dict[str, Tensor] = component_activations( model=MODEL, batch=DATA_BATCH, @@ -76,7 +99,7 @@ # %% # Process activations # ============================================================ -with SpinnerContext(message="Process activations"): +with TimerContext(message="Process activations"): FILTER_DEAD_THRESHOLD: float = 0.001 FILTER_MODULES: str = "model.layers.0" @@ -87,7 +110,7 @@ seq_mode="concat", ) -with SpinnerContext(message="Plot activations"): +with TimerContext(message="Plot activations"): plot_activations( processed_activations=PROCESSED_ACTIVATIONS, save_dir=TEMP_DIR, @@ -98,7 +121,7 @@ # %% # Compute ensemble merge iterations # ============================================================ -with SpinnerContext(message="Compute merge iterations"): +with TimerContext(message="Compute merge iterations"): MERGE_CFG: MergeConfig = MergeConfig( activation_threshold=0.01, alpha=0.01, @@ -129,13 +152,19 @@ # %% # Compute and plot distances # ============================================================ -with SpinnerContext(message="compute distances"): +with TimerContext(message="compute distances"): DISTANCES = ENSEMBLE.get_distances() -with SpinnerContext(message="plot distances"): +with TimerContext(message="plot distances"): plot_dists_distribution( distances=DISTANCES, mode="points", ) plt.legend() + + +# %% +print("Timer records (s):") +for key, value in TIMER_RECORDS: + print(f"{key:<30}{value:10.2f}") From 6eb38a72518d3bb75a108de06ac6c333dbacaff9 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 13:59:39 +0100 Subject: [PATCH 03/20] fix timing --- tests/clustering/scripts/cluster_ss.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 07ede3db5..3f339d6ac 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -1,4 +1,5 @@ # %% +import time from pathlib import Path import matplotlib.pyplot as plt @@ -33,20 +34,15 @@ class TimerContext: def __init__(self, message: str): - self.message = message - self.start_time = None - self.end_time = None + self.message: str = message + self.start_time: float def __enter__(self): print(f"[TIMER START] {self.message}") - self.start_time = torch.cuda.Event(enable_timing=True) - self.end_time = torch.cuda.Event(enable_timing=True) - self.start_time.record() - - def __exit__(self, exc_type, exc_value, traceback): - self.end_time.record() - torch.cuda.synchronize() - elapsed_time = self.start_time.elapsed_time(self.end_time) / 1000.0 + self.start_time = time.perf_counter() + + def __exit__(self, exc_type, exc_value, traceback): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType] + elapsed_time: float = time.perf_counter() - self.start_time TIMER_RECORDS.append((self.message, elapsed_time)) print(f"[TIMER END] {self.message} - Elapsed time: {elapsed_time:.2f} seconds") From f1e63f2a702b9783fbd741bbe70a014c971c76ce Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:06:21 +0100 Subject: [PATCH 04/20] allow saving in formats besides pdf --- spd/clustering/plotting/activations.py | 18 ++++++++++-------- spd/clustering/plotting/merge.py | 6 ++++-- tests/clustering/scripts/cluster_ss.py | 1 + 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index 2411eca38..dc8c8658c 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from pathlib import Path +from typing import Literal import matplotlib as mpl import matplotlib.pyplot as plt @@ -20,7 +21,7 @@ def plot_activations( processed_activations: ProcessedActivations, save_dir: Path, n_samples_max: int, - pdf_prefix: str = "activations", + figure_prefix: str = "activations", figsize_raw: tuple[int, int] = (12, 4), figsize_concat: tuple[int, int] = (12, 2), figsize_coact: tuple[int, int] = (8, 6), @@ -28,6 +29,7 @@ def plot_activations( hist_bins: int = 100, do_sorted_samples: bool = False, wandb_run: wandb.sdk.wandb_run.Run | None = None, + save_fmt: Literal["pdf", "png", "svg"] = "pdf", ) -> None: """Plot activation visualizations including raw, concatenated, sorted, and coactivations. @@ -37,7 +39,7 @@ def plot_activations( coact: Coactivation matrix labels: Component labels save_dir: The directory to save the plots to - pdf_prefix: Prefix for PDF filenames + figure_prefix: Prefix for figure filenames figsize_raw: Figure size for raw activations figsize_concat: Figure size for concatenated activations figsize_coact: Figure size for coactivations @@ -77,7 +79,7 @@ def plot_activations( axs_act[i].set_ylabel(f"components\n{key}") axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") - fig1_fname = save_dir / f"{pdf_prefix}_raw.pdf" + fig1_fname = save_dir / f"{figure_prefix}_raw.{save_fmt}" _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -100,7 +102,7 @@ def plot_activations( plt.colorbar(im2) - fig2_fname: Path = save_dir / f"{pdf_prefix}_concatenated.pdf" + fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.{save_fmt}" fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -169,7 +171,7 @@ def plot_activations( plt.colorbar(im3) - fig3_fname: Path = save_dir / f"{pdf_prefix}_concatenated_sorted.pdf" + fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.{save_fmt}" fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -193,7 +195,7 @@ def plot_activations( plt.colorbar(im4) - fig4_fname: Path = save_dir / f"{pdf_prefix}_coactivations.pdf" + fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.{save_fmt}" fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -217,7 +219,7 @@ def plot_activations( add_component_labeling(ax4_log, labels, axis="x") add_component_labeling(ax4_log, labels, axis="y") plt.colorbar(im4_log) - fig4_log_fname: Path = save_dir / f"{pdf_prefix}_coactivations_log.pdf" + fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.{save_fmt}" fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -312,7 +314,7 @@ def plot_activations( plt.tight_layout() - fig5_fname: Path = save_dir / f"{pdf_prefix}_histograms.pdf" + fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.{save_fmt}" fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) # Log to WandB if available diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py index e470b3114..8a2cc20df 100644 --- a/spd/clustering/plotting/merge.py +++ b/spd/clustering/plotting/merge.py @@ -17,7 +17,7 @@ figsize=(16, 10), tick_spacing=5, save_pdf=False, - pdf_prefix="merge_iteration", + figure_prefix="merge_iteration", ) @@ -168,7 +168,9 @@ def plot_merge_iteration( if plot_config_["save_pdf"]: fig.savefig( - f"{plot_config_['pdf_prefix']}_iter_{iteration:03d}.pdf", bbox_inches="tight", dpi=300 + f"{plot_config_['figure_prefix']}_iter_{iteration:03d}.pdf", + bbox_inches="tight", + dpi=300, ) if show: diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 3f339d6ac..0f575fb1d 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -112,6 +112,7 @@ def __exit__(self, exc_type, exc_value, traceback): # pyright: ignore[reportUnk save_dir=TEMP_DIR, n_samples_max=256, wandb_run=None, + save_fmt="svg", ) # %% From 7c48ded780880fb02f83b08f76da40e0f5713c47 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:21:05 +0100 Subject: [PATCH 05/20] streaming dataset for notebook to avoid download? --- spd/clustering/pipeline/s1_split_dataset.py | 18 ++++-- tests/clustering/scripts/cluster_ss.py | 67 +++++++++++---------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py index d5427e600..622d0539b 100644 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -21,7 +21,9 @@ from spd.models.component_model import ComponentModel, SPDRunInfo -def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], dict[str, Any]]: +def split_dataset( + config: ClusteringRunConfig, **kwargs: dict[str, Any] +) -> tuple[Iterator[BatchTensor], dict[str, Any]]: """Split a dataset into n_batches of batch_size, returning iterator and config""" ds: Generator[BatchTensor, None, None] ds_config_dict: dict[str, Any] @@ -30,11 +32,13 @@ def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], d ds, ds_config_dict = _get_dataloader_lm( model_path=config.model_path, batch_size=config.batch_size, + **kwargs, ) case "resid_mlp": ds, ds_config_dict = _get_dataloader_resid_mlp( model_path=config.model_path, batch_size=config.batch_size, + **kwargs, ) case name: raise ValueError( @@ -56,6 +60,7 @@ def limited_iterator() -> Iterator[BatchTensor]: def _get_dataloader_lm( model_path: str, batch_size: int, + config_kwargs: dict[str, Any] | None = None, ) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]: """split up a SS dataset into n_batches of batch_size, returned the saved paths @@ -81,15 +86,20 @@ def _get_dataloader_lm( f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" ) + config_kwargs_: dict[str, Any] = dict( + is_tokenized=False, + streaming=False, + seed=0, + **(config_kwargs or {}), # allow overrides + ) + dataset_config: DatasetConfig = DatasetConfig( name=cfg.task_config.dataset_name, hf_tokenizer_path=pretrained_model_name, split=cfg.task_config.train_data_split, n_ctx=cfg.task_config.max_seq_len, - is_tokenized=False, - streaming=False, - seed=0, column_name=cfg.task_config.column_name, + **config_kwargs_, ) with SpinnerContext(message="getting dataloader..."): diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 0f575fb1d..5935e615b 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -1,35 +1,5 @@ # %% import time -from pathlib import Path - -import matplotlib.pyplot as plt -import torch -from jaxtyping import Int -from muutils.dbg import dbg_auto -from torch import Tensor - -from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, -) -from spd.clustering.merge import merge_iteration -from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble -from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.clustering.pipeline.s1_split_dataset import split_dataset -from spd.clustering.plotting.activations import plot_activations -from spd.clustering.plotting.merge import plot_dists_distribution -from spd.models.component_model import ComponentModel, SPDRunInfo - -DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" -TEMP_DIR: Path = Path( - "tests/.temp" -) # save to an actual dir that is gitignored, so users can view plots -TEMP_DIR.mkdir(parents=True, exist_ok=True) - - -TIMER_RECORDS: list[tuple[str, float]] = list() class TimerContext: @@ -47,6 +17,38 @@ def __exit__(self, exc_type, exc_value, traceback): # pyright: ignore[reportUnk print(f"[TIMER END] {self.message} - Elapsed time: {elapsed_time:.2f} seconds") +with TimerContext(message="Import modules"): + from pathlib import Path + + import matplotlib.pyplot as plt + import torch + from jaxtyping import Int + from muutils.dbg import dbg_auto + from torch import Tensor + + from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, + ) + from spd.clustering.merge import merge_iteration + from spd.clustering.merge_config import MergeConfig + from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble + from spd.clustering.merge_run_config import ClusteringRunConfig + from spd.clustering.pipeline.s1_split_dataset import split_dataset + from spd.clustering.plotting.activations import plot_activations + from spd.clustering.plotting.merge import plot_dists_distribution + from spd.models.component_model import ComponentModel, SPDRunInfo + + DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" + TEMP_DIR: Path = Path( + "tests/.temp" + ) # save to an actual dir that is gitignored, so users can view plots + TEMP_DIR.mkdir(parents=True, exist_ok=True) + + TIMER_RECORDS: list[tuple[str, float]] = list() + + # magic autoreload # %load_ext autoreload # %autoreload 2 @@ -72,7 +74,10 @@ def __exit__(self, exc_type, exc_value, traceback): # pyright: ignore[reportUnk ) with TimerContext(message="Load data"): - BATCHES, _ = split_dataset(config=CONFIG) + BATCHES, _ = split_dataset( + config=CONFIG, + config_kwargs=dict(streaming=True), + ) # %% # Load data batch From bb737fae2705ccd16e2e4d35988947452cbd7d89 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:24:20 +0100 Subject: [PATCH 06/20] oops --- spd/clustering/pipeline/s1_split_dataset.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py index 622d0539b..307c590ea 100644 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -86,12 +86,14 @@ def _get_dataloader_lm( f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" ) - config_kwargs_: dict[str, Any] = dict( - is_tokenized=False, - streaming=False, - seed=0, - **(config_kwargs or {}), # allow overrides - ) + config_kwargs_: dict[str, Any] = { + **dict( + is_tokenized=False, + streaming=False, + seed=0, + ), + **(config_kwargs or {}), + } dataset_config: DatasetConfig = DatasetConfig( name=cfg.task_config.dataset_name, From adbb18687397fc6e5ddf1edb0b13bcaf838b1783 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:30:26 +0100 Subject: [PATCH 07/20] remove timers --- tests/clustering/scripts/cluster_ss.py | 242 +++++++++++-------------- 1 file changed, 104 insertions(+), 138 deletions(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 5935e615b..641fac903 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -1,52 +1,33 @@ # %% -import time - - -class TimerContext: - def __init__(self, message: str): - self.message: str = message - self.start_time: float - - def __enter__(self): - print(f"[TIMER START] {self.message}") - self.start_time = time.perf_counter() - - def __exit__(self, exc_type, exc_value, traceback): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType] - elapsed_time: float = time.perf_counter() - self.start_time - TIMER_RECORDS.append((self.message, elapsed_time)) - print(f"[TIMER END] {self.message} - Elapsed time: {elapsed_time:.2f} seconds") - - -with TimerContext(message="Import modules"): - from pathlib import Path - - import matplotlib.pyplot as plt - import torch - from jaxtyping import Int - from muutils.dbg import dbg_auto - from torch import Tensor - - from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, - ) - from spd.clustering.merge import merge_iteration - from spd.clustering.merge_config import MergeConfig - from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble - from spd.clustering.merge_run_config import ClusteringRunConfig - from spd.clustering.pipeline.s1_split_dataset import split_dataset - from spd.clustering.plotting.activations import plot_activations - from spd.clustering.plotting.merge import plot_dists_distribution - from spd.models.component_model import ComponentModel, SPDRunInfo - - DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" - TEMP_DIR: Path = Path( - "tests/.temp" - ) # save to an actual dir that is gitignored, so users can view plots - TEMP_DIR.mkdir(parents=True, exist_ok=True) - - TIMER_RECORDS: list[tuple[str, float]] = list() +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +from jaxtyping import Int +from muutils.dbg import dbg_auto +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.pipeline.s1_split_dataset import split_dataset +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.models.component_model import ComponentModel, SPDRunInfo + +DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" +TEMP_DIR: Path = Path( + "tests/.temp" +) # save to an actual dir that is gitignored, so users can view plots +TEMP_DIR.mkdir(parents=True, exist_ok=True) + +TIMER_RECORDS: list[tuple[str, float]] = list() # magic autoreload @@ -56,117 +37,102 @@ def __exit__(self, exc_type, exc_value, traceback): # pyright: ignore[reportUnk # %% # Load model and dataset # ============================================================ -with TimerContext(message="Load model"): - MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" - - SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) - MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) - MODEL.to(DEVICE) - SPD_CONFIG = SPD_RUN.config - - # Use split_dataset with RunConfig to get real data - CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), - model_path=MODEL_PATH, - task_name="lm", - n_batches=1, - batch_size=2, - ) - -with TimerContext(message="Load data"): - BATCHES, _ = split_dataset( - config=CONFIG, - config_kwargs=dict(streaming=True), - ) +MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" + +SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) +MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) +MODEL.to(DEVICE) +SPD_CONFIG = SPD_RUN.config + +# Use split_dataset with RunConfig to get real data +CONFIG: ClusteringRunConfig = ClusteringRunConfig( + merge_config=MergeConfig(), + model_path=MODEL_PATH, + task_name="lm", + n_batches=1, + batch_size=2, +) + +BATCHES, _ = split_dataset( + config=CONFIG, + config_kwargs=dict(streaming=True), +) # %% # Load data batch # ============================================================ -with TimerContext(message="Load data batch"): - DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) +DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) # %% # Get component activations # ============================================================ -with TimerContext(message="Get component activations"): - COMPONENT_ACTS: dict[str, Tensor] = component_activations( - model=MODEL, - batch=DATA_BATCH, - device=DEVICE, - sigmoid_type="hard", - ) - - _ = dbg_auto(COMPONENT_ACTS) +COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + batch=DATA_BATCH, + device=DEVICE, + sigmoid_type="hard", +) + +_ = dbg_auto(COMPONENT_ACTS) # %% # Process activations # ============================================================ -with TimerContext(message="Process activations"): - FILTER_DEAD_THRESHOLD: float = 0.001 - FILTER_MODULES: str = "model.layers.0" - - PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( - activations=COMPONENT_ACTS, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, - filter_modules=lambda x: x.startswith(FILTER_MODULES), - seq_mode="concat", - ) - -with TimerContext(message="Plot activations"): - plot_activations( - processed_activations=PROCESSED_ACTIVATIONS, - save_dir=TEMP_DIR, - n_samples_max=256, - wandb_run=None, - save_fmt="svg", - ) +FILTER_DEAD_THRESHOLD: float = 0.001 +FILTER_MODULES: str = "model.layers.0" + +PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + activations=COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, + filter_modules=lambda x: x.startswith(FILTER_MODULES), + seq_mode="concat", +) + +plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, + save_fmt="svg", +) # %% # Compute ensemble merge iterations # ============================================================ -with TimerContext(message="Compute merge iterations"): - MERGE_CFG: MergeConfig = MergeConfig( - activation_threshold=0.01, - alpha=0.01, - iters=2, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, - module_name_filter=FILTER_MODULES, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, +MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.01, + alpha=0.01, + iters=2, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + pop_component_prob=0, + module_name_filter=FILTER_MODULES, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + +# Modern approach: run merge_iteration multiple times to create ensemble +ENSEMBLE_SIZE: int = 2 +HISTORIES: list[MergeHistory] = [] +for i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id=f"batch_{i}", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, ) + HISTORIES.append(HISTORY) - # Modern approach: run merge_iteration multiple times to create ensemble - ENSEMBLE_SIZE: int = 2 - HISTORIES: list[MergeHistory] = [] - for i in range(ENSEMBLE_SIZE): - HISTORY: MergeHistory = merge_iteration( - merge_config=MERGE_CFG, - batch_id=f"batch_{i}", - activations=PROCESSED_ACTIVATIONS.activations, - component_labels=PROCESSED_ACTIVATIONS.labels, - log_callback=None, - ) - HISTORIES.append(HISTORY) - - ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) +ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) # %% # Compute and plot distances # ============================================================ -with TimerContext(message="compute distances"): - DISTANCES = ENSEMBLE.get_distances() - - -with TimerContext(message="plot distances"): - plot_dists_distribution( - distances=DISTANCES, - mode="points", - ) - plt.legend() +DISTANCES = ENSEMBLE.get_distances() -# %% -print("Timer records (s):") -for key, value in TIMER_RECORDS: - print(f"{key:<30}{value:10.2f}") +plot_dists_distribution( + distances=DISTANCES, + mode="points", +) +plt.legend() From 51998845883799ca87be10730a648af6a9ddc8a7 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:31:22 +0100 Subject: [PATCH 08/20] minimize diff --- tests/clustering/scripts/cluster_ss.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 641fac903..926c19136 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -27,9 +27,6 @@ ) # save to an actual dir that is gitignored, so users can view plots TEMP_DIR.mkdir(parents=True, exist_ok=True) -TIMER_RECORDS: list[tuple[str, float]] = list() - - # magic autoreload # %load_ext autoreload # %autoreload 2 @@ -130,7 +127,6 @@ # ============================================================ DISTANCES = ENSEMBLE.get_distances() - plot_dists_distribution( distances=DISTANCES, mode="points", From c1c7d7e5384cbac070d79dea55434dd74aeefada Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:32:34 +0100 Subject: [PATCH 09/20] remove custom CI timing step --- .github/workflows/checks.yaml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index e3e865b39..523c57a0f 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -40,11 +40,6 @@ jobs: - name: Print dependencies run: uv pip list - - name: "[TEMP] run cluster_ss.py" - run: uv run python tests/clustering/scripts/cluster_ss.py - env: - WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} - - name: Run basedpyright run: uv run basedpyright From 17bfdc44c2e2b11cc28db6de8ab91c2c33c702b0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:46:19 +0100 Subject: [PATCH 10/20] dataset streaming in config+cli for spd-cluster test_cluster_ss_notebook was fast, but streaming was not enabled for test_clustering_with_simplestories_config since that tests it thru cli -- hence the downloading was just happening for the latter. so, we adjust the config and cli to allow for enabling dataset streaming --- spd/clustering/merge_run_config.py | 6 +++++- spd/clustering/pipeline/clustering_pipeline.py | 5 ++++- spd/clustering/scripts/main.py | 6 ++++++ tests/clustering/scripts/cluster_ss.py | 2 +- tests/clustering/test_clustering_experiments.py | 1 + 5 files changed, 17 insertions(+), 3 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index feba16967..ee50ec657 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -71,7 +71,7 @@ class ClusteringRunConfig(BaseModel): """Configuration for a complete merge clustering run. Extends MergeConfig with parameters for model, dataset, and batch configuration. - CLI parameters (base_path, devices, workers_per_device) have defaults but will always be overridden + CLI parameters (base_path, devices, workers_per_device, dataset_streaming) have defaults but will always be overridden """ merge_config: MergeConfig = Field( @@ -100,6 +100,10 @@ class ClusteringRunConfig(BaseModel): default="perm_invariant_hamming", description="Method to use for computing distances between clusterings", ) + dataset_streaming: bool = Field( + default=False, + description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) # Implementation details # note that these are *always* overriden by CLI args in `spd/clustering/scripts/main.py`, but we have to have defaults here diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py index 1e07c71d7..6f975ebca 100644 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -48,7 +48,10 @@ def main(config: ClusteringRunConfig) -> None: logger.info(f"Splitting dataset into {config.n_batches} batches...") batches: Iterator[BatchTensor] dataset_config: dict[str, Any] - batches, dataset_config = split_dataset(config=config) + batches, dataset_config = split_dataset( + config=config, + config_kwargs=dict(streaming=config.dataset_streaming), + ) storage.save_batches(batches=batches, config=dataset_config) batch_paths: list[Path] = storage.get_batch_paths() n_batch_paths: int = len(batch_paths) diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py index 56cee4d84..2104482e5 100644 --- a/spd/clustering/scripts/main.py +++ b/spd/clustering/scripts/main.py @@ -43,6 +43,11 @@ def cli() -> None: default=1, help="Maximum number of concurrent clustering processes per device (default: 1)", ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) args: argparse.Namespace = parser.parse_args() logger.info("Starting clustering pipeline") @@ -66,6 +71,7 @@ def cli() -> None: config.base_path = args.base_path config.devices = devices config.workers_per_device = args.workers_per_device + config.dataset_streaming = args.dataset_streaming logger.info(f"Configuration loaded: {config.config_identifier}") logger.info(f"Base path: {config.base_path}") diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 926c19136..1d2962829 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -52,7 +52,7 @@ BATCHES, _ = split_dataset( config=CONFIG, - config_kwargs=dict(streaming=True), + config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199 ) # %% diff --git a/tests/clustering/test_clustering_experiments.py b/tests/clustering/test_clustering_experiments.py index 5031adfce..19ff937f6 100644 --- a/tests/clustering/test_clustering_experiments.py +++ b/tests/clustering/test_clustering_experiments.py @@ -87,6 +87,7 @@ def test_clustering_with_simplestories_config(): "spd-cluster", "--config", str(config_path), + "--dataset-streaming", # see https://github.com/goodfire-ai/spd/pull/199 ], capture_output=True, text=True, From 2fb930f62ae71bee8af604f1ec9c12fb5470c2fc Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 14:51:14 +0100 Subject: [PATCH 11/20] wip --- spd/clustering/pipeline/clustering_pipeline.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py index 6f975ebca..fcd3cde36 100644 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -46,11 +46,16 @@ def main(config: ClusteringRunConfig) -> None: # Split dataset into batches logger.info(f"Splitting dataset into {config.n_batches} batches...") + split_dataset_kwargs: dict[str, Any] = dict() + if config.dataset_streaming: + logger.info("Using streaming dataset loading") + split_dataset_kwargs["config_kwargs"] = dict(streaming=True) + assert config.task_name == "lm", "Streaming dataset loading only supported for 'lm' task" batches: Iterator[BatchTensor] dataset_config: dict[str, Any] batches, dataset_config = split_dataset( config=config, - config_kwargs=dict(streaming=config.dataset_streaming), + **split_dataset_kwargs, ) storage.save_batches(batches=batches, config=dataset_config) batch_paths: list[Path] = storage.get_batch_paths() From c1d80a83ec02f12c15f320bae5c0a417a076ea2f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:10:17 +0100 Subject: [PATCH 12/20] cuda issues??? --- tests/clustering/scripts/cluster_ss.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 1d2962829..63c0a7f80 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -1,7 +1,11 @@ # %% +import os + +# Suppress tokenizer parallelism warning when forking +os.environ["TOKENIZERS_PARALLELISM"] = "false" + from pathlib import Path -import matplotlib.pyplot as plt import torch from jaxtyping import Int from muutils.dbg import dbg_auto @@ -131,4 +135,8 @@ distances=DISTANCES, mode="points", ) -plt.legend() + +# %% +# Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown +# ============================================================ +os._exit(0) From 2bd85c9f2f846524adf75af4bc7071df05bf8feb Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:34:28 +0100 Subject: [PATCH 13/20] [temp] telemetry for action using https://github.com/catchpoint/workflow-telemetry-action --- .github/workflows/checks.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 523c57a0f..8e524fbb8 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -18,6 +18,9 @@ jobs: image: ghcr.io/${{ github.repository }}/ci-mpi:latest timeout-minutes: 15 steps: + - name: Collect Workflow Telemetry + uses: catchpoint/workflow-telemetry-action@v2 + - name: Checkout uses: actions/checkout@v4 From 5df67621bd15b2fbd68a472ee4c2a6cfd0f30e20 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:38:16 +0100 Subject: [PATCH 14/20] minor fixes from claude see https://github.com/goodfire-ai/spd/pull/199#issuecomment-3390405175 --- spd/clustering/merge_run_config.py | 9 +++++++++ spd/clustering/pipeline/clustering_pipeline.py | 4 +++- spd/clustering/pipeline/s1_split_dataset.py | 3 ++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index ee50ec657..d86cbc7a6 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -165,6 +165,15 @@ def validate_intervals(self) -> Self: return self + @model_validator(mode="after") + def validate_streaming_compatibility(self) -> Self: + """Ensure dataset_streaming is only enabled for compatible tasks.""" + if self.dataset_streaming and self.task_name != "lm": + raise ValueError( + f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" + ) + return self + @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition from the model_path diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py index fcd3cde36..2c72d767a 100644 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -50,7 +50,9 @@ def main(config: ClusteringRunConfig) -> None: if config.dataset_streaming: logger.info("Using streaming dataset loading") split_dataset_kwargs["config_kwargs"] = dict(streaming=True) - assert config.task_name == "lm", "Streaming dataset loading only supported for 'lm' task" + assert config.task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{config.task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) batches: Iterator[BatchTensor] dataset_config: dict[str, Any] batches, dataset_config = split_dataset( diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py index 307c590ea..711cda65a 100644 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -22,7 +22,8 @@ def split_dataset( - config: ClusteringRunConfig, **kwargs: dict[str, Any] + config: ClusteringRunConfig, + **kwargs: Any, ) -> tuple[Iterator[BatchTensor], dict[str, Any]]: """Split a dataset into n_batches of batch_size, returning iterator and config""" ds: Generator[BatchTensor, None, None] From 9eabafd1b363cc00f881c0db52ebeccdb4456c4e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:45:17 +0100 Subject: [PATCH 15/20] [temp] no docker container for workflow telemetry to work --- .github/workflows/checks.yaml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 8e524fbb8..90f17c677 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -14,8 +14,8 @@ on: jobs: build: runs-on: ubuntu-latest - container: - image: ghcr.io/${{ github.repository }}/ci-mpi:latest + # container: + # image: ghcr.io/${{ github.repository }}/ci-mpi:latest timeout-minutes: 15 steps: - name: Collect Workflow Telemetry @@ -24,6 +24,12 @@ jobs: - name: Checkout uses: actions/checkout@v4 + - name: install MPI + run: | + apt-get update + apt-get install -y -q git make openmpi-bin libopenmpi-dev + apt-get clean + - name: Install uv uses: astral-sh/setup-uv@v5 with: From 67c536a5ec9ba3a1f32cb89fadb2c11cf7b9459c Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:46:22 +0100 Subject: [PATCH 16/20] sudo in workflow for apt-get --- .github/workflows/checks.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 90f17c677..5a0cf8df1 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -26,9 +26,9 @@ jobs: - name: install MPI run: | - apt-get update - apt-get install -y -q git make openmpi-bin libopenmpi-dev - apt-get clean + sudo apt-get update + sudo apt-get install -y -q git make openmpi-bin libopenmpi-dev + sudo apt-get clean - name: Install uv uses: astral-sh/setup-uv@v5 From bb777a601f6d08f1b78f6e3af59262155ea109df Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:54:07 +0100 Subject: [PATCH 17/20] use --dist worksteal in CI pytest --- .github/workflows/checks.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 5a0cf8df1..f4d7aeae8 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -60,7 +60,7 @@ jobs: - name: Run tests - run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto + run: uv run pytest tests/ --runslow --durations 20 --numprocesses auto --dist worksteal env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} OMPI_ALLOW_RUN_AS_ROOT: 1 From db726a4e1296b617de53f9fc626ab6ca37a9ac38 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:55:09 +0100 Subject: [PATCH 18/20] revert temp CI changes --- .github/workflows/checks.yaml | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index f4d7aeae8..491dc7f78 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -14,22 +14,11 @@ on: jobs: build: runs-on: ubuntu-latest - # container: - # image: ghcr.io/${{ github.repository }}/ci-mpi:latest timeout-minutes: 15 steps: - - name: Collect Workflow Telemetry - uses: catchpoint/workflow-telemetry-action@v2 - - name: Checkout uses: actions/checkout@v4 - - name: install MPI - run: | - sudo apt-get update - sudo apt-get install -y -q git make openmpi-bin libopenmpi-dev - sudo apt-get clean - - name: Install uv uses: astral-sh/setup-uv@v5 with: From e3bd858d4de280d8570e1d056ddfef4cdf0fcef6 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:57:15 +0100 Subject: [PATCH 19/20] minor fixes. accidentally removed container in CI lol --- .github/workflows/checks.yaml | 3 ++- spd/clustering/pipeline/clustering_pipeline.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 491dc7f78..81027df03 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -14,6 +14,8 @@ on: jobs: build: runs-on: ubuntu-latest + container: + image: ghcr.io/${{ github.repository }}/ci-mpi:latest timeout-minutes: 15 steps: - name: Checkout @@ -47,7 +49,6 @@ jobs: - name: Run ruff format run: uv run ruff format . - - name: Run tests run: uv run pytest tests/ --runslow --durations 20 --numprocesses auto --dist worksteal env: diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py index 2c72d767a..8c6b72f9d 100644 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -50,6 +50,8 @@ def main(config: ClusteringRunConfig) -> None: if config.dataset_streaming: logger.info("Using streaming dataset loading") split_dataset_kwargs["config_kwargs"] = dict(streaming=True) + # check this here as well as the model validator because we edit `config.dataset_streaming` after init in main() after the CLI args are parsed + # not sure if this is actually a problem though assert config.task_name == "lm", ( f"Streaming dataset loading only supported for 'lm' task, got '{config.task_name = }'. Remove dataset_streaming=True from config or use a different task." ) From ca4502d23782d9229aa5a936451c8fafad0e21b1 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 15:58:20 +0100 Subject: [PATCH 20/20] link to issue in script we do the weird thing because of CUDA issues. we should remove it at some point, and make sure that the CUDA worker catches it https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 --- tests/clustering/scripts/cluster_ss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 63c0a7f80..00ef733d6 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -138,5 +138,6 @@ # %% # Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown +# see https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 # ============================================================ os._exit(0)