diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 8002aa11c..d2099c6ca 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -16,7 +16,6 @@ from spd.data import DatasetConfig, create_data_loader from spd.log import logger from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data # ============================================================================= # Schemas @@ -120,7 +119,7 @@ def generate() -> Generator[str]: if added_count >= n_prompts: break - tokens = extract_batch_data(batch).to(DEVICE) + tokens = batch["input_ids"].to(DEVICE) batch_size, n_seq = tokens.shape # Compute CI for the whole batch diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c8e86f0fc..ae745a0c9 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -106,7 +106,7 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT spd_run = SPDRunInfo.from_path(model_path) cfg = spd_run.config - component_model = ComponentModel.from_pretrained(spd_run.checkpoint_path) + component_model = ComponentModel.from_run_info(spd_run) assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }" diff --git a/spd/configs.py b/spd/configs.py index f71eec08e..20eecbee8 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -461,11 +461,6 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]: ), ) ) - output_loss_type: Literal["mse", "kl"] = Field( - ..., - description="Metric used to measure recon error between model outputs and targets", - ) - # --- Training --- lr_schedule: ScheduleConfig = Field(..., description="Learning rate schedule configuration") steps: NonNegativeInt = Field(..., description="Total number of optimisation steps") @@ -566,9 +561,10 @@ def microbatch_size(self) -> PositiveInt: default=None, description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'", ) - pretrained_model_output_attr: str | None = Field( + output_extract: int | str | None = Field( default=None, - description="Name of the attribute on the forward output that contains logits or activations", + description="How to extract tensor from model output. None = raw output, int = index into " + "output tuple, str = attribute name.", ) tokenizer_name: str | None = Field( default=None, @@ -607,6 +603,7 @@ def microbatch_size(self) -> PositiveInt: "lr_exponential_halflife", "out_dir", "n_examples_until_dead", + "output_loss_type", ] RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = { "grad_clip_norm": "grad_clip_norm_components", @@ -652,6 +649,19 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, "simple_stories_train.models.", "spd.pretrain.models.", 1 ) + # Migrate old pretrained_model_output_attr to output_extract + if "pretrained_model_output_attr" in config_dict: + old_val = config_dict.pop("pretrained_model_output_attr") + match old_val: + case None: + pass + case "idx_0": + config_dict["output_extract"] = 0 + case "logits": + config_dict["output_extract"] = "logits" + case _: + raise ValueError(f"Unknown pretrained_model_output_attr: {old_val!r}") + if "eval_batch_size" not in config_dict: config_dict["eval_batch_size"] = config_dict["batch_size"] if "train_log_freq" not in config_dict: diff --git a/spd/data.py b/spd/data.py index b1ed33a62..39479f0b0 100644 --- a/spd/data.py +++ b/spd/data.py @@ -1,8 +1,10 @@ +from collections.abc import Callable, Generator from typing import Any import numpy as np import torch from datasets import Dataset, IterableDataset, load_dataset +from jaxtyping import Int from numpy.typing import NDArray from torch import Tensor from torch.utils.data import DataLoader, DistributedSampler @@ -152,7 +154,8 @@ def create_data_loader( dist_state: DistributedState | None = None, global_seed: int = 0, to_lower: bool = True, -) -> tuple[DataLoader[Any], PreTrainedTokenizer]: + collate_fn: Callable[..., Any] | None = None, +) -> tuple[DataLoader[Int[Tensor, "batch seq"]], PreTrainedTokenizer]: """Create a DataLoader for the given dataset. Uses PyTorch's DistributedSampler to ensure each rank gets the correct @@ -255,7 +258,7 @@ def create_data_loader( generator = torch.Generator(device="cpu") generator.manual_seed(seed) - loader = DataLoader[Dataset | IterableDataset]( + loader = DataLoader[Int[Tensor, "batch seq"]]( torch_dataset, # pyright: ignore[reportArgumentType] batch_size=batch_size, sampler=sampler, @@ -264,11 +267,17 @@ def create_data_loader( ), drop_last=True, generator=generator, + collate_fn=collate_fn, ) return loader, tokenizer -def loop_dataloader[T](dl: DataLoader[T]): +def lm_collate_fn(batch: list[dict[str, Tensor]]) -> Tensor: + """Collate function that extracts input_ids tensors from HuggingFace dataset dicts.""" + return torch.stack([item["input_ids"] for item in batch]) + + +def loop_dataloader[T](dl: DataLoader[T]) -> Generator[T]: """Loop over a dataloader, resetting the iterator when it is exhausted. Ensures that each epoch gets different data, even when using a distributed sampler. @@ -311,6 +320,7 @@ def train_loader_and_tokenizer( batch_size=batch_size, buffer_size=task_config.buffer_size, global_seed=config.seed, + collate_fn=lm_collate_fn, ) assert isinstance(tokenizer, PreTrainedTokenizerBase) diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 0651a3a7b..32e7aaf7c 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -30,7 +30,6 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data from spd.utils.wandb_utils import parse_wandb_run_path @@ -206,7 +205,7 @@ def harvest_attributions( # Skip batches not assigned to this rank if world_size is not None and batch_idx % world_size != rank: continue - batch = extract_batch_data(batch_data).to(device) + batch = batch_data["input_ids"].to(device) harvester.process_batch(batch) logger.info( diff --git a/spd/eval.py b/spd/eval.py index c6f0b47ff..eb2c946d6 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -3,7 +3,6 @@ from collections.abc import Iterator from typing import Any -from jaxtyping import Float, Int from PIL import Image from torch import Tensor from torch.types import Number @@ -37,34 +36,37 @@ UnmaskedReconLossConfig, UVPlotsConfig, ) -from spd.metrics import UnmaskedReconLoss +from spd.metrics import ( + CI_L0, + CEandKLLosses, + CIHistograms, + CIMaskedReconLayerwiseLoss, + CIMaskedReconLoss, + CIMaskedReconSubsetLoss, + CIMeanPerComponent, + ComponentActivationDensity, + FaithfulnessLoss, + IdentityCIError, + ImportanceMinimalityLoss, + PermutedCIPlots, + PGDReconLayerwiseLoss, + PGDReconLoss, + PGDReconSubsetLoss, + StochasticHiddenActsReconLoss, + StochasticReconLayerwiseLoss, + StochasticReconLoss, + StochasticReconSubsetCEAndKL, + StochasticReconSubsetLoss, + UnmaskedReconLoss, + UVPlots, +) from spd.metrics.base import Metric -from spd.metrics.ce_and_kl_losses import CEandKLLosses -from spd.metrics.ci_histograms import CIHistograms -from spd.metrics.ci_l0 import CI_L0 -from spd.metrics.ci_masked_recon_layerwise_loss import CIMaskedReconLayerwiseLoss -from spd.metrics.ci_masked_recon_loss import CIMaskedReconLoss -from spd.metrics.ci_masked_recon_subset_loss import CIMaskedReconSubsetLoss -from spd.metrics.ci_mean_per_component import CIMeanPerComponent -from spd.metrics.component_activation_density import ComponentActivationDensity -from spd.metrics.faithfulness_loss import FaithfulnessLoss -from spd.metrics.identity_ci_error import IdentityCIError -from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss -from spd.metrics.permuted_ci_plots import PermutedCIPlots -from spd.metrics.pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss -from spd.metrics.pgd_masked_recon_loss import PGDReconLoss -from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss from spd.metrics.pgd_utils import CreateDataIter, calc_multibatch_pgd_masked_recon_loss -from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss -from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss -from spd.metrics.stochastic_recon_loss import StochasticReconLoss -from spd.metrics.stochastic_recon_subset_ce_and_kl import StochasticReconSubsetCEAndKL -from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss -from spd.metrics.uv_plots import UVPlots +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.routing import AllLayersRouter, get_subset_router from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed -from spd.utils.general_utils import dict_safe_update_, extract_batch_data +from spd.utils.general_utils import dict_safe_update_ MetricOutType = dict[str, str | Number | Image.Image | CustomChart] DistMetricOutType = dict[str, str | float | Image.Image | CustomChart] @@ -121,6 +123,7 @@ def init_metric( model: ComponentModel, run_config: Config, device: str, + reconstruction_loss: ReconstructionLoss, ) -> Metric: match cfg: case ImportanceMinimalityLossConfig(): @@ -158,16 +161,20 @@ def init_metric( metric = CIMaskedReconSubsetLoss( model=model, device=device, - output_loss_type=run_config.output_loss_type, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLayerwiseLossConfig(): metric = CIMaskedReconLayerwiseLoss( - model=model, device=device, output_loss_type=run_config.output_loss_type + model=model, + device=device, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLossConfig(): metric = CIMaskedReconLoss( - model=model, device=device, output_loss_type=run_config.output_loss_type + model=model, + device=device, + reconstruction_loss=reconstruction_loss, ) case CIMeanPerComponentConfig(): metric = CIMeanPerComponent(model=model, device=device) @@ -196,7 +203,7 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLossConfig(): metric = StochasticReconLoss( @@ -205,7 +212,7 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetLossConfig(): metric = StochasticReconSubsetLoss( @@ -214,33 +221,33 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLossConfig(): metric = PGDReconLoss( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case PGDReconSubsetLossConfig(): metric = PGDReconSubsetLoss( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, pgd_config=cfg, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLayerwiseLossConfig(): metric = PGDReconLayerwiseLoss( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetCEAndKLConfig(): metric = StochasticReconSubsetCEAndKL( @@ -271,7 +278,7 @@ def init_metric( metric = UnmaskedReconLoss( model=model, device=device, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) case _: @@ -284,18 +291,25 @@ def init_metric( def evaluate( eval_metric_configs: list[MetricConfigType], model: ComponentModel, - eval_iterator: Iterator[Int[Tensor, "..."] | tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + eval_iterator: Iterator[Any], device: str, run_config: Config, slow_step: bool, n_eval_steps: int, current_frac_of_training: float, + reconstruction_loss: ReconstructionLoss, ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" metrics: list[Metric] = [] for cfg in eval_metric_configs: - metric = init_metric(cfg=cfg, model=model, run_config=run_config, device=device) + metric = init_metric( + cfg=cfg, + model=model, + run_config=run_config, + device=device, + reconstruction_loss=reconstruction_loss, + ) if metric.slow and not slow_step: continue metrics.append(metric) @@ -304,8 +318,7 @@ def evaluate( weight_deltas = model.calc_weight_deltas() for _ in range(n_eval_steps): - batch_raw = next(eval_iterator) - batch = extract_batch_data(batch_raw).to(device) + batch = next(eval_iterator) target_output: OutputWithCache = model(batch, cache_type="input") ci = model.calc_causal_importances( @@ -344,8 +357,8 @@ def evaluate_multibatch_pgd( model: ComponentModel, create_data_iter: CreateDataIter, config: Config, - batch_dims: tuple[int, ...], device: str, + reconstruction_loss: ReconstructionLoss, ) -> dict[str, float]: """Calculate multibatch PGD metrics.""" weight_deltas = model.calc_weight_deltas() if config.use_delta_component else None @@ -367,11 +380,10 @@ def evaluate_multibatch_pgd( model=model, weight_deltas=weight_deltas, create_data_iter=create_data_iter, - output_loss_type=config.output_loss_type, router=router, sampling=config.sampling, use_delta_component=config.use_delta_component, - batch_dims=batch_dims, device=device, + reconstruction_loss=reconstruction_loss, ).item() return metrics diff --git a/spd/experiments/ih/ih_config.yaml b/spd/experiments/ih/ih_config.yaml index 9c844723a..6391b5dd9 100644 --- a/spd/experiments/ih/ih_config.yaml +++ b/spd/experiments/ih/ih_config.yaml @@ -33,7 +33,6 @@ ci_recon_layerwise_coeff: null stochastic_recon_layerwise_coeff: 1 importance_minimality_coeff: 1e-2 pnorm: 0.1 -output_loss_type: kl ci_fn_type: "vector_mlp" ci_fn_hidden_dims: [128] diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 1b0b268fc..2ca5f8f1d 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -7,13 +7,11 @@ from spd.configs import Config, IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import ( - save_pre_run_info, - set_seed, -) +from spd.utils.general_utils import save_pre_run_info, set_seed from spd.utils.run_utils import ExecutionStamp from spd.utils.wandb_utils import init_wandb @@ -99,7 +97,8 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=run_batch_first_element, + reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/ih/model.py b/spd/experiments/ih/model.py index 84babad42..0d14b5866 100644 --- a/spd/experiments/ih/model.py +++ b/spd/experiments/ih/model.py @@ -210,7 +210,7 @@ def __init__(self, cfg: InductionModelConfig): self.unembed = nn.Linear(cfg.d_model, adjusted_vocab_size, bias=False) @override - def forward(self, tokens: Float[Tensor, "B S"], **_): + def forward(self, tokens: Float[Tensor, "B S"]) -> Float[Tensor, "B S V"]: x = self.token_embed(tokens) for block in self.blocks: diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index 3aa180fe5..a2414f99c 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -25,7 +25,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 2 @@ -64,7 +63,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: openai-community/gpt2 -pretrained_model_output_attr: logits +output_extract: logits tokenizer_name: openai-community/gpt2 # --- Task Specific --- diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 747a90832..e189ef1e2 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -7,8 +7,9 @@ import wandb from spd.configs import Config, LMTaskConfig -from spd.data import DatasetConfig, create_data_loader +from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.log import logger +from spd.models.batch_and_loss_fns import make_run_batch, recon_loss_kl from spd.pretrain.run_info import PretrainRunInfo from spd.run_spd import optimize from spd.utils.distributed_utils import ( @@ -140,6 +141,7 @@ def main( buffer_size=config.task_config.buffer_size, global_seed=config.seed, dist_state=dist_state, + collate_fn=lm_collate_fn, ) eval_data_config = DatasetConfig( @@ -169,18 +171,21 @@ def main( buffer_size=config.task_config.buffer_size, global_seed=config.seed + 1, dist_state=dist_state, + collate_fn=lm_collate_fn, ) if is_main_process(): logger.info("Starting optimization...") + assert config.output_extract is not None, "LM models require output_extract" optimize( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch(config.output_extract), + reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index fc179119f..efb62bc9c 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -67,7 +67,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 1.0e-04 warmup_pct: 0.0 @@ -115,7 +114,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-bd02d372 -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index c3c3a106a..8a8f7f632 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -67,7 +67,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 1.0e-04 warmup_pct: 0.0 @@ -119,7 +118,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-32d1bb3b -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index 9fcbeec47..d1b8977ca 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -25,7 +25,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 16 @@ -64,7 +63,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: SimpleStories/test-SimpleStories-gpt2-1.25M -pretrained_model_output_attr: logits +output_extract: logits tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index 790002de7..8ee2e0e0f 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -47,7 +47,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 0.0002 fn_type: cosine @@ -92,7 +91,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/3qhd7rnb # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index 4080b1634..f80c49936 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -47,7 +47,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 0.0002 fn_type: cosine @@ -94,7 +93,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/wr1su18m # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index ed6c497ee..36e6a581d 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -38,7 +38,6 @@ loss_metric_configs: routing: type: uniform_k_subset coeff: 1.0 -output_loss_type: kl # --- Training --- batch_size: 256 @@ -97,7 +96,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/rvu66183 # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index 72dba01f0..56ec9e14e 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -36,7 +36,6 @@ loss_metric_configs: coeff: 2.0 - classname: "StochasticReconLoss" coeff: 0.2 -output_loss_type: kl # --- Training --- batch_size: 48 @@ -94,7 +93,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/xi36b9az # No ln -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # We'll load this from wandb in future # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index f9324bb2d..59a609a4e 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -48,7 +48,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -92,7 +91,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tfacbi70 # 100k steps -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index 512d0d3a1..0c321c289 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -48,7 +48,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -94,7 +93,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tb8373uo # 100k steps -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 92d9eced4..d90cf9868 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -42,7 +42,6 @@ loss_metric_configs: type: uniform_k_subset coeff: 1.0 -output_loss_type: kl # --- Training --- batch_size: 256 @@ -94,7 +93,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_name: wandb:goodfire/spd/runs/erq48r3w # 100k steps 4019 tok -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index 9cb54c2de..62f7003cd 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -42,7 +42,6 @@ loss_metric_configs: classname: PGDReconSubsetLoss - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 400000 batch_size: 64 gradient_accumulation_steps: 1 @@ -86,7 +85,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gvbmdt9w -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index 712a288da..e7a6929ea 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -46,7 +46,6 @@ loss_metric_configs: type: uniform_k_subset - classname: FaithfulnessLoss coeff: 1000000 -output_loss_type: kl steps: 400000 batch_size: 128 gradient_accumulation_steps: 1 @@ -94,7 +93,7 @@ ci_alive_threshold: 0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: buffer_size: 1000 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index 72cf4937e..2262cbeaf 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -46,7 +46,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -92,7 +91,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/7pt957pf # 100k steps -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index 79622689d..271ed0813 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -42,7 +42,6 @@ loss_metric_configs: classname: PGDReconSubsetLoss - coeff: 100000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 400000 batch_size: 128 gradient_accumulation_steps: 1 @@ -117,7 +116,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/9de1zu65 # 100k steps -pretrained_model_output_attr: idx_0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index 4f8c966ac..f6f39aeb8 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 4 @@ -65,7 +64,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.AutoModelForCausalLM pretrained_model_name: roneneldan/TinyStories-1M -pretrained_model_output_attr: logits +output_extract: logits tokenizer_name: EleutherAI/gpt-neo-125M # --- Task Specific --- diff --git a/spd/experiments/resid_mlp/resid_mlp1_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_config.yaml index 1d8bc7fed..093c8e554 100644 --- a/spd/experiments/resid_mlp/resid_mlp1_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp1_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp2_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_config.yaml index bae662b6f..b9d7e48fb 100644 --- a/spd/experiments/resid_mlp/resid_mlp2_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp2_config.yaml @@ -34,7 +34,6 @@ loss_metric_configs: mask_scope: shared_across_batch - classname: "FaithfulnessLoss" coeff: 0.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp3_config.yaml b/spd/experiments/resid_mlp/resid_mlp3_config.yaml index dac4f9c10..c3dbda014 100644 --- a/spd/experiments/resid_mlp/resid_mlp3_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp3_config.yaml @@ -27,7 +27,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 75e423099..a61d8904a 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,6 +13,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device @@ -108,7 +109,8 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=run_batch_first_element, + reconstruction_loss=recon_loss_mse, out_dir=out_dir, ) diff --git a/spd/experiments/resid_mlp/resid_mlp_interp.py b/spd/experiments/resid_mlp/resid_mlp_interp.py index 209ef9747..506989176 100644 --- a/spd/experiments/resid_mlp/resid_mlp_interp.py +++ b/spd/experiments/resid_mlp/resid_mlp_interp.py @@ -9,7 +9,10 @@ from PIL import Image from torch import Tensor -from spd.experiments.resid_mlp.models import MLP, ResidMLP +from spd.experiments.resid_mlp.models import ( + MLP, + ResidMLP, +) from spd.experiments.tms.models import TMSModel from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo diff --git a/spd/experiments/tms/tms_40-10-id_config.yaml b/spd/experiments/tms/tms_40-10-id_config.yaml index e3e40d5fc..f3b21c094 100644 --- a/spd/experiments/tms/tms_40-10-id_config.yaml +++ b/spd/experiments/tms/tms_40-10-id_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: "mse" # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_40-10_config.yaml b/spd/experiments/tms/tms_40-10_config.yaml index a4aeb6a97..7d264c77d 100644 --- a/spd/experiments/tms/tms_40-10_config.yaml +++ b/spd/experiments/tms/tms_40-10_config.yaml @@ -27,7 +27,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: "mse" # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_5-2-id_config.yaml b/spd/experiments/tms/tms_5-2-id_config.yaml index c9b2234e8..11670b63d 100644 --- a/spd/experiments/tms/tms_5-2-id_config.yaml +++ b/spd/experiments/tms/tms_5-2-id_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 34c92fa08..cc5d6a668 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -26,7 +26,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 18c437a68..38dc47a6b 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -13,6 +13,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device @@ -104,7 +105,8 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=run_batch_first_element, + reconstruction_loss=recon_loss_mse, out_dir=out_dir, tied_weights=tied_weights, ) diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index 0c1d5daf1..e2afd6eab 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -34,7 +34,7 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import bf16_autocast, extract_batch_data +from spd.utils.general_utils import bf16_autocast def _compute_u_norms(model: ComponentModel) -> dict[str, Float[Tensor, " C"]]: @@ -231,7 +231,7 @@ def harvest_activation_contexts( batch_range = range(config.n_batches) if config.n_batches is not None else itertools.count() for batch_idx in tqdm.tqdm(batch_range, desc="Harvesting", disable=rank is not None): try: - batch_data = extract_batch_data(next(train_iter)) + batch = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break @@ -240,8 +240,8 @@ def harvest_activation_contexts( if world_size is not None and batch_idx % world_size != rank: continue - batch = batch_data.to(device) - with torch.no_grad(), bf16_autocast(): + batch = batch.to(device) + with torch.no_grad(), bf16_autocast(enabled=spd_config.autocast_bf16): out = model(batch, cache_type="input") probs = torch.softmax(out.output, dim=-1) diff --git a/spd/losses.py b/spd/losses.py index daef1773c..5d0aa9525 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -1,7 +1,7 @@ -from typing import Literal +from typing import Any import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from spd.configs import ( @@ -36,13 +36,15 @@ stochastic_recon_subset_loss, unmasked_recon_loss, ) +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel +from spd.utils.general_utils import get_obj_device def compute_total_loss( loss_metric_configs: list[LossMetricConfigType], model: ComponentModel, - batch: Int[Tensor, "..."], + batch: Any, ci: CIOutputs, target_out: Tensor, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], @@ -51,13 +53,13 @@ def compute_total_loss( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], dict[str, float]]: """Compute weighted total loss and per-term raw values using new loss primitives. Returns (total, terms_dict). terms_dict contains raw per-term values (no coeffs) and a weighted total. """ - total = torch.tensor(0.0, device=batch.device) + total = torch.tensor(0.0, device=get_obj_device(model)) terms: dict[str, float] = {} for cfg in loss_metric_configs: @@ -79,99 +81,99 @@ def compute_total_loss( case UnmaskedReconLossConfig(): loss = unmasked_recon_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconSubsetLossConfig(): loss = ci_masked_recon_subset_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLayerwiseLossConfig(): loss = ci_masked_recon_layerwise_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLossConfig(): loss = ci_masked_recon_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLayerwiseLossConfig(): loss = stochastic_recon_layerwise_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLossConfig(): loss = stochastic_recon_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetLossConfig(): loss = stochastic_recon_subset_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLossConfig(): loss = pgd_recon_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case PGDReconSubsetLossConfig(): loss = pgd_recon_subset_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, pgd_config=cfg, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLayerwiseLossConfig(): loss = pgd_recon_layerwise_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case StochasticHiddenActsReconLossConfig(): loss = stochastic_hidden_acts_recon_loss( diff --git a/spd/metrics/base.py b/spd/metrics/base.py index 97665464b..2e9a0fd4d 100644 --- a/spd/metrics/base.py +++ b/spd/metrics/base.py @@ -6,7 +6,7 @@ from typing import Any, ClassVar, Protocol -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from spd.models.component_model import CIOutputs @@ -21,12 +21,12 @@ class Metric(Protocol): def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, current_frac_of_training: float, - weight_deltas: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], ) -> None: """Update metric state with a batch of data.""" ... diff --git a/spd/metrics/ci_masked_recon_layerwise_loss.py b/spd/metrics/ci_masked_recon_layerwise_loss.py index b7ff12be9..2862eb9e7 100644 --- a/spd/metrics/ci_masked_recon_layerwise_loss.py +++ b/spd/metrics/ci_masked_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -6,51 +6,52 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm +from spd.utils.general_utils import get_obj_device def _ci_masked_recon_layerwise_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: - sum_loss = torch.tensor(0.0, device=batch.device) - n_examples = 0 + sum_loss = torch.tensor(0.0, device=get_obj_device(model)) + sum_n_examples = 0 mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) for module_name, mask_info in mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - n_examples += out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + loss, n_examples = reconstruction_loss(out, target_out) sum_loss += loss - return sum_loss, n_examples + sum_n_examples += n_examples + return sum_loss, sum_n_examples def _ci_masked_recon_layerwise_loss_compute( - sum_loss: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int + sum_loss: Float[Tensor, ""], sum_n_examples: Int[Tensor, ""] | int ) -> Float[Tensor, ""]: - return sum_loss / n_examples + return sum_loss / sum_n_examples def ci_masked_recon_layerwise_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: - sum_loss, n_examples = _ci_masked_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _ci_masked_recon_layerwise_loss_update( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=reconstruction_loss, ) - return _ci_masked_recon_layerwise_loss_compute(sum_loss, n_examples) + return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) class CIMaskedReconLayerwiseLoss(Metric): @@ -59,34 +60,37 @@ class CIMaskedReconLayerwiseLoss(Metric): metric_section: ClassVar[str] = "loss" def __init__( - self, model: ComponentModel, device: str, output_loss_type: Literal["mse", "kl"] + self, + model: ComponentModel, + device: str, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) + self.sum_n_examples = torch.tensor(0, device=device) @override def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, **_: Any, ) -> None: - sum_loss, n_examples = _ci_masked_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _ci_masked_recon_layerwise_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss - self.n_examples += n_examples + self.sum_n_examples += sum_n_examples @override def compute(self) -> Float[Tensor, ""]: sum_loss = all_reduce(self.sum_loss, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _ci_masked_recon_layerwise_loss_compute(sum_loss, n_examples) + sum_n_examples = all_reduce(self.sum_n_examples, op=ReduceOp.SUM) + return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) diff --git a/spd/metrics/ci_masked_recon_loss.py b/spd/metrics/ci_masked_recon_loss.py index a11c11469..2cb871064 100644 --- a/spd/metrics/ci_masked_recon_loss.py +++ b/spd/metrics/ci_masked_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -6,24 +6,22 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm def _ci_masked_recon_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - return loss, out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + return reconstruction_loss(out, target_out) def _ci_masked_recon_loss_compute( @@ -34,17 +32,17 @@ def _ci_masked_recon_loss_compute( def ci_masked_recon_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_loss_update( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=reconstruction_loss, ) return _ci_masked_recon_loss_compute(sum_loss, n_examples) @@ -55,10 +53,13 @@ class CIMaskedReconLoss(Metric): metric_section: ClassVar[str] = "loss" def __init__( - self, model: ComponentModel, device: str, output_loss_type: Literal["mse", "kl"] + self, + model: ComponentModel, + device: str, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -66,17 +67,17 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/ci_masked_recon_subset_loss.py b/spd/metrics/ci_masked_recon_subset_loss.py index 0a2e83441..785ae15bd 100644 --- a/spd/metrics/ci_masked_recon_subset_loss.py +++ b/spd/metrics/ci_masked_recon_subset_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,20 +7,21 @@ from spd.configs import SubsetRoutingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.routing import Router, get_subset_router from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm +from spd.utils.general_utils import get_obj_device def _ci_masked_recon_subset_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], router: Router, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: subset_routing_masks = router.get_masks( module_names=model.target_module_paths, @@ -32,9 +33,7 @@ def _ci_masked_recon_subset_loss_update( weight_deltas_and_masks=None, ) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - return loss, out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + return reconstruction_loss(out, target_out) def _ci_masked_recon_subset_loss_compute( @@ -45,19 +44,19 @@ def _ci_masked_recon_subset_loss_compute( def ci_masked_recon_subset_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, - router=get_subset_router(routing, batch.device), + router=get_subset_router(routing, device=get_obj_device(model)), + reconstruction_loss=reconstruction_loss, ) return _ci_masked_recon_subset_loss_compute(sum_loss, n_examples) @@ -71,13 +70,12 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.router = get_subset_router(routing, device) - + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -85,18 +83,18 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, router=self.router, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/identity_ci_error.py b/spd/metrics/identity_ci_error.py index d619c5082..7c3dcf881 100644 --- a/spd/metrics/identity_ci_error.py +++ b/spd/metrics/identity_ci_error.py @@ -32,9 +32,10 @@ def __init__( self.batch_shape: tuple[int, ...] | None = None @override - def update(self, *, batch: Tensor, **_: Any) -> None: + def update(self, *, batch: Tensor | tuple[Tensor, ...], **_: Any) -> None: if self.batch_shape is None: - self.batch_shape = tuple(batch.shape) + input_tensor = batch[0] if isinstance(batch, tuple) else batch + self.batch_shape = tuple(input_tensor.shape) @override def compute(self) -> dict[str, float]: diff --git a/spd/metrics/permuted_ci_plots.py b/spd/metrics/permuted_ci_plots.py index d5baa8b28..77d713499 100644 --- a/spd/metrics/permuted_ci_plots.py +++ b/spd/metrics/permuted_ci_plots.py @@ -30,9 +30,10 @@ def __init__( self.batch_shape: tuple[int, ...] | None = None @override - def update(self, *, batch: Tensor, **_: Any) -> None: + def update(self, *, batch: Tensor | tuple[Tensor, ...], **_: Any) -> None: if self.batch_shape is None: - self.batch_shape = tuple(batch.shape) + input_tensor = batch[0] if isinstance(batch, tuple) else batch + self.batch_shape = tuple(input_tensor.shape) @override def compute(self) -> dict[str, Image.Image]: diff --git a/spd/metrics/pgd_masked_recon_layerwise_loss.py b/spd/metrics/pgd_masked_recon_layerwise_loss.py index 787cad8a2..97c45f3bd 100644 --- a/spd/metrics/pgd_masked_recon_layerwise_loss.py +++ b/spd/metrics/pgd_masked_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -8,6 +8,7 @@ from spd.configs import PGDConfig from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import LayerRouter from spd.utils.distributed_utils import all_reduce @@ -16,12 +17,12 @@ def _pgd_recon_layerwise_loss_update( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], Int[Tensor, ""]]: device = next(iter(ci.values())).device sum_loss = torch.tensor(0.0, device=device) @@ -33,9 +34,9 @@ def _pgd_recon_layerwise_loss_update( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, router=LayerRouter(device=device, layer_name=layer), pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) sum_loss += sum_loss_layer n_examples += n_examples_layer @@ -45,21 +46,21 @@ def _pgd_recon_layerwise_loss_update( def pgd_recon_layerwise_loss( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _pgd_recon_layerwise_loss_update( model=model, batch=batch, target_out=target_out, - output_loss_type=output_loss_type, ci=ci, weight_deltas=weight_deltas, pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) return sum_loss / n_examples @@ -73,15 +74,15 @@ class PGDReconLayerwiseLoss(Metric): def __init__( self, model: ComponentModel, - output_loss_type: Literal["mse", "kl"], pgd_config: PGDConfig, device: str, use_delta_component: bool, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.use_delta_component: bool = use_delta_component + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -89,8 +90,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, @@ -99,10 +100,10 @@ def update( model=self.model, batch=batch, target_out=target_out, - output_loss_type=self.output_loss_type, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, pgd_config=self.pgd_config, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_masked_recon_loss.py b/spd/metrics/pgd_masked_recon_loss.py index 7d35e149f..4ab242393 100644 --- a/spd/metrics/pgd_masked_recon_loss.py +++ b/spd/metrics/pgd_masked_recon_loss.py @@ -1,13 +1,14 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp from spd.configs import PGDConfig from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.distributed_utils import all_reduce @@ -16,12 +17,12 @@ def pgd_recon_loss( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = pgd_masked_recon_loss_update( model=model, @@ -29,9 +30,9 @@ def pgd_recon_loss( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, router=AllLayersRouter(), pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) return sum_loss / n_examples @@ -46,14 +47,14 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], pgd_config: PGDConfig, use_delta_component: bool, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.use_delta_component: bool = use_delta_component + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -61,8 +62,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, @@ -73,9 +74,9 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, target_out=target_out, - output_loss_type=self.output_loss_type, router=AllLayersRouter(), pgd_config=self.pgd_config, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_masked_recon_subset_loss.py b/spd/metrics/pgd_masked_recon_subset_loss.py index a9e8a7eac..c904c14d7 100644 --- a/spd/metrics/pgd_masked_recon_subset_loss.py +++ b/spd/metrics/pgd_masked_recon_subset_loss.py @@ -1,28 +1,30 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp from spd.configs import PGDConfig, SubsetRoutingType from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import get_subset_router from spd.utils.distributed_utils import all_reduce +from spd.utils.general_utils import get_obj_device def pgd_recon_subset_loss( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = pgd_masked_recon_loss_update( model=model, @@ -30,9 +32,9 @@ def pgd_recon_subset_loss( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, - router=get_subset_router(routing, batch.device), + router=get_subset_router(routing, device=get_obj_device(model)), pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) return sum_loss / n_examples @@ -47,16 +49,16 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], use_delta_component: bool, pgd_config: PGDConfig, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.use_delta_component: bool = use_delta_component - self.router = get_subset_router(routing, device) + self.router = get_subset_router(routing, device=get_obj_device(model)) + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -65,8 +67,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -77,9 +79,9 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, target_out=target_out, - output_loss_type=self.output_loss_type, router=self.router, pgd_config=self.pgd_config, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index d83e42c6b..d26fe6291 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -1,30 +1,31 @@ from collections.abc import Callable, Iterator from functools import partial -from typing import Literal +from typing import Any import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp from spd.configs import PGDConfig, PGDInitStrategy, PGDMultiBatchConfig, SamplingType from spd.log import logger +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import RoutingMasks, make_mask_infos from spd.routing import Router from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, extract_batch_data +from spd.utils.general_utils import get_obj_device def pgd_masked_recon_loss_update( model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + target_out: Tensor, router: Router, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: """Central implementation of PGD masked reconstruction loss. @@ -45,7 +46,7 @@ def pgd_masked_recon_loss_update( singleton_batch_dims = [1 for _ in batch_dims] shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = _get_pgd_init_tensor( - pgd_config.init, shape, batch.device + pgd_config.init, shape, device=get_obj_device(model) ).requires_grad_(True) fwd_pass = partial( @@ -57,8 +58,8 @@ def pgd_masked_recon_loss_update( weight_deltas=weight_deltas, routing_masks=routing_masks, target_out=target_out, - output_loss_type=output_loss_type, batch_dims=batch_dims, + reconstruction_loss=reconstruction_loss, ) for _ in range(pgd_config.n_steps): @@ -79,10 +80,7 @@ def pgd_masked_recon_loss_update( return fwd_pass() -CreateDataIter = Callable[ - [], - Iterator[Int[Tensor, "..."]] | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], -] +CreateDataIter = Callable[[], Iterator[Any]] def calc_multibatch_pgd_masked_recon_loss( @@ -90,12 +88,11 @@ def calc_multibatch_pgd_masked_recon_loss( model: ComponentModel, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, create_data_iter: CreateDataIter, - output_loss_type: Literal["mse", "kl"], router: Router, sampling: SamplingType, use_delta_component: bool, - batch_dims: tuple[int, ...], device: str, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: """PGD masked reconstruction loss with gradient accumulation over multiple batches. @@ -108,19 +105,25 @@ def calc_multibatch_pgd_masked_recon_loss( create_data_iter: Function to create an iterator over batches. This function should return an iterator which behaves identically each time. Specifically in terms of data ordering and shuffling. - output_loss_type: Loss type for reconstruction ("mse" or "kl") router: Router to use for routing masks sampling: Sampling mode for causal importance calculation use_delta_component: Whether to include weight delta component - batch_dims: Dimensions of batch (e.g., (batch_size,) or (batch_size, seq_len)) + reconstruction_loss: Function to compute reconstruction loss Returns: Final reconstruction loss after PGD optimization """ - singleton_batch_dims = [1 for _ in batch_dims] + + demo_batch = next(create_data_iter()) + demo_output = model(demo_batch, cache_type="input") + ci_demo = model.calc_causal_importances( + pre_weight_acts=demo_output.cache, sampling=sampling + ).lower_leaky adv_sources: dict[str, Float[Tensor, "*ones mask_c"]] = {} for module_name in model.target_module_paths: - module_c = model.module_to_c[module_name] + demo_ci = ci_demo[module_name] + *batch_dims, module_c = demo_ci.shape + singleton_batch_dims = [1 for _ in batch_dims] mask_c = module_c if not use_delta_component else module_c + 1 shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = _get_pgd_init_tensor( @@ -134,35 +137,34 @@ def calc_multibatch_pgd_masked_recon_loss( model=model, weight_deltas=weight_deltas, device=device, - output_loss_type=output_loss_type, sampling=sampling, router=router, - batch_dims=batch_dims, + reconstruction_loss=reconstruction_loss, ) for _ in range(pgd_config.n_steps): assert all(adv.grad is None for adv in adv_sources.values()) - _, _, adv_sources_grads = fwd_bwd_fn(data_iter=create_data_iter()) + _, _, adv_sources_sum_grads = fwd_bwd_fn(data_iter=create_data_iter()) with torch.no_grad(): for k in adv_sources: - adv_sources[k].add_(pgd_config.step_size * adv_sources_grads[k].sign()) + adv_sources[k].add_(pgd_config.step_size * adv_sources_sum_grads[k].sign()) adv_sources[k].clamp_(0.0, 1.0) - final_loss, final_n_examples, _ = fwd_bwd_fn(data_iter=create_data_iter()) - return final_loss / final_n_examples + final_loss, final_sum_n_examples, _ = fwd_bwd_fn(data_iter=create_data_iter()) + return final_loss / final_sum_n_examples def _forward_with_adv_sources( model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, adv_sources: dict[str, Float[Tensor, "*batch_dim_or_ones mask_c"]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing_masks: RoutingMasks, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + target_out: Tensor, batch_dims: tuple[int, ...], + reconstruction_loss: ReconstructionLoss, ): expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] @@ -183,11 +185,7 @@ def _forward_with_adv_sources( ) out = model(batch, mask_infos=mask_infos) - sum_loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - - n_examples = ( - target_out.shape.numel() if output_loss_type == "mse" else target_out.shape[:-1].numel() - ) + sum_loss, n_examples = reconstruction_loss(out, target_out) return sum_loss, n_examples @@ -197,13 +195,11 @@ def _multibatch_pgd_fwd_bwd( pgd_config: PGDMultiBatchConfig, model: ComponentModel, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - data_iter: Iterator[Int[Tensor, "..."]] - | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + data_iter: Iterator[Any], device: torch.device | str, - output_loss_type: Literal["mse", "kl"], router: Router, sampling: SamplingType, - batch_dims: tuple[int, ...], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int, dict[str, Float[Tensor, "*ones mask_c"]]]: """Perform a forward and backward pass over multiple batches with gradient accumulation. @@ -213,16 +209,15 @@ def _multibatch_pgd_fwd_bwd( - The gradients of the adv_sources (dict keyed by module name) """ pgd_step_accum_sum_loss = torch.tensor(0.0, device=device) - pgd_step_accum_n_examples = 0 - pgd_step_accum_grads = {k: torch.zeros_like(v) for k, v in adv_sources.items()} + pgd_step_accum_sum_n_examples = 0 + pgd_step_accum_sum_grads = {k: torch.zeros_like(v) for k, v in adv_sources.items()} for microbatch_idx in range(pgd_config.gradient_accumulation_steps): try: - microbatch_item = next(data_iter) + microbatch = next(data_iter) except StopIteration: logger.warning(f"Dataloader exhausted after {microbatch_idx} batches, ending PGD step.") break - microbatch = extract_batch_data(microbatch_item).to(device) # NOTE: technically this is duplicated work across PGD steps, but that's the price we pay to # enable accumulating gradients over more microbatches than we'd be able to fit CI values in @@ -233,13 +228,15 @@ def _multibatch_pgd_fwd_bwd( sampling=sampling, ).lower_leaky + batch_dims = next(iter(ci.values())).shape[:-1] + # It's important that we call this every microbatch to ensure stochastic routing masks are # sampled independently for each example. routing_masks = router.get_masks( module_names=model.target_module_paths, mask_shape=batch_dims ) - batch_sum_loss, batch_n_examples = _forward_with_adv_sources( + batch_sum_loss, batch_sum_n_examples = _forward_with_adv_sources( model=model, batch=microbatch, adv_sources=adv_sources, @@ -247,21 +244,21 @@ def _multibatch_pgd_fwd_bwd( weight_deltas=weight_deltas, routing_masks=routing_masks, target_out=target_model_output.output, - output_loss_type=output_loss_type, batch_dims=batch_dims, + reconstruction_loss=reconstruction_loss, ) pgd_step_accum_sum_loss += batch_sum_loss - pgd_step_accum_n_examples += batch_n_examples + pgd_step_accum_sum_n_examples += batch_sum_n_examples # important: take gradient wrt the UNEXPANDED adv_sources, not the expanded ones grads = torch.autograd.grad(batch_sum_loss, list(adv_sources.values())) for k, g in zip(adv_sources.keys(), grads, strict=True): - pgd_step_accum_grads[k] += all_reduce(g, op=ReduceOp.AVG).detach() + pgd_step_accum_sum_grads[k] += all_reduce(g, op=ReduceOp.AVG).detach() del target_model_output, ci - return pgd_step_accum_sum_loss, pgd_step_accum_n_examples, pgd_step_accum_grads + return pgd_step_accum_sum_loss, pgd_step_accum_sum_n_examples, pgd_step_accum_sum_grads def _get_pgd_init_tensor( diff --git a/spd/metrics/stochastic_hidden_acts_recon_loss.py b/spd/metrics/stochastic_hidden_acts_recon_loss.py index 814e6e18c..2e97e2dda 100644 --- a/spd/metrics/stochastic_hidden_acts_recon_loss.py +++ b/spd/metrics/stochastic_hidden_acts_recon_loss.py @@ -18,7 +18,7 @@ def _stochastic_hidden_acts_recon_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -63,7 +63,7 @@ def stochastic_hidden_acts_recon_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -104,7 +104,7 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/stochastic_recon_layerwise_loss.py b/spd/metrics/stochastic_recon_layerwise_loss.py index b14d57fe3..d90675e6e 100644 --- a/spd/metrics/stochastic_recon_layerwise_loss.py +++ b/spd/metrics/stochastic_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,27 +7,28 @@ from spd.configs import SamplingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device def _stochastic_recon_layerwise_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) sum_loss = torch.tensor(0.0, device=device) - n_examples = 0 + sum_n_examples = 0 stochastic_mask_infos_list = [ calc_stochastic_component_mask_info( @@ -42,40 +43,39 @@ def _stochastic_recon_layerwise_loss_update( for stochastic_mask_infos in stochastic_mask_infos_list: for module_name, mask_info in stochastic_mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - - n_examples += out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + loss, batch_n_examples = reconstruction_loss(out, target_out) sum_loss += loss - return sum_loss, n_examples + sum_n_examples += batch_n_examples + return sum_loss, sum_n_examples def _stochastic_recon_layerwise_loss_compute( - sum_loss: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int + sum_loss: Float[Tensor, ""], sum_n_examples: Int[Tensor, ""] | int ) -> Float[Tensor, ""]: - return sum_loss / n_examples + return sum_loss / sum_n_examples def stochastic_recon_layerwise_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: - sum_loss, n_examples = _stochastic_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _stochastic_recon_layerwise_loss_update( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=reconstruction_loss, ) - return _stochastic_recon_layerwise_loss_compute(sum_loss, n_examples) + return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) class StochasticReconLayerwiseLoss(Metric): @@ -90,41 +90,41 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) + self.sum_n_examples = torch.tensor(0, device=device) @override def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, ) -> None: - sum_loss, n_examples = _stochastic_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _stochastic_recon_layerwise_loss_update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss - self.n_examples += n_examples + self.sum_n_examples += sum_n_examples @override def compute(self) -> Float[Tensor, ""]: sum_loss = all_reduce(self.sum_loss, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _stochastic_recon_layerwise_loss_compute(sum_loss, n_examples) + sum_n_examples = all_reduce(self.sum_n_examples, op=ReduceOp.SUM) + return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) diff --git a/spd/metrics/stochastic_recon_loss.py b/spd/metrics/stochastic_recon_loss.py index 46cb0ad61..793fbed32 100644 --- a/spd/metrics/stochastic_recon_loss.py +++ b/spd/metrics/stochastic_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,73 +7,71 @@ from spd.configs import SamplingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device def _stochastic_recon_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) sum_loss = torch.tensor(0.0, device=device) - n_examples = 0 + sum_n_examples = 0 - stoch_mask_infos_list = [ - calc_stochastic_component_mask_info( + for _ in range(n_mask_samples): + stoch_mask_infos = calc_stochastic_component_mask_info( causal_importances=ci, component_mask_sampling=sampling, weight_deltas=weight_deltas, router=AllLayersRouter(), ) - for _ in range(n_mask_samples) - ] - for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples += out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, n_examples = reconstruction_loss(out, target_out) sum_loss += loss - return sum_loss, n_examples + sum_n_examples += n_examples + + return sum_loss, sum_n_examples def _stochastic_recon_loss_compute( - sum_loss: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int + sum_loss: Float[Tensor, ""], sum_n_examples: Int[Tensor, ""] | int ) -> Float[Tensor, ""]: - return sum_loss / n_examples + return sum_loss / sum_n_examples def stochastic_recon_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: - sum_loss, n_examples = _stochastic_recon_loss_update( - model, - sampling, - n_mask_samples, - output_loss_type, - batch, - target_out, - ci, - weight_deltas, + sum_loss, sum_n_examples = _stochastic_recon_loss_update( + model=model, + sampling=sampling, + n_mask_samples=n_mask_samples, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=weight_deltas, + reconstruction_loss=reconstruction_loss, ) - return _stochastic_recon_loss_compute(sum_loss, n_examples) + return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) class StochasticReconLoss(Metric): @@ -88,41 +86,41 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) + self.sum_n_examples = torch.tensor(0, device=device) @override def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, ) -> None: - sum_loss, n_examples = _stochastic_recon_loss_update( + sum_loss, sum_n_examples = _stochastic_recon_loss_update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss - self.n_examples += n_examples + self.sum_n_examples += sum_n_examples @override def compute(self) -> Float[Tensor, ""]: sum_loss = all_reduce(self.sum_loss, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _stochastic_recon_loss_compute(sum_loss, n_examples) + sum_n_examples = all_reduce(self.sum_n_examples, op=ReduceOp.SUM) + return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) diff --git a/spd/metrics/stochastic_recon_subset_loss.py b/spd/metrics/stochastic_recon_subset_loss.py index 62573a889..85293b17b 100644 --- a/spd/metrics/stochastic_recon_subset_loss.py +++ b/spd/metrics/stochastic_recon_subset_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,23 +7,24 @@ from spd.configs import SamplingType, SubsetRoutingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import Router, get_subset_router from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device def _stochastic_recon_subset_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, router: Router, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -42,11 +43,9 @@ def _stochastic_recon_subset_loss_update( for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples += out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, batch_n_examples = reconstruction_loss(out, target_out) sum_loss += loss - + n_examples += batch_n_examples return sum_loss, n_examples @@ -60,23 +59,23 @@ def stochastic_recon_subset_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _stochastic_recon_subset_loss_update( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, - router=get_subset_router(routing, batch.device), + router=get_subset_router(routing, device=get_obj_device(model)), + reconstruction_loss=reconstruction_loss, ) return _stochastic_recon_subset_loss_compute(sum_loss, n_examples) @@ -93,15 +92,15 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.router = get_subset_router(routing, device) + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -109,8 +108,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -119,12 +118,12 @@ def update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, router=self.router, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/unmasked_recon_loss.py b/spd/metrics/unmasked_recon_loss.py index 01cf67fe0..e113d5e36 100644 --- a/spd/metrics/unmasked_recon_loss.py +++ b/spd/metrics/unmasked_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -6,29 +6,28 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm +from spd.utils.general_utils import get_obj_device def _unmasked_recon_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: all_ones_mask_infos = make_mask_infos( # (C,) will broadcast to (B, S, C) { - module_path: torch.ones(model.module_to_c[module_path], device=batch.device) + module_path: torch.ones(model.module_to_c[module_path], device=get_obj_device(model)) for module_path in model.target_module_paths } ) out = model(batch, mask_infos=all_ones_mask_infos) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - n_examples = out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() - return loss, n_examples + return reconstruction_loss(out, target_out) def _unmasked_recon_loss_compute( @@ -39,15 +38,15 @@ def _unmasked_recon_loss_compute( def unmasked_recon_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _unmasked_recon_loss_update( model, - output_loss_type, batch, target_out, + reconstruction_loss, ) return _unmasked_recon_loss_compute(sum_loss, n_examples) @@ -61,10 +60,10 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -72,15 +71,15 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Tensor, **_: Any, ) -> None: sum_loss, n_examples = _unmasked_recon_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/uv_plots.py b/spd/metrics/uv_plots.py index 26880f19e..0d29c6401 100644 --- a/spd/metrics/uv_plots.py +++ b/spd/metrics/uv_plots.py @@ -30,9 +30,10 @@ def __init__( self.batch_shape: tuple[int, ...] | None = None @override - def update(self, *, batch: Tensor, **_: Any) -> None: + def update(self, *, batch: Tensor | tuple[Tensor, ...], **_: Any) -> None: if self.batch_shape is None: - self.batch_shape = tuple(batch.shape) + input_tensor = batch[0] if isinstance(batch, tuple) else batch + self.batch_shape = tuple(input_tensor.shape) @override def compute(self) -> dict[str, Image.Image]: diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py new file mode 100644 index 000000000..6ec940167 --- /dev/null +++ b/spd/models/batch_and_loss_fns.py @@ -0,0 +1,86 @@ +"""Batch handling and reconstruction loss functions for different model types. + +These functions parameterize ComponentModel and training for different target model architectures. +""" + +from typing import Any, Protocol + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor, nn + +from spd.utils.general_utils import runtime_cast + + +class RunBatch(Protocol): + """Protocol for running a batch through a model and returning the output.""" + + def __call__(self, model: nn.Module, batch: Any) -> Tensor: ... + + +class ReconstructionLoss(Protocol): + """Protocol for computing reconstruction loss between predictions and targets.""" + + def __call__(self, pred: Tensor, target: Tensor) -> tuple[Float[Tensor, ""], int]: ... + + +def run_batch_passthrough(model: nn.Module, batch: Any) -> Tensor: + return runtime_cast(Tensor, model(batch)) + + +def run_batch_first_element(model: nn.Module, batch: Any) -> Tensor: + """Run model on the first element of a batch tuple (e.g. (input, labels) -> model(input)).""" + return runtime_cast(Tensor, model(batch[0])) + + +def make_run_batch(output_extract: int | str | None) -> RunBatch: + """Creates a RunBatch function for a given configuration. + + NOTE: If you plan to override the RunBatch functionality, you can simply pass + a custom RunBatch function into optimize and do not need to use this function at + all. + + Args: + output_extract: How to extract the tensor from model output. + None: passthrough (model output is the tensor) + int: index into model output tuple (e.g. 0 for first element) + str: attribute name on model output (e.g. "logits") + """ + match output_extract: + case None: + return run_batch_passthrough + case int(idx): + + def _run_index(model: nn.Module, batch: Any) -> Tensor: + return model(batch)[idx] + + return _run_index + case str(attr): + + def _run_attr(model: nn.Module, batch: Any) -> Tensor: + return getattr(model(batch), attr) + + return _run_attr + + +def recon_loss_mse( + pred: Float[Tensor, "... d"], + target: Float[Tensor, "... d"], +) -> tuple[Float[Tensor, ""], int]: + """MSE reconstruction loss. Returns (sum_of_squared_errors, n_elements).""" + assert pred.shape == target.shape + squared_errors = (pred - target) ** 2 + return squared_errors.sum(), pred.numel() + + +def recon_loss_kl( + pred: Float[Tensor, "... vocab"], + target: Float[Tensor, "... vocab"], +) -> tuple[Float[Tensor, ""], int]: + """KL divergence reconstruction loss for logits. Returns (sum_of_kl, n_positions).""" + assert pred.shape == target.shape + log_q = torch.log_softmax(pred, dim=-1) # log Q + p = torch.softmax(target, dim=-1) # P + kl_per_position = F.kl_div(log_q, p, reduction="none").sum(dim=-1) # P · (log P − log Q) + return kl_per_position.sum(), pred[..., 0].numel() diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 4a2a20fb2..5376aae65 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Generator, Sequence +from collections.abc import Callable, Generator from contextlib import contextmanager from dataclasses import dataclass from functools import partial @@ -13,6 +13,7 @@ from spd.configs import Config, SamplingType from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo +from spd.models.batch_and_loss_fns import RunBatch, make_run_batch from spd.models.components import ( Components, ComponentsMaskInfo, @@ -52,7 +53,7 @@ class CIOutputs: pre_sigmoid: dict[str, Tensor] -class ComponentModel(LoadableModule): +class ComponentModel(nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. The underlying *base model* can be any subclass of `nn.Module` (e.g. @@ -74,13 +75,14 @@ class ComponentModel(LoadableModule): def __init__( self, target_model: nn.Module, + run_batch: RunBatch, module_path_info: list[ModulePathInfo], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, - pretrained_model_output_attr: str | None, ): super().__init__() + self._run_batch: RunBatch = run_batch for name, param in target_model.named_parameters(): assert not param.requires_grad, ( @@ -89,7 +91,6 @@ def __init__( ) self.target_model = target_model - self.pretrained_model_output_attr = pretrained_model_output_attr self.module_to_c = {info.module_path: info.C for info in module_path_info} self.target_module_paths = list(self.module_to_c.keys()) @@ -119,6 +120,69 @@ def __init__( self.lower_leaky_fn = SIGMOID_TYPES[sigmoid_type] self.upper_leaky_fn = SIGMOID_TYPES[sigmoid_type] + @classmethod + def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": + """Load a trained ComponentModel from a run info object.""" + config = run_info.config + + model_class = resolve_class(config.pretrained_model_class) + if config.pretrained_model_name is not None: + assert hasattr(model_class, "from_pretrained"), ( + f"Model class {model_class} should have a `from_pretrained` method" + ) + # Handle spd.pretrain models: patch missing model_type in old pretrain runs + if config.pretrained_model_class.startswith("spd.pretrain.models."): + from spd.pretrain.run_info import PretrainRunInfo + + pretrain_run_info = PretrainRunInfo.from_path(config.pretrained_model_name) + if "model_type" not in pretrain_run_info.model_config_dict: + pretrain_run_info.model_config_dict["model_type"] = ( + config.pretrained_model_class.split(".")[-1] + ) + target_model = model_class.from_run_info(pretrain_run_info) # pyright: ignore[reportAttributeAccessIssue] + else: + target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] + else: + assert issubclass(model_class, LoadableModule), ( + f"Model class {model_class} should be a subclass of LoadableModule which " + "defines a `from_pretrained` method" + ) + assert config.pretrained_model_path is not None + target_model = model_class.from_pretrained(config.pretrained_model_path) + + target_model.eval() + target_model.requires_grad_(False) + + if config.identity_module_info is not None: + insert_identity_operations_( + target_model, + identity_module_info=config.identity_module_info, + ) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + + run_batch = make_run_batch(config.output_extract) + + comp_model = ComponentModel( + target_model=target_model, + run_batch=run_batch, + module_path_info=module_path_info, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_fn_type=config.ci_fn_type, + sigmoid_type=config.sigmoid_type, + ) + + weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) + handle_deprecated_state_dict_keys_(weights) + comp_model.load_state_dict(weights) + return comp_model + + @classmethod + def from_pretrained(cls, path: ModelPath) -> "ComponentModel": + """Load a trained ComponentModel from a wandb or local path.""" + run_info = SPDRunInfo.from_path(path) + return cls.from_run_info(run_info) + def target_weight(self, module_name: str) -> Float[Tensor, "rows cols"]: target_module = self.target_model.get_submodule(module_name) @@ -234,53 +298,28 @@ def _create_ci_fns( ) return ci_fns - def _extract_output(self, raw_output: Any) -> Tensor: - """Extract the desired output from the model's raw output. - - If pretrained_model_output_attr is None, returns the raw output directly. - If pretrained_model_output_attr starts with "idx_", returns the index specified by the - second part of the string. E.g. "idx_0" returns the first element of the raw output. - Otherwise, returns the specified attribute from the raw output. - - Args: - raw_output: The raw output from the model. - - Returns: - The extracted output. - """ - if self.pretrained_model_output_attr is None: - out = raw_output - elif self.pretrained_model_output_attr.startswith("idx_"): - idx_val = int(self.pretrained_model_output_attr.split("_")[1]) - assert isinstance(raw_output, Sequence), ( - f"raw_output must be a sequence, not {type(raw_output)}" - ) - assert idx_val < len(raw_output), ( - f"Index {idx_val} out of range for raw_output of length {len(raw_output)}" - ) - out = raw_output[idx_val] - else: - out = getattr(raw_output, self.pretrained_model_output_attr) - - assert isinstance(out, Tensor), f"Expected tensor output, got {type(out)}" - return out + @overload + def __call__( + self, + batch: Any, + cache_type: Literal["component_acts"], + mask_infos: dict[str, ComponentsMaskInfo] | None = None, + ) -> OutputWithCache: ... @overload def __call__( self, - *args: Any, + batch: Any, + cache_type: Literal["input"], mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["component_acts", "input"], - **kwargs: Any, ) -> OutputWithCache: ... @overload def __call__( self, - *args: Any, + batch: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["none"] = "none", - **kwargs: Any, ) -> Tensor: ... @override @@ -290,10 +329,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: @override def forward( self, - *args: Any, + batch: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["component_acts", "input", "none"] = "none", - **kwargs: Any, ) -> Tensor | OutputWithCache: """Forward pass with optional component replacement and/or input caching. @@ -318,8 +356,7 @@ def forward( model output tensor. """ if mask_infos is None and cache_type == "none": - # No hooks needed. Do a regular forward pass of the target model. - return self._extract_output(self.target_model(*args, **kwargs)) + return self._run_batch(self.target_model, batch) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -340,9 +377,8 @@ def forward( ) with self._attach_forward_hooks(hooks): - raw_out = self.target_model(*args, **kwargs) + out: Tensor = self._run_batch(self.target_model, batch) - out = self._extract_output(raw_out) match cache_type: case "input" | "component_acts": return OutputWithCache(output=out, cache=cache) @@ -424,74 +460,6 @@ def _attach_forward_hooks(self, hooks: dict[str, Callable[..., Any]]) -> Generat for handle in handles: handle.remove() - @classmethod - @override - def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": - """Load a trained ComponentModel checkpoint from a run info object.""" - config = run_info.config - - # Load the target model - model_class = resolve_class(config.pretrained_model_class) - if config.pretrained_model_name is not None: - assert hasattr(model_class, "from_pretrained"), ( - f"Model class {model_class} should have a `from_pretrained` method" - ) - # Handle spd.pretrain models: patch missing model_type in old pretrain runs - if config.pretrained_model_class.startswith("spd.pretrain.models."): - from spd.pretrain.run_info import PretrainRunInfo - - pretrain_run_info = PretrainRunInfo.from_path(config.pretrained_model_name) - if "model_type" not in pretrain_run_info.model_config_dict: - pretrain_run_info.model_config_dict["model_type"] = ( - config.pretrained_model_class.split(".")[-1] - ) - target_model = model_class.from_run_info(pretrain_run_info) # pyright: ignore[reportAttributeAccessIssue] - else: - target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] - else: - assert issubclass(model_class, LoadableModule), ( - f"Model class {model_class} should be a subclass of LoadableModule which " - "defines a `from_pretrained` method" - ) - assert run_info.config.pretrained_model_path is not None - target_model = model_class.from_pretrained(run_info.config.pretrained_model_path) - - target_model.eval() - target_model.requires_grad_(False) - - if config.identity_module_info is not None: - insert_identity_operations_( - target_model, - identity_module_info=config.identity_module_info, - ) - - module_path_info = expand_module_patterns(target_model, config.all_module_info) - - comp_model = ComponentModel( - target_model=target_model, - module_path_info=module_path_info, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - pretrained_model_output_attr=config.pretrained_model_output_attr, - sigmoid_type=config.sigmoid_type, - ) - - comp_model_weights = torch.load( - run_info.checkpoint_path, map_location="cpu", weights_only=True - ) - - handle_deprecated_state_dict_keys_(comp_model_weights) - - comp_model.load_state_dict(comp_model_weights) - return comp_model - - @classmethod - @override - def from_pretrained(cls, path: ModelPath) -> "ComponentModel": - """Load a trained ComponentModel checkpoint from a local or wandb path.""" - run_info = SPDRunInfo.from_path(path) - return cls.from_run_info(run_info) - def calc_causal_importances( self, pre_weight_acts: dict[str, Float[Tensor, "... d_in"] | Int[Tensor, "... pos"]], diff --git a/spd/run_spd.py b/spd/run_spd.py index 337f186b1..cda6c29b6 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -4,16 +4,14 @@ from collections import defaultdict from collections.abc import Iterator from pathlib import Path -from typing import cast +from typing import Any, cast import torch import torch.nn as nn import torch.nn.parallel -import torch.optim as optim import wandb -from jaxtyping import Float, Int from PIL import Image -from torch import Tensor +from torch import optim from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from tqdm import tqdm @@ -32,6 +30,7 @@ from spd.log import logger from spd.losses import compute_total_loss from spd.metrics import faithfulness_loss +from spd.models.batch_and_loss_fns import ReconstructionLoss, RunBatch from spd.models.component_model import ComponentModel, OutputWithCache from spd.utils.component_utils import calc_ci_l_zero from spd.utils.distributed_utils import ( @@ -43,7 +42,6 @@ from spd.utils.general_utils import ( bf16_autocast, dict_safe_update_, - extract_batch_data, get_scheduled_value, ) from spd.utils.logging_utils import get_grad_norms_dict, local_log @@ -115,11 +113,10 @@ def optimize( target_model: nn.Module, config: Config, device: str, - train_loader: DataLoader[Int[Tensor, "..."]] - | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], - eval_loader: DataLoader[Int[Tensor, "..."]] - | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], - n_eval_steps: int, + train_loader: DataLoader[Any], + eval_loader: DataLoader[Any], + run_batch: RunBatch, + reconstruction_loss: ReconstructionLoss, out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, ) -> None: @@ -128,9 +125,7 @@ def optimize( train_iterator = loop_dataloader(train_loader) eval_iterator = loop_dataloader(eval_loader) - def create_pgd_data_iter() -> ( - Iterator[Int[Tensor, "..."]] | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]] - ): + def create_pgd_data_iter() -> Iterator[Any]: assert hasattr(train_loader, "generator") and train_loader.generator is not None train_loader.generator.manual_seed(config.seed) return iter(train_loader) @@ -150,10 +145,10 @@ def create_pgd_data_iter() -> ( model = ComponentModel( target_model=target_model, + run_batch=run_batch, module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) @@ -162,6 +157,8 @@ def create_pgd_data_iter() -> ( # Wrap model with DDP if distributed dist_state = get_distributed_state() wrapped_model: nn.Module = model + + component_model: ComponentModel if dist_state is not None: if dist_state.backend == "nccl": device_id = dist_state.local_rank @@ -174,7 +171,7 @@ def create_pgd_data_iter() -> ( # For CPU, don't pass device_ids or output_device wrapped_model = torch.nn.parallel.DistributedDataParallel(model) # Access the underlying module for component operations - component_model = wrapped_model.module # type: ignore[attr-defined] + component_model = cast(ComponentModel, wrapped_model.module) # type: ignore[attr-defined] else: component_model = model assert isinstance(component_model, ComponentModel), "component_model is not a ComponentModel" @@ -216,14 +213,6 @@ def create_pgd_data_iter() -> ( eval_metric_configs = [ cfg for cfg in eval_metric_configs if cfg not in multibatch_pgd_eval_configs ] - batch_dims: tuple[int, ...] | None = None - - sample_batch = extract_batch_data(next(train_iterator)) - batch_dims = ( - sample_batch.shape[:-1] - if config.output_loss_type == "mse" # if mse then input is a vector - else sample_batch.shape # else it's a batch of token ids - ) for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): optimizer.zero_grad() @@ -239,13 +228,11 @@ def create_pgd_data_iter() -> ( microbatch_log_data: defaultdict[str, float] = defaultdict(float) for _ in range(config.gradient_accumulation_steps): - microbatch = extract_batch_data(next(train_iterator)).to(device) - + microbatch = next(train_iterator) with bf16_autocast(enabled=config.autocast_bf16): - # NOTE: we need to call the wrapped_model at least once each step in order - # to setup the DDP gradient syncing for all parameters in the component model. - # Gradients will sync regardless of whether the parameters are used in this - # call to wrapped_model. + # NOTE: we need to call the wrapped_model at least once each step in order to setup + # the DDP gradient syncing for all parameters in the component model. Gradients will + # sync regardless of whether the parameters are used in this call to wrapped_model. target_model_output: OutputWithCache = wrapped_model(microbatch, cache_type="input") ci = component_model.calc_causal_importances( @@ -266,9 +253,8 @@ def create_pgd_data_iter() -> ( sampling=config.sampling, use_delta_component=config.use_delta_component, n_mask_samples=config.n_mask_samples, - output_loss_type=config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) - microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() for loss_name, loss_value in microbatch_loss_terms.items(): @@ -313,14 +299,13 @@ def create_pgd_data_iter() -> ( else step % config.slow_eval_freq == 0 ) - assert batch_dims is not None, "batch_dims is not set" multibatch_pgd_metrics = evaluate_multibatch_pgd( multibatch_pgd_eval_configs=multibatch_pgd_eval_configs, model=component_model, create_data_iter=create_pgd_data_iter, config=config, - batch_dims=batch_dims, device=device, + reconstruction_loss=reconstruction_loss, ) metrics = evaluate( @@ -330,8 +315,9 @@ def create_pgd_data_iter() -> ( device=device, run_config=config, slow_step=slow_step, - n_eval_steps=n_eval_steps, + n_eval_steps=config.n_eval_steps, current_frac_of_training=step / config.steps, + reconstruction_loss=reconstruction_loss, ) dict_safe_update_(metrics, multibatch_pgd_metrics) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 4bfee9c9a..3e93da8ef 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -25,7 +25,7 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data, get_obj_device +from spd.utils.general_utils import get_obj_device from spd.utils.run_utils import save_file @@ -82,6 +82,7 @@ def __init__(self, config: CompareModelsConfig): def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, Config]: """Load model and config using the standard pattern from existing codebase.""" run_info = SPDRunInfo.from_path(model_path) + # TODO(oli): this should actually be generic (one of the only instances of this I think) model = ComponentModel.from_run_info(run_info) model.to(self.device) model.eval() @@ -250,8 +251,7 @@ def compute_activation_densities( model.eval() with torch.no_grad(): for _step in range(n_steps): - batch = extract_batch_data(next(eval_iterator)) - batch = batch.to(self.device) + batch = next(eval_iterator)["input_ids"].to(self.device) pre_weight_acts = model(batch, cache_type="input").cache ci = model.calc_causal_importances( diff --git a/spd/utils/general_utils.py b/spd/utils/general_utils.py index 1fda48111..de30831fb 100644 --- a/spd/utils/general_utils.py +++ b/spd/utils/general_utils.py @@ -2,7 +2,7 @@ import random from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal, Protocol +from typing import Any, Protocol import einops import numpy as np @@ -167,49 +167,9 @@ def resolve_class(path: str) -> type[nn.Module]: return getattr(module, class_name) -def extract_batch_data( - batch_item: dict[str, Any] | tuple[Tensor, "..."] | Tensor, - input_key: str = "input_ids", -) -> Tensor: - """Extract input data from various batch formats. - - This utility function handles different batch formats commonly used across the codebase: - 1. Dictionary format: {"input_ids": Tensor, "..."} - common in LM tasks - 2. Tuple format: (input_tensor, labels) - common in SPD optimization - 3. Direct tensor: when batch is already the input tensor - - Args: - batch_item: The batch item from a data loader - input_key: Key to use for dictionary format (default: "input_ids") - - Returns: - The input tensor extracted from the batch - """ - assert isinstance(batch_item, dict | tuple | Tensor), ( - f"Unsupported batch format: {type(batch_item)}. Must be a dictionary, tuple, or tensor." - ) - if isinstance(batch_item, dict): - # Dictionary format: extract the specified key - if input_key not in batch_item: - available_keys = list(batch_item.keys()) - raise KeyError( - f"Key '{input_key}' not found in batch. Available keys: {available_keys}" - ) - tensor = batch_item[input_key] - elif isinstance(batch_item, tuple): - # Assume input is the first element - tensor = batch_item[0] - else: - # Direct tensor format - tensor = batch_item - - return tensor - - def calc_kl_divergence_lm( pred: Float[Tensor, "... vocab"], target: Float[Tensor, "... vocab"], - reduce: bool = True, ) -> Float[Tensor, ""] | Float[Tensor, "..."]: """Calculate the KL divergence between two logits. @@ -226,24 +186,7 @@ def calc_kl_divergence_lm( p = torch.softmax(target, dim=-1) # P kl_raw = F.kl_div(log_q, p, reduction="none") # P · (log P − log Q) kl = kl_raw.sum(dim=-1) - if reduce: - return kl.mean() # Σ_vocab / (batch·seq) - else: - return kl - - -def calc_sum_recon_loss_lm( - pred: Float[Tensor, "... vocab"], - target: Float[Tensor, "... vocab"], - loss_type: Literal["mse", "kl"], -) -> Float[Tensor, ""]: - """Calculate the reconstruction loss for a language model without reduction.""" - match loss_type: - case "mse": - loss = ((pred - target) ** 2).sum() - case "kl": - loss = calc_kl_divergence_lm(pred=pred, target=target, reduce=False).sum() - return loss + return kl.mean() # Σ_vocab / (batch·seq) def runtime_cast[T](type_: type[T], obj: Any) -> T: diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 04cb6727c..55efbb3d9 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -593,6 +593,7 @@ def create_view_and_report( _n_try_wandb_comm_errors = 0 +# this exists to stop infra issues from crashing training runs def try_wandb[**P, T](wandb_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T | None: """Attempts to call `wandb_fn` and if it fails with a wandb CommError, logs a warning and returns None. The choice of wandb CommError is to catch issues communicating with the wandb server but diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 508f69ac4..72ace2cc6 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -20,7 +20,13 @@ from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import HarvestCache, RunState, StateManager -from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig +from spd.configs import ( + Config, + LMTaskConfig, + ModulePatternInfoConfig, + ScheduleConfig, +) +from spd.models.batch_and_loss_fns import make_run_batch from spd.models.component_model import ComponentModel from spd.pretrain.models.gpt2_simple import GPT2Simple, GPT2SimpleConfig from spd.utils.module_utils import expand_module_patterns @@ -91,9 +97,8 @@ def app_with_state(): ModulePatternInfoConfig(module_pattern=p, C=C) for p in target_module_patterns ], pretrained_model_class="spd.pretrain.models.gpt2_simple.GPT2Simple", - pretrained_model_output_attr="idx_0", + output_extract=0, tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - output_loss_type="kl", lr_schedule=ScheduleConfig(start_val=1e-3), steps=1, batch_size=1, @@ -114,10 +119,10 @@ def app_with_state(): module_path_info = expand_module_patterns(target_model, config.module_info) model = ComponentModel( target_model=target_model, + run_batch=make_run_batch(config.output_extract), module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) model.eval() diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index fa32cc1e3..d2caeaccd 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -7,6 +7,7 @@ from jaxtyping import Float from torch import Tensor +from spd.models.batch_and_loss_fns import run_batch_passthrough from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo @@ -38,7 +39,9 @@ def forward(self, x: Tensor) -> Tensor: return x -def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: +def make_one_layer_component_model( + weight: Float[Tensor, "d_out d_in"], +) -> ComponentModel: """Create a ComponentModel with a single linear layer for testing. Args: @@ -55,10 +58,10 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo comp_model = ComponentModel( target_model=target, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_fn_hidden_dims=[2], ci_fn_type="mlp", - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -89,13 +92,13 @@ def make_two_layer_component_model( comp_model = ComponentModel( target_model=target, + run_batch=run_batch_passthrough, module_path_info=[ ModulePathInfo(module_path="fc1", C=1), ModulePathInfo(module_path="fc2", C=1), ], ci_fn_hidden_dims=[2], ci_fn_type="mlp", - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) diff --git a/tests/metrics/test_ci_masked_recon_layerwise_loss.py b/tests/metrics/test_ci_masked_recon_layerwise_loss.py index 00b8092b2..04ef2609e 100644 --- a/tests/metrics/test_ci_masked_recon_layerwise_loss.py +++ b/tests/metrics/test_ci_masked_recon_layerwise_loss.py @@ -1,6 +1,7 @@ import torch from spd.metrics import ci_masked_recon_layerwise_loss, ci_masked_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model, make_two_layer_component_model @@ -43,7 +44,11 @@ def test_two_layer_manual_calculation(self: object) -> None: # Calculate actual loss actual_loss = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( @@ -60,10 +65,18 @@ def test_layerwise_vs_all_layer(self: object) -> None: ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} loss_all = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) # For single layer, results should be the same diff --git a/tests/metrics/test_ci_masked_recon_loss.py b/tests/metrics/test_ci_masked_recon_loss.py index 3f1202425..7635d2757 100644 --- a/tests/metrics/test_ci_masked_recon_loss.py +++ b/tests/metrics/test_ci_masked_recon_loss.py @@ -1,6 +1,7 @@ import torch from spd.metrics import ci_masked_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model @@ -26,7 +27,11 @@ def test_manual_calculation(self: object) -> None: # Calculate actual loss actual_loss = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( @@ -45,10 +50,18 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_full + model=model, + batch=batch, + target_out=target_out, + ci=ci_full, + reconstruction_loss=recon_loss_mse, ) loss_half = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_half + model=model, + batch=batch, + target_out=target_out, + ci=ci_half, + reconstruction_loss=recon_loss_mse, ) # Different CI values should produce different losses diff --git a/tests/metrics/test_ci_masked_recon_subset_loss.py b/tests/metrics/test_ci_masked_recon_subset_loss.py index 4a9f870b7..9c5661b3f 100644 --- a/tests/metrics/test_ci_masked_recon_subset_loss.py +++ b/tests/metrics/test_ci_masked_recon_subset_loss.py @@ -5,6 +5,7 @@ from spd.configs import UniformKSubsetRoutingConfig from spd.metrics import ci_masked_recon_subset_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model @@ -77,11 +78,11 @@ def mock_sample_uniform_k_subset_routing_masks( for _ in range(2): actual_loss = ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) actual_losses.append(actual_loss.item()) diff --git a/tests/metrics/test_stochastic_recon_layerwise_loss.py b/tests/metrics/test_stochastic_recon_layerwise_loss.py index 3862d85f8..2f4d97515 100644 --- a/tests/metrics/test_stochastic_recon_layerwise_loss.py +++ b/tests/metrics/test_stochastic_recon_layerwise_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType from spd.metrics import stochastic_recon_layerwise_loss, stochastic_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model, make_two_layer_component_model @@ -105,11 +106,11 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( @@ -130,21 +131,21 @@ def test_layerwise_vs_full_loss(self: object) -> None: model=model, sampling="continuous", n_mask_samples=5, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) loss_layerwise = stochastic_recon_layerwise_loss( model=model, sampling="continuous", n_mask_samples=5, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) # For single layer, results should be the same diff --git a/tests/metrics/test_stochastic_recon_loss.py b/tests/metrics/test_stochastic_recon_loss.py index 594b55a7f..e20e25f84 100644 --- a/tests/metrics/test_stochastic_recon_loss.py +++ b/tests/metrics/test_stochastic_recon_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType from spd.metrics import stochastic_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model @@ -78,11 +79,11 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( diff --git a/tests/metrics/test_stochastic_recon_subset_loss.py b/tests/metrics/test_stochastic_recon_subset_loss.py index 484e3d49f..54c928c39 100644 --- a/tests/metrics/test_stochastic_recon_subset_loss.py +++ b/tests/metrics/test_stochastic_recon_subset_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType, UniformKSubsetRoutingConfig from spd.metrics import stochastic_recon_subset_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model @@ -92,12 +93,12 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=None, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( diff --git a/tests/scripts_run/test_grid_search.py b/tests/scripts_run/test_grid_search.py index a5f59c65d..51bd84618 100644 --- a/tests/scripts_run/test_grid_search.py +++ b/tests/scripts_run/test_grid_search.py @@ -336,7 +336,6 @@ def test_tms_config_with_loss_sweep(self): "coeff": 1.0, }, ], - "output_loss_type": "mse", "lr": 0.001, "steps": 1000, "batch_size": 32, @@ -386,7 +385,6 @@ def test_lm_config_with_loss_sweep(self): "eps": 1e-12, } ], - "output_loss_type": "kl", "lr": 0.001, "steps": 1000, "batch_size": 32, @@ -451,7 +449,6 @@ def test_full_sweep_workflow(self): "eps": 1e-12, } ], - "output_loss_type": "mse", "lr": 0.01, # Will be overridden "steps": 1000, "batch_size": 32, diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 22d60a8f0..3d483edd8 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -17,6 +17,7 @@ ) from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo +from spd.models.batch_and_loss_fns import run_batch_passthrough from spd.models.component_model import ( ComponentModel, SPDRunInfo, @@ -82,6 +83,7 @@ def test_correct_parameters_require_grad(): component_model = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ ModulePathInfo(module_path="linear1", C=4), ModulePathInfo(module_path="linear2", C=8), @@ -91,7 +93,6 @@ def test_correct_parameters_require_grad(): ], ci_fn_type="mlp", ci_fn_hidden_dims=[4], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -152,7 +153,6 @@ def test_from_run_info(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -171,10 +171,10 @@ def test_from_run_info(): module_path_info = expand_module_patterns(target_model, config.all_module_info) cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) @@ -278,10 +278,10 @@ def test_full_weight_delta_matches_target_behaviour(): C = 4 cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[4], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -310,10 +310,10 @@ def test_input_cache_captures_pre_weight_input(): cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path=p, C=2) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -345,10 +345,10 @@ def test_weight_deltas(): target_module_paths = ["embed", "mlp", "out"] cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path=p, C=3) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -380,10 +380,10 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -436,10 +436,10 @@ def forward(self, x: Tensor) -> Tensor: # wrapped in a component model that decomposes the prepended identity layer cm = ComponentModel( target_model=model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="linear.pre_identity", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -486,10 +486,10 @@ def forward(self, x: Tensor) -> Tensor: # wrapped in a component model that decomposes the layer cm = ComponentModel( target_model=model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index ddbfd8c0d..f1a76ae8e 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -34,7 +34,6 @@ {"classname": "CIMaskedReconLayerwiseLoss", "coeff": 1.0}, {"classname": "CIMaskedReconLoss", "coeff": 1.0}, ], - "output_loss_type": "kl", # --- Training --- "batch_size": 2, "steps": 20, @@ -54,7 +53,7 @@ # --- Pretrained model info --- "pretrained_model_class": "transformers.LlamaForCausalLM", "pretrained_model_name": "SimpleStories/SimpleStories-1.25M", - "pretrained_model_output_attr": "logits", + "output_extract": "logits", "tokenizer_name": "SimpleStories/SimpleStories-1.25M", # --- Task Specific --- "task_config": { diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index b66cfa31d..62c14ce12 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -14,8 +14,9 @@ StochasticReconLayerwiseLossConfig, StochasticReconLossConfig, ) -from spd.data import DatasetConfig, create_data_loader +from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import make_run_batch, recon_loss_kl from spd.run_spd import optimize from spd.utils.general_utils import resolve_class, set_seed @@ -55,7 +56,6 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=200), ], - output_loss_type="kl", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -78,7 +78,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="transformers.GPT2LMHeadModel", pretrained_model_path=None, pretrained_model_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - pretrained_model_output_attr="logits", + output_extract="logits", tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", # Task Specific task_config=LMTaskConfig( @@ -123,6 +123,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed, + collate_fn=lm_collate_fn, ) eval_data_config = DatasetConfig( @@ -140,16 +141,19 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed + 1, + collate_fn=lm_collate_fn, ) # Run optimize function + assert config.output_extract is not None optimize( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch(config.output_extract), + reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 0cd27f521..0392a4354 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -17,6 +17,7 @@ from spd.experiments.ih.configs import InductionModelConfig from spd.experiments.ih.model import InductionTransformer from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.general_utils import set_seed @@ -71,7 +72,6 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=200), ], - output_loss_type="kl", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -95,7 +95,6 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.ih.model.InductionTransformer", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, tokenizer_name=None, # Task Specific task_config=IHTaskConfig( @@ -133,7 +132,8 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=run_batch_first_element, + reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 844982415..8db454a87 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -13,6 +13,7 @@ from spd.experiments.resid_mlp.models import ResidMLP from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.general_utils import set_seed @@ -62,7 +63,6 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: identity_module_info=[ ModulePatternInfoConfig(module_pattern="layers.*.mlp_in", C=10), ], - output_loss_type="mse", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -82,7 +82,6 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.resid_mlp.models.ResidMLP", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, tokenizer_name=None, # Task Specific task_config=ResidMLPTaskConfig( @@ -129,7 +128,8 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=run_batch_first_element, + reconstruction_loss=recon_loss_mse, out_dir=tmp_path, ) diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 69a546f6e..833d4671a 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -16,6 +16,7 @@ stochastic_recon_loss, stochastic_recon_subset_loss, ) +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo @@ -39,10 +40,10 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel comp_model = ComponentModel( target_model=target, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_fn_hidden_dims=[2], ci_fn_type="mlp", - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -279,10 +280,10 @@ def test_mse_loss_basic(self: object) -> None: result = ci_masked_recon_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=recon_loss_mse, ) # Since we're using a simple identity-like weight, and CI is 1, @@ -304,10 +305,10 @@ def test_kl_loss_basic(self: object) -> None: result = ci_masked_recon_loss( model=model, - output_loss_type="kl", batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -324,10 +325,18 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_full + model=model, + batch=batch, + target_out=target_out, + ci=ci_full, + reconstruction_loss=recon_loss_mse, ) loss_half = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_half + model=model, + batch=batch, + target_out=target_out, + ci=ci_half, + reconstruction_loss=recon_loss_mse, ) # Different CI values should produce different losses @@ -346,10 +355,10 @@ def test_layerwise_basic(self: object) -> None: result = ci_masked_recon_layerwise_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=recon_loss_mse, ) # Layerwise should produce a valid loss @@ -366,10 +375,18 @@ def test_layerwise_vs_all_layer(self: object) -> None: ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} loss_all = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) # For single layer, results should be the same @@ -388,11 +405,11 @@ def test_subset_basic(self: object) -> None: result = ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) # Subset routing should produce a valid loss @@ -411,11 +428,11 @@ def test_subset_stochastic_behavior(self: object) -> None: losses = [ ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) for _ in range(3) ] @@ -439,11 +456,11 @@ def test_continuous_sampling_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -462,11 +479,11 @@ def test_binomial_sampling_basic(self: object) -> None: model=model, sampling="binomial", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -487,11 +504,11 @@ def test_multiple_mask_samples(self: object) -> None: model=model, sampling="continuous", n_mask_samples=n_samples, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -509,22 +526,22 @@ def test_with_and_without_delta_component(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) loss_without_delta = stochastic_recon_loss( model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) # Both should be valid @@ -547,11 +564,11 @@ def test_layerwise_stochastic_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -571,11 +588,11 @@ def test_layerwise_multiple_samples(self: object) -> None: model=model, sampling="continuous", n_mask_samples=n_samples, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -595,12 +612,12 @@ def test_subset_stochastic_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -619,12 +636,12 @@ def test_subset_with_binomial_sampling(self: object) -> None: model=model, sampling="binomial", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -644,12 +661,12 @@ def test_subset_stochastic_variability(self: object) -> None: model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) for _ in range(3) ] diff --git a/tests/test_tms.py b/tests/test_tms.py index b51ab2109..f34a8b693 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -18,6 +18,7 @@ from spd.experiments.tms.models import TMSModel from spd.experiments.tms.train_tms import get_model_and_dataloader, train from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.general_utils import set_seed @@ -68,7 +69,6 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=1.0), ], - output_loss_type="mse", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.0, final_val_frac=0.0 @@ -91,7 +91,6 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.tms.models.TMSModel", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, tokenizer_name=None, # Task Specific task_config=TMSTaskConfig( @@ -137,7 +136,8 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + run_batch=run_batch_first_element, + reconstruction_loss=recon_loss_mse, out_dir=tmp_path, tied_weights=tied_weights, ) diff --git a/tests/test_wandb_run_loading.py b/tests/test_wandb_run_loading.py index 55653b7c8..3c8fd2efd 100644 --- a/tests/test_wandb_run_loading.py +++ b/tests/test_wandb_run_loading.py @@ -11,16 +11,6 @@ from spd.registry import EXPERIMENT_REGISTRY from spd.utils.wandb_utils import parse_wandb_run_path - -def from_run_info(canonical_run: str) -> ComponentModel: - run_info = SPDRunInfo.from_path(canonical_run) - return ComponentModel.from_run_info(run_info) - - -def from_pretrained(canonical_run: str) -> ComponentModel: - return ComponentModel.from_pretrained(canonical_run) - - CANONICAL_EXPS = [ (exp_name, exp_config.canonical_run) for exp_name, exp_config in EXPERIMENT_REGISTRY.items() @@ -32,17 +22,11 @@ def from_pretrained(canonical_run: str) -> ComponentModel: @pytest.mark.slow @pytest.mark.parametrize("exp_name, canonical_run", CANONICAL_EXPS) def test_loading_from_wandb(exp_name: str, canonical_run: str) -> None: - # We put both from_run_info and from_pretrained in the same test to avoid distributed read - # errors from the same wandb cache - try: - from_run_info(canonical_run) - except Exception as e: - e.add_note(f"Error with from_run_info for {exp_name} from {canonical_run}") - raise e try: - from_pretrained(canonical_run) + run_info = SPDRunInfo.from_path(canonical_run) + ComponentModel.from_run_info(run_info) except Exception as e: - e.add_note(f"Error with from_pretrained for {exp_name} from {canonical_run}") + e.add_note(f"Error loading {exp_name} from {canonical_run}") raise e