From 0e5595fca7ab72b344d02fb7a389d875cf6345e7 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 16:00:49 +0100 Subject: [PATCH 1/5] switch BaseModel to BaseConfig, get rid of old save/read logic --- spd/clustering/merge_config.py | 4 +- spd/clustering/merge_run_config.py | 102 ++++++++--------------- spd/clustering/pipeline/s2_clustering.py | 2 +- spd/clustering/scripts/main.py | 2 +- spd/utils/wandb_utils.py | 4 +- tests/clustering/test_storage.py | 2 +- 6 files changed, 41 insertions(+), 75 deletions(-) diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 03c601a9f..3bf8b6d5b 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -3,11 +3,11 @@ from typing import Any, Literal from pydantic import ( - BaseModel, Field, PositiveInt, ) +from spd.base_config import BaseConfig from spd.clustering.consts import ClusterCoactivationShaped, MergePair from spd.clustering.math.merge_pair_samplers import ( MERGE_PAIR_SAMPLERS, @@ -44,7 +44,7 @@ def _to_module_filter( raise TypeError(f"filter_modules must be str, set, or callable, got {type(filter_modules)}") # pyright: ignore[reportUnreachable] -class MergeConfig(BaseModel): +class MergeConfig(BaseConfig): activation_threshold: Probability | None = Field( default=0.01, description="Threshold for considering a component active in a group. If None, use raw scalar causal importances", diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index d86cbc7a6..82980c2bc 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -1,16 +1,15 @@ """Configuration for merge clustering runs that combines merge config with run parameters.""" import hashlib -import json import tomllib import warnings from pathlib import Path from typing import Any, Literal, Self -import yaml from muutils.misc.numerical import shorten_numerical_to_str -from pydantic import BaseModel, Field, PositiveInt, model_validator +from pydantic import Field, PositiveInt, model_validator +from spd.base_config import BaseConfig from spd.clustering.consts import DistancesMethod from spd.clustering.merge_config import MergeConfig from spd.registry import EXPERIMENT_REGISTRY, ExperimentConfig @@ -67,7 +66,7 @@ def replace_sentinel_recursive(obj: Any) -> Any: return replace_sentinel_recursive(data) -class ClusteringRunConfig(BaseModel): +class ClusteringRunConfig(BaseConfig): """Configuration for a complete merge clustering run. Extends MergeConfig with parameters for model, dataset, and batch configuration. @@ -174,6 +173,37 @@ def validate_streaming_compatibility(self) -> Self: ) return self + @model_validator(mode="before") + def handle_experiment_key(data: dict[str, Any]) -> dict[str, Any]: + """handle passing experiment key instead of model_path and task_name. + + if we provide an experiment_key, then: + 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name + 2. check it's consistent with model_path and task_name from the file if those are provided + + """ + experiment_key: str | None = data.get("experiment_key") + model_path: str | None = data.get("model_path") + task_name: str | None = data.get("task_name") + if experiment_key is not None: + exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key] + + # Enforce consistency if explicit fields present + if model_path is not None: + assert model_path == exp_config.canonical_run, ( + f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})" + ) + if task_name is not None: + assert task_name == exp_config.task_name, ( + f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})" + ) + + # overwrite in data dict + data["model_path"] = exp_config.canonical_run + data["task_name"] = exp_config.task_name + + return data + @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition from the model_path @@ -213,70 +243,6 @@ def stable_hash(self) -> str: """Generate a stable hash including all config parameters.""" return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] - @classmethod - def read(cls, path: Path) -> "ClusteringRunConfig": - """Load config from JSON, YAML, or TOML file. - - Handles legacy spd_exp: model_path format and enforces consistency. - For TOML files, the sentinel value "__NULL__" is converted to None. - """ - # read the file contents, load them according to extension - data: dict[str, Any] - content: str - if path.suffix == ".json": - content = path.read_text() - data = json.loads(content) - elif path.suffix in [".yaml", ".yml"]: - content = path.read_text() - data = yaml.safe_load(content) - elif path.suffix == ".toml": - data = toml_read_file_with_none(path) - else: - raise ValueError( - f"Unsupported file extension '{path.suffix}' on file '{path}' -- must be .json, .yaml, .yml, or .toml" - ) - - # if we provide an experiment_key, then: - # 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name - # 2. check it's consistent with model_path and task_name from the file if those are provided - experiment_key: str | None = data.get("experiment_key") - model_path: str | None = data.get("model_path") - task_name: str | None = data.get("task_name") - if experiment_key is not None: - exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key] - - # Enforce consistency if explicit fields present - if model_path is not None: - assert model_path == exp_config.canonical_run, ( - f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})" - ) - if task_name is not None: - assert task_name == exp_config.task_name, ( - f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})" - ) - - # overwrite in data dict - data["model_path"] = exp_config.canonical_run - data["task_name"] = exp_config.task_name - - return cls.model_validate(data) - - def save(self, path: Path) -> None: - """Save config to file (format inferred from extension).""" - path.parent.mkdir(parents=True, exist_ok=True) - if path.suffix == ".json": - path.write_text(self.model_dump_json(indent=2)) - elif path.suffix in [".yaml", ".yml"]: - path.write_text( - yaml.dump( - self.model_dump(mode="json"), - default_flow_style=False, - sort_keys=False, - ) - ) - else: - raise ValueError(f"Unsupported file extension: {path.suffix}") - def model_dump_with_properties(self) -> dict[str, Any]: """Serialize config including computed properties for WandB logging.""" base_dump: dict[str, Any] = self.model_dump() diff --git a/spd/clustering/pipeline/s2_clustering.py b/spd/clustering/pipeline/s2_clustering.py index fc782c96e..116d7dc61 100644 --- a/spd/clustering/pipeline/s2_clustering.py +++ b/spd/clustering/pipeline/s2_clustering.py @@ -384,7 +384,7 @@ def cli() -> None: args: argparse.Namespace = parser.parse_args() # Load config - config: ClusteringRunConfig = ClusteringRunConfig.read(args.config) + config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) # Run clustering result: ClusteringResult = run_clustering( diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py index 2104482e5..82c57ace2 100644 --- a/spd/clustering/scripts/main.py +++ b/spd/clustering/scripts/main.py @@ -67,7 +67,7 @@ def cli() -> None: # Note that the defaults for args here always override the default values in `RunConfig` itself, # but we must have those defaults to avoid type issues logger.info(f"Loading config from {args.config}") - config: ClusteringRunConfig = ClusteringRunConfig.read(args.config) + config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) config.base_path = args.base_path config.devices = devices config.workers_per_device = args.workers_per_device diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 2cac6dd80..855440804 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -6,9 +6,9 @@ import wandb_workspaces.reports.v2 as wr import wandb_workspaces.workspaces as ws from dotenv import load_dotenv -from pydantic import BaseModel from wandb.apis.public import File, Run +from spd.base_config import BaseConfig from spd.log import logger from spd.registry import EXPERIMENT_REGISTRY from spd.settings import REPO_ROOT @@ -120,7 +120,7 @@ def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path: return path -def init_wandb[T_config: BaseModel]( +def init_wandb[T_config: BaseConfig]( config: T_config, project: str, name: str | None = None, tags: list[str] | None = None ) -> T_config: """Initialize Weights & Biases and return a config updated with sweep hyperparameters. diff --git a/tests/clustering/test_storage.py b/tests/clustering/test_storage.py index d5e3d535e..389940e54 100644 --- a/tests/clustering/test_storage.py +++ b/tests/clustering/test_storage.py @@ -91,7 +91,7 @@ def test_save_and_load_run_config(self, temp_storage: ClusteringStorage): assert saved_path == temp_storage.run_config_file # Load and verify - loaded_config = ClusteringRunConfig.read(saved_path) + loaded_config = ClusteringRunConfig.from_file(saved_path) assert loaded_config.n_batches == 5 assert loaded_config.batch_size == 32 assert loaded_config.task_name == "lm" From ecfebb769cb6eaf212d114c24ff9af9eeb33b08d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 16:11:48 +0100 Subject: [PATCH 2/5] fix typo --- spd/clustering/pipeline/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/clustering/pipeline/storage.py b/spd/clustering/pipeline/storage.py index cb38befd8..febe9bcc8 100644 --- a/spd/clustering/pipeline/storage.py +++ b/spd/clustering/pipeline/storage.py @@ -270,7 +270,7 @@ def save_run_config(self, config: ClusteringRunConfig) -> Path: ) def load_run_config(self) -> ClusteringRunConfig: - return ClusteringRunConfig.read(self.run_config_file) + return ClusteringRunConfig.from_file(self.run_config_file) # Dashboard storage methods From c27be14a74a6136466214b0f0294f1f3d79eb398 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 16:12:01 +0100 Subject: [PATCH 3/5] fix pydantic validation issue --- spd/clustering/merge_run_config.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 82980c2bc..b6b8d6ab6 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -145,24 +145,27 @@ def validate_model_path(self) -> Self: ) return self - @model_validator(mode="after") - def validate_intervals(self) -> Self: + @model_validator(mode="before") + @classmethod + def validate_intervals(cls, data: dict[str, Any]) -> dict[str, Any]: """Ensure all required interval keys are present.""" + + data_intervals: dict[IntervalKey, Any] = data.get("intervals", {}) # warning if any keys are missing - missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(self.intervals.keys()) + missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(data_intervals.keys()) if missing_keys: warnings.warn( - f"Missing interval keys in {self.intervals = }: {missing_keys}. Using defaults for those.", + f"Missing interval keys in {data_intervals = }: {missing_keys}. Using defaults for those.", UserWarning, stacklevel=1, ) - self.intervals = { + data["intervals"] = { **_DEFAULT_INTERVALS, - **self.intervals, + **data_intervals, } - return self + return data @model_validator(mode="after") def validate_streaming_compatibility(self) -> Self: @@ -174,7 +177,8 @@ def validate_streaming_compatibility(self) -> Self: return self @model_validator(mode="before") - def handle_experiment_key(data: dict[str, Any]) -> dict[str, Any]: + @classmethod + def handle_experiment_key(cls, data: dict[str, Any]) -> dict[str, Any]: """handle passing experiment key instead of model_path and task_name. if we provide an experiment_key, then: From d9d1b208886d8178e725406fe6de214b0708ef59 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 16:19:05 +0100 Subject: [PATCH 4/5] use model_copy to avoid editing frozen dict when updating ClusteringRunConfig from CLI --- spd/clustering/scripts/main.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py index 82c57ace2..65e224f5e 100644 --- a/spd/clustering/scripts/main.py +++ b/spd/clustering/scripts/main.py @@ -68,10 +68,15 @@ def cli() -> None: # but we must have those defaults to avoid type issues logger.info(f"Loading config from {args.config}") config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) - config.base_path = args.base_path - config.devices = devices - config.workers_per_device = args.workers_per_device - config.dataset_streaming = args.dataset_streaming + # Use model_copy to update frozen fields + config = config.model_copy( + update={ + "base_path": args.base_path, + "devices": devices, + "workers_per_device": args.workers_per_device, + "dataset_streaming": args.dataset_streaming, + } + ) logger.info(f"Configuration loaded: {config.config_identifier}") logger.info(f"Base path: {config.base_path}") From cbb36a31fecd61c62f15d12128526e3a51021304 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 16:19:32 +0100 Subject: [PATCH 5/5] remove deprecated config fields --- spd/clustering/configs/example.toml | 1 - spd/clustering/configs/example.yaml | 1 - spd/clustering/configs/test-resid_mlp1.json | 3 +-- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/spd/clustering/configs/example.toml b/spd/clustering/configs/example.toml index d5cfe46d6..98053576b 100644 --- a/spd/clustering/configs/example.toml +++ b/spd/clustering/configs/example.toml @@ -30,7 +30,6 @@ iters = 100 # iterations to run. setting this to exactly the number of componen pop_component_prob = 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold = 0.001 # Threshold for filtering dead components module_name_filter = "__NULL__" # Can be a string prefix like "model.layers.0." if you want to do only some modules -rank_cost_fn_name = "const_1" # Options: const_1, const_2, log, linear merge_pair_sampling_method = "range" # Method for sampling merge pairs: 'range' or 'mcmc' [merge_config.merge_pair_sampling_kwargs] diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/example.yaml index 5f3cd5fa5..259f1597c 100644 --- a/spd/clustering/configs/example.yaml +++ b/spd/clustering/configs/example.yaml @@ -11,7 +11,6 @@ merge_config: pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules - rank_cost_fn_name: const_1 # Options: const_1, const_2, log, linear # Run configuration model_path: wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh # WandB path to the decomposed model diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/test-resid_mlp1.json index 75877dd25..fbacff53a 100644 --- a/spd/clustering/configs/test-resid_mlp1.json +++ b/spd/clustering/configs/test-resid_mlp1.json @@ -7,8 +7,7 @@ "merge_pair_sampling_kwargs": {"threshold": 0.05}, "pop_component_prob": 0, "filter_dead_threshold": 0.1, - "module_name_filter": null, - "rank_cost_fn_name": "const_1" + "module_name_filter": null }, "experiment_key": "resid_mlp1", "n_batches": 2,