Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b93b9d6
Geometric similarity comparison made consistent with other evals and …
leesharkey Sep 16, 2025
cd5fda2
Replaced mean max cosine sim with mean max ABS cosine sim
leesharkey Sep 17, 2025
61d3408
Configs for geom comparison runs
leesharkey Sep 17, 2025
63c85f0
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey Sep 17, 2025
770a5c5
Minor modifications to make PR-ready
leesharkey Sep 17, 2025
49ba925
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey Sep 17, 2025
364198e
Update seed to be consistent with other configs again
leesharkey Sep 17, 2025
57c2c76
Cleaned up some comments and other bits
leesharkey Sep 18, 2025
2e7752d
Major update of PR following review: Now implemented as script rather…
leesharkey Sep 18, 2025
4fbf807
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey Sep 18, 2025
98a6620
Updated registry to delete old obselete experiments
leesharkey Sep 18, 2025
bede346
Merge branch 'main' into feature/geom_sim_compar
leesharkey Sep 18, 2025
acc04f1
Merge branch 'main' into feature/geom_sim_compar
leesharkey Sep 22, 2025
62bd77e
Reorganized compare_models into subdirectory and cleaned up config code
leesharkey Sep 22, 2025
b84814a
Merging
leesharkey Sep 22, 2025
5173a6a
Updated README.md
leesharkey Sep 22, 2025
181cac8
Added some example models to the config
leesharkey Sep 22, 2025
8db7559
Getting rid of newline
leesharkey Sep 22, 2025
0d05f0a
Minor changes to make the PR mergeable
leesharkey Sep 23, 2025
8767194
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey Sep 23, 2025
019eb2d
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey Sep 24, 2025
b935b4c
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey Sep 29, 2025
3d1edeb
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Sep 30, 2025
1dd738d
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 3, 2025
956f3d4
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 5, 2025
f7ad411
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 6, 2025
ade1377
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 7, 2025
08875a9
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 13, 2025
7ca7037
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 22, 2025
cbbdb61
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 22, 2025
267deb6
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 28, 2025
f49e9e0
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 5, 2025
22f7cfc
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 12, 2025
ab5346d
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 14, 2025
7cb528f
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 20, 2025
01d1b6b
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 21, 2025
a78fdc5
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 24, 2025
296a8d2
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Dec 2, 2025
8ecd0d7
Add resume functionality for SPD training jobs
leesharkey Dec 3, 2025
77f3b18
Fix determinism bug in resume: remove incorrect RNG save/restore
leesharkey Dec 3, 2025
a2b0191
Fix: Add init_distributed() to TMS and ResidMLP decomposition scripts…
leesharkey Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions spd/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Simple checkpoint save/load for SPD training resumption.

This module provides functions for saving and loading full training checkpoints,
including model state, optimizer state, RNG states, and dataloader position.
Follows the SPD style: simple functions, fail-fast assertions, clear errors.
"""

import random
from pathlib import Path
from typing import Any

import numpy as np
import torch
import torch.optim as optim

from spd.configs import Config
from spd.log import logger
from spd.models.component_model import ComponentModel
from spd.utils.run_utils import save_file

# Version for checkpoint format (for future compatibility)
CHECKPOINT_VERSION = "1.0"

# Critical config fields that must match for resume compatibility
CRITICAL_CONFIG_FIELDS = [
"C",
"target_module_patterns",
"pretrained_model_class",
"ci_fn_type",
"ci_fn_hidden_dims",
"sigmoid_type",
"use_delta_component",
]


def find_latest_checkpoint(out_dir: Path) -> Path | None:
"""Find the latest checkpoint in a directory by step number.

Args:
out_dir: Directory containing checkpoints

Returns:
Path to latest checkpoint, or None if no checkpoints found
"""
if not out_dir.exists():
return None

checkpoints = list(out_dir.glob("model_*.pth"))
if not checkpoints:
return None

def extract_step(path: Path) -> int:
"""Extract step number from filename like 'model_1000.pth'."""
try:
return int(path.stem.split("_")[1])
except (IndexError, ValueError):
logger.warning(f"Could not parse step from checkpoint filename: {path.name}")
return -1

latest = max(checkpoints, key=extract_step)
step = extract_step(latest)

if step < 0:
return None

return latest


def save_checkpoint(
step: int,
component_model: ComponentModel,
optimizer: optim.Optimizer,
config: Config,
dataloader_steps_consumed: int,
out_dir: Path,
) -> Path:
"""Save a full training checkpoint.

Includes model state, optimizer state (momentum, etc.), RNG states for reproducibility,
dataloader position, and config snapshot for validation.

Note: In distributed training, caller is responsible for ensuring this is only called
from the main process using is_main_process().

Args:
step: Current training step
component_model: The component model to checkpoint
optimizer: The optimizer to checkpoint
config: Current training config
dataloader_steps_consumed: Number of dataloader steps consumed (for skip on resume)
out_dir: Directory to save checkpoint to

Returns:
Path to saved checkpoint file
"""
# Collect all RNG states
rng_states = {
"torch": torch.get_rng_state().cpu(), # Move to CPU for serialization
"numpy": np.random.get_state(),
"python": random.getstate(),
}

# Add CUDA RNG state if available
# Store as tuple of tensors (converted from list for type compatibility)
if torch.cuda.is_available():
cuda_states = tuple(state.cpu() for state in torch.cuda.get_rng_state_all())
rng_states["torch_cuda"] = cuda_states

checkpoint = {
"step": step,
"model_state_dict": component_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"rng_states": rng_states,
"dataloader_state": {
"steps_consumed": dataloader_steps_consumed,
},
"config_snapshot": config.model_dump(),
"version": CHECKPOINT_VERSION,
}

checkpoint_path = out_dir / f"model_{step}.pth"
save_file(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint at step {step} to {checkpoint_path}")

return checkpoint_path


def load_checkpoint(
checkpoint_path: Path,
component_model: ComponentModel,
optimizer: optim.Optimizer,
config: Config,
) -> tuple[int, int]:
"""Load a checkpoint and restore training state.

Validates config compatibility (errors on breaking changes, warns on non-critical changes),
loads model and optimizer state, restores RNG states for reproducibility.

Args:
checkpoint_path: Path to checkpoint file
component_model: Model to load state into
optimizer: Optimizer to load state into
config: Current config (for validation)

Returns:
Tuple of (checkpoint_step, dataloader_steps_consumed)

Raises:
FileNotFoundError: If checkpoint doesn't exist
ValueError: If checkpoint is incompatible with current config
AssertionError: If checkpoint format is invalid
"""
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

logger.info(f"Loading checkpoint from {checkpoint_path}")

# Load checkpoint (weights_only=False needed for RNG states and optimizer)
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

# Validate checkpoint structure
required_keys = {
"step",
"model_state_dict",
"optimizer_state_dict",
"rng_states",
"dataloader_state",
"config_snapshot",
}
missing_keys = required_keys - checkpoint.keys()
assert not missing_keys, f"Checkpoint missing required keys: {missing_keys}"

# Validate config compatibility
saved_config = checkpoint["config_snapshot"]
_validate_config_compatibility(saved_config, config.model_dump())

# Load model state
component_model.load_state_dict(checkpoint["model_state_dict"])
logger.info(f"Loaded model state from step {checkpoint['step']}")

# Load optimizer state
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
logger.info("Loaded optimizer state")

# Restore RNG states for reproducibility
# Note: States were saved as CPU tensors, so no need to call .cpu() again
torch.set_rng_state(checkpoint["rng_states"]["torch"])
if torch.cuda.is_available() and "torch_cuda" in checkpoint["rng_states"]:
cuda_states = checkpoint["rng_states"]["torch_cuda"]
# Only set if we have the same number of devices
if len(cuda_states) == torch.cuda.device_count():
torch.cuda.set_rng_state_all(cuda_states)
else:
logger.warning(
f"Saved checkpoint has {len(cuda_states)} CUDA devices, "
f"but current setup has {torch.cuda.device_count()}. Skipping CUDA RNG restore."
)
np.random.set_state(checkpoint["rng_states"]["numpy"])
random.setstate(checkpoint["rng_states"]["python"])
logger.info("Restored RNG states")

# Get dataloader state
dataloader_steps = checkpoint["dataloader_state"]["steps_consumed"]

logger.info(f"Successfully loaded checkpoint from step {checkpoint['step']}")

return checkpoint["step"], dataloader_steps


def _validate_config_compatibility(
saved_config: dict[str, Any], current_config: dict[str, Any]
) -> None:
"""Validate that current config is compatible with checkpoint's config.

Errors on breaking changes (architecture differences), warns on non-critical changes
(hyperparameters that can safely differ).

Args:
saved_config: Config from checkpoint
current_config: Current training config

Raises:
ValueError: If configs have incompatible (breaking) differences
"""
# Check critical fields - these must match exactly
breaking_changes = []

for field in CRITICAL_CONFIG_FIELDS:
if field not in saved_config or field not in current_config:
# Field missing in one config - skip (could be added field)
continue

saved_value = saved_config[field]
current_value = current_config[field]

if saved_value != current_value:
breaking_changes.append(
f" {field}:\n Saved: {saved_value}\n Current: {current_value}"
)

if breaking_changes:
changes_str = "\n".join(breaking_changes)
raise ValueError(
f"Cannot resume: Config has incompatible architecture changes:\n{changes_str}\n\n"
f"These fields affect model structure and must match for resume.\n"
f"If you want to change these, you must start training from scratch."
)

# Check non-critical fields - warn but don't block
non_critical_fields = ["lr", "steps", "batch_size", "seed", "train_log_freq", "eval_freq"]
non_critical_changes = []

for field in non_critical_fields:
if (
field in saved_config
and field in current_config
and saved_config[field] != current_config[field]
):
non_critical_changes.append(
f"{field}: {saved_config[field]} -> {current_config[field]}"
)

if non_critical_changes:
changes_str = ", ".join(non_critical_changes)
logger.warning(
f"Config has non-critical changes from checkpoint: {changes_str}. "
f"Continuing with current config values."
)

logger.info("Config compatibility check passed")
17 changes: 17 additions & 0 deletions spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,23 @@ def microbatch_size(self) -> PositiveInt:
)
)

# --- Resume Configuration ---
auto_resume: bool = Field(
default=False,
description="Automatically resume from the latest checkpoint in out_dir if available. "
"Searches for model_*.pth files and resumes from the one with the highest step number.",
)
resume_from_checkpoint: Path | None = Field(
default=None,
description="Explicit path to checkpoint file to resume from. If specified, overrides "
"auto_resume. Checkpoint must be compatible with current config (same C, architecture, etc.).",
)
wandb_run_id: str | None = Field(
default=None,
description="WandB run ID for resuming an existing run. When resuming with the same run ID, "
"WandB will continue logging to the same run instead of creating a new one.",
)

# --- Component Tracking ---
ci_alive_threshold: Probability = Field(
default=0.0,
Expand Down
6 changes: 5 additions & 1 deletion spd/experiments/resid_mlp/resid_mlp_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from spd.log import logger
from spd.run_spd import optimize
from spd.utils.data_utils import DatasetGeneratedDataLoader
from spd.utils.distributed_utils import get_device
from spd.utils.distributed_utils import get_device, init_distributed
from spd.utils.general_utils import save_pre_run_info, set_seed
from spd.utils.run_utils import get_output_dir, save_file
from spd.utils.wandb_utils import init_wandb
Expand All @@ -41,6 +41,10 @@ def main(
sweep_params = (
None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:"))
)

dist_state = init_distributed()
logger.info(f"Distributed state: {dist_state}")

if config.wandb_project:
tags = ["resid_mlp"]
if evals_id:
Expand Down
5 changes: 4 additions & 1 deletion spd/experiments/tms/tms_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from spd.log import logger
from spd.run_spd import optimize
from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset
from spd.utils.distributed_utils import get_device
from spd.utils.distributed_utils import get_device, init_distributed
from spd.utils.general_utils import save_pre_run_info, set_seed
from spd.utils.run_utils import get_output_dir
from spd.utils.wandb_utils import init_wandb
Expand All @@ -42,6 +42,9 @@ def main(
None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:"))
)

dist_state = init_distributed()
logger.info(f"Distributed state: {dist_state}")

device = get_device()
logger.info(f"Using device: {device}")

Expand Down
Loading
Loading