diff --git a/spd/checkpoint.py b/spd/checkpoint.py new file mode 100644 index 000000000..951e67810 --- /dev/null +++ b/spd/checkpoint.py @@ -0,0 +1,270 @@ +"""Simple checkpoint save/load for SPD training resumption. + +This module provides functions for saving and loading full training checkpoints, +including model state, optimizer state, RNG states, and dataloader position. +Follows the SPD style: simple functions, fail-fast assertions, clear errors. +""" + +import random +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.optim as optim + +from spd.configs import Config +from spd.log import logger +from spd.models.component_model import ComponentModel +from spd.utils.run_utils import save_file + +# Version for checkpoint format (for future compatibility) +CHECKPOINT_VERSION = "1.0" + +# Critical config fields that must match for resume compatibility +CRITICAL_CONFIG_FIELDS = [ + "C", + "target_module_patterns", + "pretrained_model_class", + "ci_fn_type", + "ci_fn_hidden_dims", + "sigmoid_type", + "use_delta_component", +] + + +def find_latest_checkpoint(out_dir: Path) -> Path | None: + """Find the latest checkpoint in a directory by step number. + + Args: + out_dir: Directory containing checkpoints + + Returns: + Path to latest checkpoint, or None if no checkpoints found + """ + if not out_dir.exists(): + return None + + checkpoints = list(out_dir.glob("model_*.pth")) + if not checkpoints: + return None + + def extract_step(path: Path) -> int: + """Extract step number from filename like 'model_1000.pth'.""" + try: + return int(path.stem.split("_")[1]) + except (IndexError, ValueError): + logger.warning(f"Could not parse step from checkpoint filename: {path.name}") + return -1 + + latest = max(checkpoints, key=extract_step) + step = extract_step(latest) + + if step < 0: + return None + + return latest + + +def save_checkpoint( + step: int, + component_model: ComponentModel, + optimizer: optim.Optimizer, + config: Config, + dataloader_steps_consumed: int, + out_dir: Path, +) -> Path: + """Save a full training checkpoint. + + Includes model state, optimizer state (momentum, etc.), RNG states for reproducibility, + dataloader position, and config snapshot for validation. + + Note: In distributed training, caller is responsible for ensuring this is only called + from the main process using is_main_process(). + + Args: + step: Current training step + component_model: The component model to checkpoint + optimizer: The optimizer to checkpoint + config: Current training config + dataloader_steps_consumed: Number of dataloader steps consumed (for skip on resume) + out_dir: Directory to save checkpoint to + + Returns: + Path to saved checkpoint file + """ + # Collect all RNG states + rng_states = { + "torch": torch.get_rng_state().cpu(), # Move to CPU for serialization + "numpy": np.random.get_state(), + "python": random.getstate(), + } + + # Add CUDA RNG state if available + # Store as tuple of tensors (converted from list for type compatibility) + if torch.cuda.is_available(): + cuda_states = tuple(state.cpu() for state in torch.cuda.get_rng_state_all()) + rng_states["torch_cuda"] = cuda_states + + checkpoint = { + "step": step, + "model_state_dict": component_model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "rng_states": rng_states, + "dataloader_state": { + "steps_consumed": dataloader_steps_consumed, + }, + "config_snapshot": config.model_dump(), + "version": CHECKPOINT_VERSION, + } + + checkpoint_path = out_dir / f"model_{step}.pth" + save_file(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint at step {step} to {checkpoint_path}") + + return checkpoint_path + + +def load_checkpoint( + checkpoint_path: Path, + component_model: ComponentModel, + optimizer: optim.Optimizer, + config: Config, +) -> tuple[int, int]: + """Load a checkpoint and restore training state. + + Validates config compatibility (errors on breaking changes, warns on non-critical changes), + loads model and optimizer state, restores RNG states for reproducibility. + + Args: + checkpoint_path: Path to checkpoint file + component_model: Model to load state into + optimizer: Optimizer to load state into + config: Current config (for validation) + + Returns: + Tuple of (checkpoint_step, dataloader_steps_consumed) + + Raises: + FileNotFoundError: If checkpoint doesn't exist + ValueError: If checkpoint is incompatible with current config + AssertionError: If checkpoint format is invalid + """ + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + logger.info(f"Loading checkpoint from {checkpoint_path}") + + # Load checkpoint (weights_only=False needed for RNG states and optimizer) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Validate checkpoint structure + required_keys = { + "step", + "model_state_dict", + "optimizer_state_dict", + "rng_states", + "dataloader_state", + "config_snapshot", + } + missing_keys = required_keys - checkpoint.keys() + assert not missing_keys, f"Checkpoint missing required keys: {missing_keys}" + + # Validate config compatibility + saved_config = checkpoint["config_snapshot"] + _validate_config_compatibility(saved_config, config.model_dump()) + + # Load model state + component_model.load_state_dict(checkpoint["model_state_dict"]) + logger.info(f"Loaded model state from step {checkpoint['step']}") + + # Load optimizer state + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + logger.info("Loaded optimizer state") + + # Restore RNG states for reproducibility + # Note: States were saved as CPU tensors, so no need to call .cpu() again + torch.set_rng_state(checkpoint["rng_states"]["torch"]) + if torch.cuda.is_available() and "torch_cuda" in checkpoint["rng_states"]: + cuda_states = checkpoint["rng_states"]["torch_cuda"] + # Only set if we have the same number of devices + if len(cuda_states) == torch.cuda.device_count(): + torch.cuda.set_rng_state_all(cuda_states) + else: + logger.warning( + f"Saved checkpoint has {len(cuda_states)} CUDA devices, " + f"but current setup has {torch.cuda.device_count()}. Skipping CUDA RNG restore." + ) + np.random.set_state(checkpoint["rng_states"]["numpy"]) + random.setstate(checkpoint["rng_states"]["python"]) + logger.info("Restored RNG states") + + # Get dataloader state + dataloader_steps = checkpoint["dataloader_state"]["steps_consumed"] + + logger.info(f"Successfully loaded checkpoint from step {checkpoint['step']}") + + return checkpoint["step"], dataloader_steps + + +def _validate_config_compatibility( + saved_config: dict[str, Any], current_config: dict[str, Any] +) -> None: + """Validate that current config is compatible with checkpoint's config. + + Errors on breaking changes (architecture differences), warns on non-critical changes + (hyperparameters that can safely differ). + + Args: + saved_config: Config from checkpoint + current_config: Current training config + + Raises: + ValueError: If configs have incompatible (breaking) differences + """ + # Check critical fields - these must match exactly + breaking_changes = [] + + for field in CRITICAL_CONFIG_FIELDS: + if field not in saved_config or field not in current_config: + # Field missing in one config - skip (could be added field) + continue + + saved_value = saved_config[field] + current_value = current_config[field] + + if saved_value != current_value: + breaking_changes.append( + f" {field}:\n Saved: {saved_value}\n Current: {current_value}" + ) + + if breaking_changes: + changes_str = "\n".join(breaking_changes) + raise ValueError( + f"Cannot resume: Config has incompatible architecture changes:\n{changes_str}\n\n" + f"These fields affect model structure and must match for resume.\n" + f"If you want to change these, you must start training from scratch." + ) + + # Check non-critical fields - warn but don't block + non_critical_fields = ["lr", "steps", "batch_size", "seed", "train_log_freq", "eval_freq"] + non_critical_changes = [] + + for field in non_critical_fields: + if ( + field in saved_config + and field in current_config + and saved_config[field] != current_config[field] + ): + non_critical_changes.append( + f"{field}: {saved_config[field]} -> {current_config[field]}" + ) + + if non_critical_changes: + changes_str = ", ".join(non_critical_changes) + logger.warning( + f"Config has non-critical changes from checkpoint: {changes_str}. " + f"Continuing with current config values." + ) + + logger.info("Config compatibility check passed") diff --git a/spd/configs.py b/spd/configs.py index 24b81483d..e5248a68f 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -398,6 +398,23 @@ def microbatch_size(self) -> PositiveInt: ) ) + # --- Resume Configuration --- + auto_resume: bool = Field( + default=False, + description="Automatically resume from the latest checkpoint in out_dir if available. " + "Searches for model_*.pth files and resumes from the one with the highest step number.", + ) + resume_from_checkpoint: Path | None = Field( + default=None, + description="Explicit path to checkpoint file to resume from. If specified, overrides " + "auto_resume. Checkpoint must be compatible with current config (same C, architecture, etc.).", + ) + wandb_run_id: str | None = Field( + default=None, + description="WandB run ID for resuming an existing run. When resuming with the same run ID, " + "WandB will continue logging to the same run instead of creating a new one.", + ) + # --- Component Tracking --- ci_alive_threshold: Probability = Field( default=0.0, diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 21f0c3863..96f26a473 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -16,7 +16,7 @@ from spd.log import logger from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader -from spd.utils.distributed_utils import get_device +from spd.utils.distributed_utils import get_device, init_distributed from spd.utils.general_utils import save_pre_run_info, set_seed from spd.utils.run_utils import get_output_dir, save_file from spd.utils.wandb_utils import init_wandb @@ -41,6 +41,10 @@ def main( sweep_params = ( None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:")) ) + + dist_state = init_distributed() + logger.info(f"Distributed state: {dist_state}") + if config.wandb_project: tags = ["resid_mlp"] if evals_id: diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 711264f60..3b1bdbcc9 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -16,7 +16,7 @@ from spd.log import logger from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset -from spd.utils.distributed_utils import get_device +from spd.utils.distributed_utils import get_device, init_distributed from spd.utils.general_utils import save_pre_run_info, set_seed from spd.utils.run_utils import get_output_dir from spd.utils.wandb_utils import init_wandb @@ -42,6 +42,9 @@ def main( None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:")) ) + dist_state = init_distributed() + logger.info(f"Distributed state: {dist_state}") + device = get_device() logger.info(f"Using device: {device}") diff --git a/spd/run_spd.py b/spd/run_spd.py index 49e0df810..ffdc2806f 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -1,11 +1,14 @@ """Run SPD on a model.""" import gc +import hashlib +import random from collections import defaultdict from collections.abc import Iterator from pathlib import Path from typing import cast +import numpy as np import torch import torch.nn as nn import torch.nn.parallel @@ -18,6 +21,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from spd import checkpoint as ckpt from spd.configs import ( Config, LossMetricConfigType, @@ -49,10 +53,23 @@ ) from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import replace_std_values_in_layernorm -from spd.utils.run_utils import save_file from spd.utils.wandb_utils import try_wandb +def log_rng_state(label: str) -> None: + """Log RNG state hashes for debugging determinism issues.""" + torch_state = torch.get_rng_state() + torch_hash = hashlib.sha256(torch_state.numpy().tobytes()).hexdigest()[:12] + + np_state = np.random.get_state() + np_hash = hashlib.sha256(str(np_state[1]).encode()).hexdigest()[:12] # pyright: ignore[reportArgumentType] + + py_state = random.getstate() + py_hash = hashlib.sha256(str(py_state).encode()).hexdigest()[:12] + + logger.info(f"[RNG DEBUG {label}] torch={torch_hash} numpy={np_hash} python={py_hash}") + + def run_faithfulness_warmup( component_model: ComponentModel, component_params: list[torch.nn.Parameter], @@ -206,7 +223,45 @@ def create_pgd_data_iter() -> ( lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife) logger.info(f"Base LR scheduler created: {config.lr_schedule}") - if config.faithfulness_warmup_steps > 0: + # --- Resume from Checkpoint --- # + resume_checkpoint_path: Path | None = None + start_step = 0 + dataloader_steps_consumed = 0 + is_resuming = False + + # Determine checkpoint to resume from + if config.resume_from_checkpoint is not None: + # Explicit checkpoint path provided + resume_checkpoint_path = Path(config.resume_from_checkpoint) + if not resume_checkpoint_path.exists(): + raise FileNotFoundError( + f"Resume checkpoint not found: {resume_checkpoint_path}\n" + f"Check that the path is correct and the file exists." + ) + logger.info(f"Resuming from explicit checkpoint: {resume_checkpoint_path}") + + elif config.auto_resume and out_dir is not None: + # Auto-detect latest checkpoint + resume_checkpoint_path = ckpt.find_latest_checkpoint(out_dir) + if resume_checkpoint_path is not None: + logger.info(f"Auto-detected latest checkpoint: {resume_checkpoint_path}") + else: + logger.info("No checkpoints found for auto-resume, starting from scratch") + + # Load checkpoint if found + if resume_checkpoint_path is not None: + checkpoint_step, dataloader_steps_consumed = ckpt.load_checkpoint( + checkpoint_path=resume_checkpoint_path, + component_model=component_model, + optimizer=optimizer, + config=config, + ) + start_step = checkpoint_step + 1 # Resume from next step + is_resuming = True + logger.info(f"Resuming training from step {start_step}") + + # Skip faithfulness warmup when resuming (components already trained) + if config.faithfulness_warmup_steps > 0 and not is_resuming: run_faithfulness_warmup(component_model, component_params, config) eval_metric_configs = get_unique_metric_configs( @@ -220,11 +275,80 @@ def create_pgd_data_iter() -> ( eval_metric_configs = [ cfg for cfg in eval_metric_configs if cfg not in multibatch_pgd_eval_configs ] - batch_dims: tuple[int, ...] | None = None + + # Fast-forward data iterators and handle alive_tracker batch if resuming from checkpoint + # IMPORTANT: Do this BEFORE the main training loop to position data correctly + if dataloader_steps_consumed > 0: + # When resuming from a checkpoint, we need to regenerate the same batches that were + # consumed in the original run (for on-the-fly data generation like ResidMLP). + # The checkpoint RNG state includes advancement from both data generation AND training. + # To get back to the correct position: + # 1. Reset RNG to initial seed (to regenerate batches correctly) + # 2. Skip consumed batches (regenerates them with correct RNG progression) + # 3. Restore checkpoint RNG states (for training to continue correctly) + + # Save the checkpoint RNG states + log_rng_state("BEFORE_SAVE_CHECKPOINT_RNG") + checkpoint_rng_states = { + "torch": torch.get_rng_state().clone(), + "numpy": np.random.get_state(), + "python": random.getstate(), + } + if torch.cuda.is_available(): + checkpoint_rng_states["torch_cuda"] = tuple( + state.clone() for state in torch.cuda.get_rng_state_all() + ) + log_rng_state("AFTER_SAVE_CHECKPOINT_RNG") + + # Reset to initial seed to regenerate consumed batches correctly + torch.manual_seed(config.seed) + np.random.seed(config.seed) + random.seed(config.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(config.seed) + log_rng_state("AFTER_RESET_TO_INITIAL_SEED") + + logger.info( + f"Fast-forwarding data iterators by {dataloader_steps_consumed} steps (regenerating consumed batches)..." + ) + for i in tqdm( + range(dataloader_steps_consumed), + desc="Skipping dataloader steps", + disable=not is_main_process(), + ): + next(train_iterator) + if i == 0: + log_rng_state("AFTER_SKIP_BATCH_0") + elif i == dataloader_steps_consumed - 1: + log_rng_state(f"AFTER_SKIP_BATCH_{i}_FINAL") + logger.info("Data iterator fast-forward complete") + log_rng_state("AFTER_ALL_SKIPS") + + # Now we're positioned at the correct batch for training + # Get a sample batch for alive_tracker (only shape matters) + # We'll use a dummy batch to avoid consuming from the iterator + sample_batch = torch.zeros(config.batch_size, 100, device=device) + logger.info(f"Using dummy batch for alive_tracker on resume, shape: {sample_batch.shape}") + + # Restore checkpoint RNG states for training + torch.set_rng_state(checkpoint_rng_states["torch"]) # pyright: ignore[reportArgumentType] + np.random.set_state(checkpoint_rng_states["numpy"]) # pyright: ignore[reportArgumentType] + random.setstate(checkpoint_rng_states["python"]) # pyright: ignore[reportArgumentType] + if torch.cuda.is_available() and "torch_cuda" in checkpoint_rng_states: + torch.cuda.set_rng_state_all(checkpoint_rng_states["torch_cuda"]) # pyright: ignore[reportArgumentType] + log_rng_state("AFTER_RESTORE_CHECKPOINT_RNG") + logger.info("Restored checkpoint RNG states after fast-forward") + else: + # Normal case: consume first batch for alive_tracker initialization + log_rng_state("BEFORE_CONSUME_ALIVE_TRACKER_BATCH") + sample_batch = extract_batch_data(next(train_iterator)) + log_rng_state("AFTER_CONSUME_ALIVE_TRACKER_BATCH") + logger.info( + f"Normal startup - consumed batch for alive_tracker, shape: {sample_batch.shape}" + ) # Track which components are alive based on firing frequency - sample_batch = extract_batch_data(next(train_iterator)) - batch_dims = ( + batch_dims: tuple[int, ...] = ( sample_batch.shape[:-1] if config.output_loss_type == "mse" # if mse then input is a vector else sample_batch.shape # else it's a batch of token ids @@ -238,7 +362,9 @@ def create_pgd_data_iter() -> ( global_n_examples_per_batch=batch_dims.numel(), ) - for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): + log_rng_state(f"BEFORE_TRAINING_LOOP_START_{start_step}") + for step in tqdm(range(start_step, config.steps + 1), ncols=0, disable=not is_main_process()): + log_rng_state(f"START_TRAINING_STEP_{step}") optimizer.zero_grad() step_lr = get_lr_with_warmup( @@ -255,8 +381,14 @@ def create_pgd_data_iter() -> ( microbatch_log_data: defaultdict[str, float] = defaultdict(float) - for _ in range(config.gradient_accumulation_steps): + for grad_acc_idx in range(config.gradient_accumulation_steps): + if is_main_process(): + log_rng_state(f"BEFORE_BATCH_STEP_{step}_GRAD_ACC_{grad_acc_idx}") microbatch = extract_batch_data(next(train_iterator)).to(device) + if is_main_process(): + log_rng_state(f"AFTER_BATCH_STEP_{step}_GRAD_ACC_{grad_acc_idx}") + batch_hash = hashlib.sha256(microbatch.cpu().numpy().tobytes()).hexdigest()[:12] + logger.info(f"[BATCH HASH STEP_{step}_GRAD_ACC_{grad_acc_idx}] {batch_hash}") # NOTE: we need to call the wrapped_model at least once each step in order to setup # the DDP gradient syncing for all parameters in the component model. Gradients will @@ -385,13 +517,25 @@ def create_pgd_data_iter() -> ( and out_dir is not None and is_main_process() ): - # Save the state dict of the underlying module (not DDP wrapper) - save_file(component_model.state_dict(), out_dir / f"model_{step}.pth") - logger.info(f"Saved model, optimizer, and out_dir to {out_dir}") + # Save full checkpoint with model, optimizer, RNG states, and dataloader position + # Calculate total dataloader steps consumed: + # - 1 batch for alive_tracker initialization (before main loop) + # - (step + 1) training steps (steps 0 through current step, inclusive) + # - Each step consumes gradient_accumulation_steps batches + total_dataloader_steps = 1 + (step + 1) * config.gradient_accumulation_steps + checkpoint_path = ckpt.save_checkpoint( + step=step, + component_model=component_model, + optimizer=optimizer, + config=config, + dataloader_steps_consumed=total_dataloader_steps, + out_dir=out_dir, + ) + logger.info(f"Saved checkpoint to {checkpoint_path}") if config.wandb_project: try_wandb( wandb.save, - str(out_dir / f"model_{step}.pth"), + str(checkpoint_path), base_path=str(out_dir), policy="now", )