Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
run: uv run ruff format .

- name: Run tests
run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto
run: uv run pytest tests/ --runslow --durations 20 --numprocesses auto --dist worksteal
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
OMPI_ALLOW_RUN_AS_ROOT: 1
Expand Down
15 changes: 14 additions & 1 deletion spd/clustering/merge_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ClusteringRunConfig(BaseModel):
"""Configuration for a complete merge clustering run.

Extends MergeConfig with parameters for model, dataset, and batch configuration.
CLI parameters (base_path, devices, workers_per_device) have defaults but will always be overridden
CLI parameters (base_path, devices, workers_per_device, dataset_streaming) have defaults but will always be overridden
"""

merge_config: MergeConfig = Field(
Expand Down Expand Up @@ -100,6 +100,10 @@ class ClusteringRunConfig(BaseModel):
default="perm_invariant_hamming",
description="Method to use for computing distances between clusterings",
)
dataset_streaming: bool = Field(
default=False,
description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199",
)

# Implementation details
# note that these are *always* overriden by CLI args in `spd/clustering/scripts/main.py`, but we have to have defaults here
Expand Down Expand Up @@ -161,6 +165,15 @@ def validate_intervals(self) -> Self:

return self

@model_validator(mode="after")
def validate_streaming_compatibility(self) -> Self:
"""Ensure dataset_streaming is only enabled for compatible tasks."""
if self.dataset_streaming and self.task_name != "lm":
raise ValueError(
f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'"
)
return self

@property
def wandb_decomp_model(self) -> str:
"""Extract the WandB run ID of the source decomposition from the model_path
Expand Down
14 changes: 13 additions & 1 deletion spd/clustering/pipeline/clustering_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,21 @@ def main(config: ClusteringRunConfig) -> None:

# Split dataset into batches
logger.info(f"Splitting dataset into {config.n_batches} batches...")
split_dataset_kwargs: dict[str, Any] = dict()
if config.dataset_streaming:
logger.info("Using streaming dataset loading")
split_dataset_kwargs["config_kwargs"] = dict(streaming=True)
# check this here as well as the model validator because we edit `config.dataset_streaming` after init in main() after the CLI args are parsed
# not sure if this is actually a problem though
assert config.task_name == "lm", (
f"Streaming dataset loading only supported for 'lm' task, got '{config.task_name = }'. Remove dataset_streaming=True from config or use a different task."
)
batches: Iterator[BatchTensor]
dataset_config: dict[str, Any]
batches, dataset_config = split_dataset(config=config)
batches, dataset_config = split_dataset(
config=config,
**split_dataset_kwargs,
)
storage.save_batches(batches=batches, config=dataset_config)
batch_paths: list[Path] = storage.get_batch_paths()
n_batch_paths: int = len(batch_paths)
Expand Down
21 changes: 17 additions & 4 deletions spd/clustering/pipeline/s1_split_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from spd.models.component_model import ComponentModel, SPDRunInfo


def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], dict[str, Any]]:
def split_dataset(
config: ClusteringRunConfig,
**kwargs: Any,
) -> tuple[Iterator[BatchTensor], dict[str, Any]]:
"""Split a dataset into n_batches of batch_size, returning iterator and config"""
ds: Generator[BatchTensor, None, None]
ds_config_dict: dict[str, Any]
Expand All @@ -30,11 +33,13 @@ def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], d
ds, ds_config_dict = _get_dataloader_lm(
model_path=config.model_path,
batch_size=config.batch_size,
**kwargs,
)
case "resid_mlp":
ds, ds_config_dict = _get_dataloader_resid_mlp(
model_path=config.model_path,
batch_size=config.batch_size,
**kwargs,
)
case name:
raise ValueError(
Expand All @@ -56,6 +61,7 @@ def limited_iterator() -> Iterator[BatchTensor]:
def _get_dataloader_lm(
model_path: str,
batch_size: int,
config_kwargs: dict[str, Any] | None = None,
) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]:
"""split up a SS dataset into n_batches of batch_size, returned the saved paths

Expand All @@ -81,15 +87,22 @@ def _get_dataloader_lm(
f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }"
)

config_kwargs_: dict[str, Any] = {
**dict(
is_tokenized=False,
streaming=False,
seed=0,
),
**(config_kwargs or {}),
}

dataset_config: DatasetConfig = DatasetConfig(
name=cfg.task_config.dataset_name,
hf_tokenizer_path=pretrained_model_name,
split=cfg.task_config.train_data_split,
n_ctx=cfg.task_config.max_seq_len,
is_tokenized=False,
streaming=False,
seed=0,
column_name=cfg.task_config.column_name,
**config_kwargs_,
)

with SpinnerContext(message="getting dataloader..."):
Expand Down
18 changes: 10 additions & 8 deletions spd/clustering/plotting/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Sequence
from pathlib import Path
from typing import Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -20,14 +21,15 @@ def plot_activations(
processed_activations: ProcessedActivations,
save_dir: Path,
n_samples_max: int,
pdf_prefix: str = "activations",
figure_prefix: str = "activations",
figsize_raw: tuple[int, int] = (12, 4),
figsize_concat: tuple[int, int] = (12, 2),
figsize_coact: tuple[int, int] = (8, 6),
hist_scales: tuple[str, str] = ("lin", "log"),
hist_bins: int = 100,
do_sorted_samples: bool = False,
wandb_run: wandb.sdk.wandb_run.Run | None = None,
save_fmt: Literal["pdf", "png", "svg"] = "pdf",
) -> None:
"""Plot activation visualizations including raw, concatenated, sorted, and coactivations.

Expand All @@ -37,7 +39,7 @@ def plot_activations(
coact: Coactivation matrix
labels: Component labels
save_dir: The directory to save the plots to
pdf_prefix: Prefix for PDF filenames
figure_prefix: Prefix for figure filenames
figsize_raw: Figure size for raw activations
figsize_concat: Figure size for concatenated activations
figsize_coact: Figure size for coactivations
Expand Down Expand Up @@ -77,7 +79,7 @@ def plot_activations(
axs_act[i].set_ylabel(f"components\n{key}")
axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})")

fig1_fname = save_dir / f"{pdf_prefix}_raw.pdf"
fig1_fname = save_dir / f"{figure_prefix}_raw.{save_fmt}"
_fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300)

# Log to WandB if available
Expand All @@ -100,7 +102,7 @@ def plot_activations(

plt.colorbar(im2)

fig2_fname: Path = save_dir / f"{pdf_prefix}_concatenated.pdf"
fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.{save_fmt}"
fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300)

# Log to WandB if available
Expand Down Expand Up @@ -169,7 +171,7 @@ def plot_activations(

plt.colorbar(im3)

fig3_fname: Path = save_dir / f"{pdf_prefix}_concatenated_sorted.pdf"
fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.{save_fmt}"
fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300)

# Log to WandB if available
Expand All @@ -193,7 +195,7 @@ def plot_activations(

plt.colorbar(im4)

fig4_fname: Path = save_dir / f"{pdf_prefix}_coactivations.pdf"
fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.{save_fmt}"
fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300)

# Log to WandB if available
Expand All @@ -217,7 +219,7 @@ def plot_activations(
add_component_labeling(ax4_log, labels, axis="x")
add_component_labeling(ax4_log, labels, axis="y")
plt.colorbar(im4_log)
fig4_log_fname: Path = save_dir / f"{pdf_prefix}_coactivations_log.pdf"
fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.{save_fmt}"
fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300)

# Log to WandB if available
Expand Down Expand Up @@ -312,7 +314,7 @@ def plot_activations(

plt.tight_layout()

fig5_fname: Path = save_dir / f"{pdf_prefix}_histograms.pdf"
fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.{save_fmt}"
fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300)

# Log to WandB if available
Expand Down
6 changes: 4 additions & 2 deletions spd/clustering/plotting/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
figsize=(16, 10),
tick_spacing=5,
save_pdf=False,
pdf_prefix="merge_iteration",
figure_prefix="merge_iteration",
)


Expand Down Expand Up @@ -168,7 +168,9 @@ def plot_merge_iteration(

if plot_config_["save_pdf"]:
fig.savefig(
f"{plot_config_['pdf_prefix']}_iter_{iteration:03d}.pdf", bbox_inches="tight", dpi=300
f"{plot_config_['figure_prefix']}_iter_{iteration:03d}.pdf",
bbox_inches="tight",
dpi=300,
)

if show:
Expand Down
6 changes: 6 additions & 0 deletions spd/clustering/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def cli() -> None:
default=1,
help="Maximum number of concurrent clustering processes per device (default: 1)",
)
parser.add_argument(
"--dataset-streaming",
action="store_true",
help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199",
)
args: argparse.Namespace = parser.parse_args()

logger.info("Starting clustering pipeline")
Expand All @@ -66,6 +71,7 @@ def cli() -> None:
config.base_path = args.base_path
config.devices = devices
config.workers_per_device = args.workers_per_device
config.dataset_streaming = args.dataset_streaming

logger.info(f"Configuration loaded: {config.config_identifier}")
logger.info(f"Base path: {config.base_path}")
Expand Down
20 changes: 17 additions & 3 deletions tests/clustering/scripts/cluster_ss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# %%
import os

# Suppress tokenizer parallelism warning when forking
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from pathlib import Path

import matplotlib.pyplot as plt
import torch
from jaxtyping import Int
from muutils.dbg import dbg_auto
Expand Down Expand Up @@ -49,7 +53,11 @@
n_batches=1,
batch_size=2,
)
BATCHES, _ = split_dataset(config=CONFIG)

BATCHES, _ = split_dataset(
config=CONFIG,
config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199
)

# %%
# Load data batch
Expand Down Expand Up @@ -85,6 +93,7 @@
save_dir=TEMP_DIR,
n_samples_max=256,
wandb_run=None,
save_fmt="svg",
)

# %%
Expand Down Expand Up @@ -126,4 +135,9 @@
distances=DISTANCES,
mode="points",
)
plt.legend()

# %%
# Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown
# see https://github.com/goodfire-ai/spd/issues/201#issue-3503138939
# ============================================================
os._exit(0)
1 change: 1 addition & 0 deletions tests/clustering/test_clustering_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_clustering_with_simplestories_config():
"spd-cluster",
"--config",
str(config_path),
"--dataset-streaming", # see https://github.com/goodfire-ai/spd/pull/199
],
capture_output=True,
text=True,
Expand Down