Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion spd/clustering/configs/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iters = 100 # iterations to run. setting this to exactly the number of componen
pop_component_prob = 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway
filter_dead_threshold = 0.001 # Threshold for filtering dead components
module_name_filter = "__NULL__" # Can be a string prefix like "model.layers.0." if you want to do only some modules
rank_cost_fn_name = "const_1" # Options: const_1, const_2, log, linear
merge_pair_sampling_method = "range" # Method for sampling merge pairs: 'range' or 'mcmc'

[merge_config.merge_pair_sampling_kwargs]
Expand Down
1 change: 0 additions & 1 deletion spd/clustering/configs/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ merge_config:
pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway
filter_dead_threshold: 0.001 # Threshold for filtering dead components
module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules
rank_cost_fn_name: const_1 # Options: const_1, const_2, log, linear

# Run configuration
model_path: wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh # WandB path to the decomposed model
Expand Down
3 changes: 1 addition & 2 deletions spd/clustering/configs/test-resid_mlp1.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
"merge_pair_sampling_kwargs": {"threshold": 0.05},
"pop_component_prob": 0,
"filter_dead_threshold": 0.1,
"module_name_filter": null,
"rank_cost_fn_name": "const_1"
"module_name_filter": null
},
"experiment_key": "resid_mlp1",
"n_batches": 2,
Expand Down
4 changes: 2 additions & 2 deletions spd/clustering/merge_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Any, Literal

from pydantic import (
BaseModel,
Field,
PositiveInt,
)

from spd.base_config import BaseConfig
from spd.clustering.consts import ClusterCoactivationShaped, MergePair
from spd.clustering.math.merge_pair_samplers import (
MERGE_PAIR_SAMPLERS,
Expand Down Expand Up @@ -44,7 +44,7 @@ def _to_module_filter(
raise TypeError(f"filter_modules must be str, set, or callable, got {type(filter_modules)}") # pyright: ignore[reportUnreachable]


class MergeConfig(BaseModel):
class MergeConfig(BaseConfig):
activation_threshold: Probability | None = Field(
default=0.01,
description="Threshold for considering a component active in a group. If None, use raw scalar causal importances",
Expand Down
120 changes: 45 additions & 75 deletions spd/clustering/merge_run_config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Configuration for merge clustering runs that combines merge config with run parameters."""

import hashlib
import json
import tomllib
import warnings
from pathlib import Path
from typing import Any, Literal, Self

import yaml
from muutils.misc.numerical import shorten_numerical_to_str
from pydantic import BaseModel, Field, PositiveInt, model_validator
from pydantic import Field, PositiveInt, model_validator

from spd.base_config import BaseConfig
from spd.clustering.consts import DistancesMethod
from spd.clustering.merge_config import MergeConfig
from spd.registry import EXPERIMENT_REGISTRY, ExperimentConfig
Expand Down Expand Up @@ -67,7 +66,7 @@ def replace_sentinel_recursive(obj: Any) -> Any:
return replace_sentinel_recursive(data)


class ClusteringRunConfig(BaseModel):
class ClusteringRunConfig(BaseConfig):
"""Configuration for a complete merge clustering run.

Extends MergeConfig with parameters for model, dataset, and batch configuration.
Expand Down Expand Up @@ -146,24 +145,27 @@ def validate_model_path(self) -> Self:
)
return self

@model_validator(mode="after")
def validate_intervals(self) -> Self:
@model_validator(mode="before")
@classmethod
def validate_intervals(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Ensure all required interval keys are present."""

data_intervals: dict[IntervalKey, Any] = data.get("intervals", {})
# warning if any keys are missing
missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(self.intervals.keys())
missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(data_intervals.keys())
if missing_keys:
warnings.warn(
f"Missing interval keys in {self.intervals = }: {missing_keys}. Using defaults for those.",
f"Missing interval keys in {data_intervals = }: {missing_keys}. Using defaults for those.",
UserWarning,
stacklevel=1,
)

self.intervals = {
data["intervals"] = {
**_DEFAULT_INTERVALS,
**self.intervals,
**data_intervals,
}

return self
return data

@model_validator(mode="after")
def validate_streaming_compatibility(self) -> Self:
Expand All @@ -174,6 +176,38 @@ def validate_streaming_compatibility(self) -> Self:
)
return self

@model_validator(mode="before")
@classmethod
def handle_experiment_key(cls, data: dict[str, Any]) -> dict[str, Any]:
"""handle passing experiment key instead of model_path and task_name.

if we provide an experiment_key, then:
1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name
2. check it's consistent with model_path and task_name from the file if those are provided

"""
experiment_key: str | None = data.get("experiment_key")
model_path: str | None = data.get("model_path")
task_name: str | None = data.get("task_name")
if experiment_key is not None:
exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key]

# Enforce consistency if explicit fields present
if model_path is not None:
assert model_path == exp_config.canonical_run, (
f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})"
)
if task_name is not None:
assert task_name == exp_config.task_name, (
f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})"
)

# overwrite in data dict
data["model_path"] = exp_config.canonical_run
data["task_name"] = exp_config.task_name

return data

@property
def wandb_decomp_model(self) -> str:
"""Extract the WandB run ID of the source decomposition from the model_path
Expand Down Expand Up @@ -213,70 +247,6 @@ def stable_hash(self) -> str:
"""Generate a stable hash including all config parameters."""
return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6]

@classmethod
def read(cls, path: Path) -> "ClusteringRunConfig":
"""Load config from JSON, YAML, or TOML file.

Handles legacy spd_exp: model_path format and enforces consistency.
For TOML files, the sentinel value "__NULL__" is converted to None.
"""
# read the file contents, load them according to extension
data: dict[str, Any]
content: str
if path.suffix == ".json":
content = path.read_text()
data = json.loads(content)
elif path.suffix in [".yaml", ".yml"]:
content = path.read_text()
data = yaml.safe_load(content)
elif path.suffix == ".toml":
data = toml_read_file_with_none(path)
else:
raise ValueError(
f"Unsupported file extension '{path.suffix}' on file '{path}' -- must be .json, .yaml, .yml, or .toml"
)

# if we provide an experiment_key, then:
# 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name
# 2. check it's consistent with model_path and task_name from the file if those are provided
experiment_key: str | None = data.get("experiment_key")
model_path: str | None = data.get("model_path")
task_name: str | None = data.get("task_name")
if experiment_key is not None:
exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key]

# Enforce consistency if explicit fields present
if model_path is not None:
assert model_path == exp_config.canonical_run, (
f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})"
)
if task_name is not None:
assert task_name == exp_config.task_name, (
f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})"
)

# overwrite in data dict
data["model_path"] = exp_config.canonical_run
data["task_name"] = exp_config.task_name

return cls.model_validate(data)

def save(self, path: Path) -> None:
"""Save config to file (format inferred from extension)."""
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix == ".json":
path.write_text(self.model_dump_json(indent=2))
elif path.suffix in [".yaml", ".yml"]:
path.write_text(
yaml.dump(
self.model_dump(mode="json"),
default_flow_style=False,
sort_keys=False,
)
)
else:
raise ValueError(f"Unsupported file extension: {path.suffix}")

def model_dump_with_properties(self) -> dict[str, Any]:
"""Serialize config including computed properties for WandB logging."""
base_dump: dict[str, Any] = self.model_dump()
Expand Down
2 changes: 1 addition & 1 deletion spd/clustering/pipeline/s2_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def cli() -> None:
args: argparse.Namespace = parser.parse_args()

# Load config
config: ClusteringRunConfig = ClusteringRunConfig.read(args.config)
config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config)

# Run clustering
result: ClusteringResult = run_clustering(
Expand Down
2 changes: 1 addition & 1 deletion spd/clustering/pipeline/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def save_run_config(self, config: ClusteringRunConfig) -> Path:
)

def load_run_config(self) -> ClusteringRunConfig:
return ClusteringRunConfig.read(self.run_config_file)
return ClusteringRunConfig.from_file(self.run_config_file)

# Dashboard storage methods

Expand Down
15 changes: 10 additions & 5 deletions spd/clustering/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ def cli() -> None:
# Note that the defaults for args here always override the default values in `RunConfig` itself,
# but we must have those defaults to avoid type issues
logger.info(f"Loading config from {args.config}")
config: ClusteringRunConfig = ClusteringRunConfig.read(args.config)
config.base_path = args.base_path
config.devices = devices
config.workers_per_device = args.workers_per_device
config.dataset_streaming = args.dataset_streaming
config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config)
# Use model_copy to update frozen fields
config = config.model_copy(
update={
"base_path": args.base_path,
"devices": devices,
"workers_per_device": args.workers_per_device,
"dataset_streaming": args.dataset_streaming,
}
)

logger.info(f"Configuration loaded: {config.config_identifier}")
logger.info(f"Base path: {config.base_path}")
Expand Down
4 changes: 2 additions & 2 deletions spd/utils/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import wandb_workspaces.reports.v2 as wr
import wandb_workspaces.workspaces as ws
from dotenv import load_dotenv
from pydantic import BaseModel
from wandb.apis.public import File, Run

from spd.base_config import BaseConfig
from spd.log import logger
from spd.registry import EXPERIMENT_REGISTRY
from spd.settings import REPO_ROOT
Expand Down Expand Up @@ -120,7 +120,7 @@ def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path:
return path


def init_wandb[T_config: BaseModel](
def init_wandb[T_config: BaseConfig](
config: T_config, project: str, name: str | None = None, tags: list[str] | None = None
) -> T_config:
"""Initialize Weights & Biases and return a config updated with sweep hyperparameters.
Expand Down
2 changes: 1 addition & 1 deletion tests/clustering/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_save_and_load_run_config(self, temp_storage: ClusteringStorage):
assert saved_path == temp_storage.run_config_file

# Load and verify
loaded_config = ClusteringRunConfig.read(saved_path)
loaded_config = ClusteringRunConfig.from_file(saved_path)
assert loaded_config.n_batches == 5
assert loaded_config.batch_size == 32
assert loaded_config.task_name == "lm"
Expand Down