diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 258935d2b..81027df03 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -50,7 +50,7 @@ jobs: run: uv run ruff format . - name: Run tests - run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto + run: uv run pytest tests/ --runslow --durations 20 --numprocesses auto --dist worksteal env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} OMPI_ALLOW_RUN_AS_ROOT: 1 diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index feba16967..d86cbc7a6 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -71,7 +71,7 @@ class ClusteringRunConfig(BaseModel): """Configuration for a complete merge clustering run. Extends MergeConfig with parameters for model, dataset, and batch configuration. - CLI parameters (base_path, devices, workers_per_device) have defaults but will always be overridden + CLI parameters (base_path, devices, workers_per_device, dataset_streaming) have defaults but will always be overridden """ merge_config: MergeConfig = Field( @@ -100,6 +100,10 @@ class ClusteringRunConfig(BaseModel): default="perm_invariant_hamming", description="Method to use for computing distances between clusterings", ) + dataset_streaming: bool = Field( + default=False, + description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) # Implementation details # note that these are *always* overriden by CLI args in `spd/clustering/scripts/main.py`, but we have to have defaults here @@ -161,6 +165,15 @@ def validate_intervals(self) -> Self: return self + @model_validator(mode="after") + def validate_streaming_compatibility(self) -> Self: + """Ensure dataset_streaming is only enabled for compatible tasks.""" + if self.dataset_streaming and self.task_name != "lm": + raise ValueError( + f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" + ) + return self + @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition from the model_path diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py index 1e07c71d7..8c6b72f9d 100644 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -46,9 +46,21 @@ def main(config: ClusteringRunConfig) -> None: # Split dataset into batches logger.info(f"Splitting dataset into {config.n_batches} batches...") + split_dataset_kwargs: dict[str, Any] = dict() + if config.dataset_streaming: + logger.info("Using streaming dataset loading") + split_dataset_kwargs["config_kwargs"] = dict(streaming=True) + # check this here as well as the model validator because we edit `config.dataset_streaming` after init in main() after the CLI args are parsed + # not sure if this is actually a problem though + assert config.task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{config.task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) batches: Iterator[BatchTensor] dataset_config: dict[str, Any] - batches, dataset_config = split_dataset(config=config) + batches, dataset_config = split_dataset( + config=config, + **split_dataset_kwargs, + ) storage.save_batches(batches=batches, config=dataset_config) batch_paths: list[Path] = storage.get_batch_paths() n_batch_paths: int = len(batch_paths) diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py index d5427e600..711cda65a 100644 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -21,7 +21,10 @@ from spd.models.component_model import ComponentModel, SPDRunInfo -def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], dict[str, Any]]: +def split_dataset( + config: ClusteringRunConfig, + **kwargs: Any, +) -> tuple[Iterator[BatchTensor], dict[str, Any]]: """Split a dataset into n_batches of batch_size, returning iterator and config""" ds: Generator[BatchTensor, None, None] ds_config_dict: dict[str, Any] @@ -30,11 +33,13 @@ def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], d ds, ds_config_dict = _get_dataloader_lm( model_path=config.model_path, batch_size=config.batch_size, + **kwargs, ) case "resid_mlp": ds, ds_config_dict = _get_dataloader_resid_mlp( model_path=config.model_path, batch_size=config.batch_size, + **kwargs, ) case name: raise ValueError( @@ -56,6 +61,7 @@ def limited_iterator() -> Iterator[BatchTensor]: def _get_dataloader_lm( model_path: str, batch_size: int, + config_kwargs: dict[str, Any] | None = None, ) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]: """split up a SS dataset into n_batches of batch_size, returned the saved paths @@ -81,15 +87,22 @@ def _get_dataloader_lm( f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" ) + config_kwargs_: dict[str, Any] = { + **dict( + is_tokenized=False, + streaming=False, + seed=0, + ), + **(config_kwargs or {}), + } + dataset_config: DatasetConfig = DatasetConfig( name=cfg.task_config.dataset_name, hf_tokenizer_path=pretrained_model_name, split=cfg.task_config.train_data_split, n_ctx=cfg.task_config.max_seq_len, - is_tokenized=False, - streaming=False, - seed=0, column_name=cfg.task_config.column_name, + **config_kwargs_, ) with SpinnerContext(message="getting dataloader..."): diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index 2411eca38..dc8c8658c 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from pathlib import Path +from typing import Literal import matplotlib as mpl import matplotlib.pyplot as plt @@ -20,7 +21,7 @@ def plot_activations( processed_activations: ProcessedActivations, save_dir: Path, n_samples_max: int, - pdf_prefix: str = "activations", + figure_prefix: str = "activations", figsize_raw: tuple[int, int] = (12, 4), figsize_concat: tuple[int, int] = (12, 2), figsize_coact: tuple[int, int] = (8, 6), @@ -28,6 +29,7 @@ def plot_activations( hist_bins: int = 100, do_sorted_samples: bool = False, wandb_run: wandb.sdk.wandb_run.Run | None = None, + save_fmt: Literal["pdf", "png", "svg"] = "pdf", ) -> None: """Plot activation visualizations including raw, concatenated, sorted, and coactivations. @@ -37,7 +39,7 @@ def plot_activations( coact: Coactivation matrix labels: Component labels save_dir: The directory to save the plots to - pdf_prefix: Prefix for PDF filenames + figure_prefix: Prefix for figure filenames figsize_raw: Figure size for raw activations figsize_concat: Figure size for concatenated activations figsize_coact: Figure size for coactivations @@ -77,7 +79,7 @@ def plot_activations( axs_act[i].set_ylabel(f"components\n{key}") axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") - fig1_fname = save_dir / f"{pdf_prefix}_raw.pdf" + fig1_fname = save_dir / f"{figure_prefix}_raw.{save_fmt}" _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -100,7 +102,7 @@ def plot_activations( plt.colorbar(im2) - fig2_fname: Path = save_dir / f"{pdf_prefix}_concatenated.pdf" + fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.{save_fmt}" fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -169,7 +171,7 @@ def plot_activations( plt.colorbar(im3) - fig3_fname: Path = save_dir / f"{pdf_prefix}_concatenated_sorted.pdf" + fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.{save_fmt}" fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -193,7 +195,7 @@ def plot_activations( plt.colorbar(im4) - fig4_fname: Path = save_dir / f"{pdf_prefix}_coactivations.pdf" + fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.{save_fmt}" fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -217,7 +219,7 @@ def plot_activations( add_component_labeling(ax4_log, labels, axis="x") add_component_labeling(ax4_log, labels, axis="y") plt.colorbar(im4_log) - fig4_log_fname: Path = save_dir / f"{pdf_prefix}_coactivations_log.pdf" + fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.{save_fmt}" fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -312,7 +314,7 @@ def plot_activations( plt.tight_layout() - fig5_fname: Path = save_dir / f"{pdf_prefix}_histograms.pdf" + fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.{save_fmt}" fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) # Log to WandB if available diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py index e470b3114..8a2cc20df 100644 --- a/spd/clustering/plotting/merge.py +++ b/spd/clustering/plotting/merge.py @@ -17,7 +17,7 @@ figsize=(16, 10), tick_spacing=5, save_pdf=False, - pdf_prefix="merge_iteration", + figure_prefix="merge_iteration", ) @@ -168,7 +168,9 @@ def plot_merge_iteration( if plot_config_["save_pdf"]: fig.savefig( - f"{plot_config_['pdf_prefix']}_iter_{iteration:03d}.pdf", bbox_inches="tight", dpi=300 + f"{plot_config_['figure_prefix']}_iter_{iteration:03d}.pdf", + bbox_inches="tight", + dpi=300, ) if show: diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py index 56cee4d84..2104482e5 100644 --- a/spd/clustering/scripts/main.py +++ b/spd/clustering/scripts/main.py @@ -43,6 +43,11 @@ def cli() -> None: default=1, help="Maximum number of concurrent clustering processes per device (default: 1)", ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) args: argparse.Namespace = parser.parse_args() logger.info("Starting clustering pipeline") @@ -66,6 +71,7 @@ def cli() -> None: config.base_path = args.base_path config.devices = devices config.workers_per_device = args.workers_per_device + config.dataset_streaming = args.dataset_streaming logger.info(f"Configuration loaded: {config.config_identifier}") logger.info(f"Base path: {config.base_path}") diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 6ede368f0..00ef733d6 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -1,7 +1,11 @@ # %% +import os + +# Suppress tokenizer parallelism warning when forking +os.environ["TOKENIZERS_PARALLELISM"] = "false" + from pathlib import Path -import matplotlib.pyplot as plt import torch from jaxtyping import Int from muutils.dbg import dbg_auto @@ -49,7 +53,11 @@ n_batches=1, batch_size=2, ) -BATCHES, _ = split_dataset(config=CONFIG) + +BATCHES, _ = split_dataset( + config=CONFIG, + config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199 +) # %% # Load data batch @@ -85,6 +93,7 @@ save_dir=TEMP_DIR, n_samples_max=256, wandb_run=None, + save_fmt="svg", ) # %% @@ -126,4 +135,9 @@ distances=DISTANCES, mode="points", ) -plt.legend() + +# %% +# Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown +# see https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 +# ============================================================ +os._exit(0) diff --git a/tests/clustering/test_clustering_experiments.py b/tests/clustering/test_clustering_experiments.py index 5031adfce..19ff937f6 100644 --- a/tests/clustering/test_clustering_experiments.py +++ b/tests/clustering/test_clustering_experiments.py @@ -87,6 +87,7 @@ def test_clustering_with_simplestories_config(): "spd-cluster", "--config", str(config_path), + "--dataset-streaming", # see https://github.com/goodfire-ai/spd/pull/199 ], capture_output=True, text=True,