Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c0c3a17
wip: Introduce generic types for batch and output in ComponentModel
claude-spd1 Jan 29, 2026
bd03f57
wip: Refactor ComponentModel loading to use task-specific factories
claude-spd1 Jan 29, 2026
e712573
>>>>>>> dev/app
claude-spd1 Feb 2, 2026
2312fe6
Revert ">>>>>>> dev/app"
claude-spd1 Feb 2, 2026
357898a
wip: Replace getattr with cast for type safety on model attributes
claude-spd1 Feb 2, 2026
8fdfa81
Remove accidentally added files
danbraunai-goodfire Feb 6, 2026
539edb2
Remove LogitsOnlyWrapper and add back ComponentModel.from_run_info
danbraunai-goodfire Feb 6, 2026
b09e7dc
Merge branch 'main' into feature/generic-shapes
danbraunai-goodfire Feb 6, 2026
249b2a3
Merge main into feature/generic-shapes
danbraunai-goodfire Feb 9, 2026
d2c5465
Add extract_tensor_output config arg
danbraunai-goodfire Feb 9, 2026
52e9d19
Replace accessor DSL with RunBatch protocol (#375)
ocg-goodfire Feb 10, 2026
44ad9a4
Merge branch 'main' into feature/generic-shapes
danbraunai-goodfire Feb 10, 2026
2c98564
Merge main into feature/generic-shapes
danbraunai-goodfire Feb 11, 2026
a94b580
Fix typing
danbraunai-goodfire Feb 11, 2026
ca22236
Fix typing for StochasticReconSubsetCEAndKL
danbraunai-goodfire Feb 11, 2026
e695241
FIx non-deterministic test in CI
danbraunai-goodfire Feb 11, 2026
8b35022
Remove explicit __call__
danbraunai-goodfire Feb 11, 2026
4fe40e3
Remove OutputT as it was always a Tensor
danbraunai-goodfire Feb 11, 2026
1f51a37
Remove generics and simplify output_extract
danbraunai-goodfire Feb 11, 2026
3042ce0
Revert various changes that were made
danbraunai-goodfire Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions spd/app/backend/routers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from spd.data import DatasetConfig, create_data_loader
from spd.log import logger
from spd.utils.distributed_utils import get_device
from spd.utils.general_utils import extract_batch_data

# =============================================================================
# Schemas
Expand Down Expand Up @@ -120,7 +119,7 @@ def generate() -> Generator[str]:
if added_count >= n_prompts:
break

tokens = extract_batch_data(batch).to(DEVICE)
tokens = batch["input_ids"].to(DEVICE)
batch_size, n_seq = tokens.shape

# Compute CI for the whole batch
Expand Down
2 changes: 1 addition & 1 deletion spd/clustering/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT

spd_run = SPDRunInfo.from_path(model_path)
cfg = spd_run.config
component_model = ComponentModel.from_pretrained(spd_run.checkpoint_path)
component_model = ComponentModel.from_run_info(spd_run)

assert isinstance(cfg.task_config, ResidMLPTaskConfig), (
f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }"
Expand Down
24 changes: 17 additions & 7 deletions spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,6 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]:
),
)
)
output_loss_type: Literal["mse", "kl"] = Field(
...,
description="Metric used to measure recon error between model outputs and targets",
)

# --- Training ---
lr_schedule: ScheduleConfig = Field(..., description="Learning rate schedule configuration")
steps: NonNegativeInt = Field(..., description="Total number of optimisation steps")
Expand Down Expand Up @@ -566,9 +561,10 @@ def microbatch_size(self) -> PositiveInt:
default=None,
description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'",
)
pretrained_model_output_attr: str | None = Field(
output_extract: int | str | None = Field(
default=None,
description="Name of the attribute on the forward output that contains logits or activations",
description="How to extract tensor from model output. None = raw output, int = index into "
"output tuple, str = attribute name.",
)
tokenizer_name: str | None = Field(
default=None,
Expand Down Expand Up @@ -607,6 +603,7 @@ def microbatch_size(self) -> PositiveInt:
"lr_exponential_halflife",
"out_dir",
"n_examples_until_dead",
"output_loss_type",
]
RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {
"grad_clip_norm": "grad_clip_norm_components",
Expand Down Expand Up @@ -652,6 +649,19 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str,
"simple_stories_train.models.", "spd.pretrain.models.", 1
)

# Migrate old pretrained_model_output_attr to output_extract
if "pretrained_model_output_attr" in config_dict:
old_val = config_dict.pop("pretrained_model_output_attr")
match old_val:
case None:
pass
case "idx_0":
config_dict["output_extract"] = 0
case "logits":
config_dict["output_extract"] = "logits"
case _:
raise ValueError(f"Unknown pretrained_model_output_attr: {old_val!r}")

if "eval_batch_size" not in config_dict:
config_dict["eval_batch_size"] = config_dict["batch_size"]
if "train_log_freq" not in config_dict:
Expand Down
16 changes: 13 additions & 3 deletions spd/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections.abc import Callable, Generator
from typing import Any

import numpy as np
import torch
from datasets import Dataset, IterableDataset, load_dataset
from jaxtyping import Int
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -152,7 +154,8 @@ def create_data_loader(
dist_state: DistributedState | None = None,
global_seed: int = 0,
to_lower: bool = True,
) -> tuple[DataLoader[Any], PreTrainedTokenizer]:
collate_fn: Callable[..., Any] | None = None,
) -> tuple[DataLoader[Int[Tensor, "batch seq"]], PreTrainedTokenizer]:
"""Create a DataLoader for the given dataset.

Uses PyTorch's DistributedSampler to ensure each rank gets the correct
Expand Down Expand Up @@ -255,7 +258,7 @@ def create_data_loader(
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)

loader = DataLoader[Dataset | IterableDataset](
loader = DataLoader[Int[Tensor, "batch seq"]](
torch_dataset, # pyright: ignore[reportArgumentType]
batch_size=batch_size,
sampler=sampler,
Expand All @@ -264,11 +267,17 @@ def create_data_loader(
),
drop_last=True,
generator=generator,
collate_fn=collate_fn,
)
return loader, tokenizer


def loop_dataloader[T](dl: DataLoader[T]):
def lm_collate_fn(batch: list[dict[str, Tensor]]) -> Tensor:
"""Collate function that extracts input_ids tensors from HuggingFace dataset dicts."""
return torch.stack([item["input_ids"] for item in batch])


def loop_dataloader[T](dl: DataLoader[T]) -> Generator[T]:
"""Loop over a dataloader, resetting the iterator when it is exhausted.

Ensures that each epoch gets different data, even when using a distributed sampler.
Expand Down Expand Up @@ -311,6 +320,7 @@ def train_loader_and_tokenizer(
batch_size=batch_size,
buffer_size=task_config.buffer_size,
global_seed=config.seed,
collate_fn=lm_collate_fn,
)

assert isinstance(tokenizer, PreTrainedTokenizerBase)
Expand Down
3 changes: 1 addition & 2 deletions spd/dataset_attributions/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from spd.log import logger
from spd.models.component_model import ComponentModel, SPDRunInfo
from spd.utils.distributed_utils import get_device
from spd.utils.general_utils import extract_batch_data
from spd.utils.wandb_utils import parse_wandb_run_path


Expand Down Expand Up @@ -206,7 +205,7 @@ def harvest_attributions(
# Skip batches not assigned to this rank
if world_size is not None and batch_idx % world_size != rank:
continue
batch = extract_batch_data(batch_data).to(device)
batch = batch_data["input_ids"].to(device)
harvester.process_batch(batch)

logger.info(
Expand Down
94 changes: 53 additions & 41 deletions spd/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Iterator
from typing import Any

from jaxtyping import Float, Int
from PIL import Image
from torch import Tensor
from torch.types import Number
Expand Down Expand Up @@ -37,34 +36,37 @@
UnmaskedReconLossConfig,
UVPlotsConfig,
)
from spd.metrics import UnmaskedReconLoss
from spd.metrics import (
CI_L0,
CEandKLLosses,
CIHistograms,
CIMaskedReconLayerwiseLoss,
CIMaskedReconLoss,
CIMaskedReconSubsetLoss,
CIMeanPerComponent,
ComponentActivationDensity,
FaithfulnessLoss,
IdentityCIError,
ImportanceMinimalityLoss,
PermutedCIPlots,
PGDReconLayerwiseLoss,
PGDReconLoss,
PGDReconSubsetLoss,
StochasticHiddenActsReconLoss,
StochasticReconLayerwiseLoss,
StochasticReconLoss,
StochasticReconSubsetCEAndKL,
StochasticReconSubsetLoss,
UnmaskedReconLoss,
UVPlots,
)
from spd.metrics.base import Metric
from spd.metrics.ce_and_kl_losses import CEandKLLosses
from spd.metrics.ci_histograms import CIHistograms
from spd.metrics.ci_l0 import CI_L0
from spd.metrics.ci_masked_recon_layerwise_loss import CIMaskedReconLayerwiseLoss
from spd.metrics.ci_masked_recon_loss import CIMaskedReconLoss
from spd.metrics.ci_masked_recon_subset_loss import CIMaskedReconSubsetLoss
from spd.metrics.ci_mean_per_component import CIMeanPerComponent
from spd.metrics.component_activation_density import ComponentActivationDensity
from spd.metrics.faithfulness_loss import FaithfulnessLoss
from spd.metrics.identity_ci_error import IdentityCIError
from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss
from spd.metrics.permuted_ci_plots import PermutedCIPlots
from spd.metrics.pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss
from spd.metrics.pgd_masked_recon_loss import PGDReconLoss
from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss
from spd.metrics.pgd_utils import CreateDataIter, calc_multibatch_pgd_masked_recon_loss
from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss
from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss
from spd.metrics.stochastic_recon_loss import StochasticReconLoss
from spd.metrics.stochastic_recon_subset_ce_and_kl import StochasticReconSubsetCEAndKL
from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss
from spd.metrics.uv_plots import UVPlots
from spd.models.batch_and_loss_fns import ReconstructionLoss
from spd.models.component_model import ComponentModel, OutputWithCache
from spd.routing import AllLayersRouter, get_subset_router
from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed
from spd.utils.general_utils import dict_safe_update_, extract_batch_data
from spd.utils.general_utils import dict_safe_update_

MetricOutType = dict[str, str | Number | Image.Image | CustomChart]
DistMetricOutType = dict[str, str | float | Image.Image | CustomChart]
Expand Down Expand Up @@ -121,6 +123,7 @@ def init_metric(
model: ComponentModel,
run_config: Config,
device: str,
reconstruction_loss: ReconstructionLoss,
) -> Metric:
match cfg:
case ImportanceMinimalityLossConfig():
Expand Down Expand Up @@ -158,16 +161,20 @@ def init_metric(
metric = CIMaskedReconSubsetLoss(
model=model,
device=device,
output_loss_type=run_config.output_loss_type,
routing=cfg.routing,
reconstruction_loss=reconstruction_loss,
)
case CIMaskedReconLayerwiseLossConfig():
metric = CIMaskedReconLayerwiseLoss(
model=model, device=device, output_loss_type=run_config.output_loss_type
model=model,
device=device,
reconstruction_loss=reconstruction_loss,
)
case CIMaskedReconLossConfig():
metric = CIMaskedReconLoss(
model=model, device=device, output_loss_type=run_config.output_loss_type
model=model,
device=device,
reconstruction_loss=reconstruction_loss,
)
case CIMeanPerComponentConfig():
metric = CIMeanPerComponent(model=model, device=device)
Expand Down Expand Up @@ -196,7 +203,7 @@ def init_metric(
sampling=run_config.sampling,
use_delta_component=run_config.use_delta_component,
n_mask_samples=run_config.n_mask_samples,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
)
case StochasticReconLossConfig():
metric = StochasticReconLoss(
Expand All @@ -205,7 +212,7 @@ def init_metric(
sampling=run_config.sampling,
use_delta_component=run_config.use_delta_component,
n_mask_samples=run_config.n_mask_samples,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
)
case StochasticReconSubsetLossConfig():
metric = StochasticReconSubsetLoss(
Expand All @@ -214,33 +221,33 @@ def init_metric(
sampling=run_config.sampling,
use_delta_component=run_config.use_delta_component,
n_mask_samples=run_config.n_mask_samples,
output_loss_type=run_config.output_loss_type,
routing=cfg.routing,
reconstruction_loss=reconstruction_loss,
)
case PGDReconLossConfig():
metric = PGDReconLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
pgd_config=cfg,
reconstruction_loss=reconstruction_loss,
)
case PGDReconSubsetLossConfig():
metric = PGDReconSubsetLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
pgd_config=cfg,
routing=cfg.routing,
reconstruction_loss=reconstruction_loss,
)
case PGDReconLayerwiseLossConfig():
metric = PGDReconLayerwiseLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
pgd_config=cfg,
reconstruction_loss=reconstruction_loss,
)
case StochasticReconSubsetCEAndKLConfig():
metric = StochasticReconSubsetCEAndKL(
Expand Down Expand Up @@ -271,7 +278,7 @@ def init_metric(
metric = UnmaskedReconLoss(
model=model,
device=device,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
)

case _:
Expand All @@ -284,18 +291,25 @@ def init_metric(
def evaluate(
eval_metric_configs: list[MetricConfigType],
model: ComponentModel,
eval_iterator: Iterator[Int[Tensor, "..."] | tuple[Float[Tensor, "..."], Float[Tensor, "..."]]],
eval_iterator: Iterator[Any],
device: str,
run_config: Config,
slow_step: bool,
n_eval_steps: int,
current_frac_of_training: float,
reconstruction_loss: ReconstructionLoss,
) -> MetricOutType:
"""Run evaluation and return a mapping of metric names to values/images."""

metrics: list[Metric] = []
for cfg in eval_metric_configs:
metric = init_metric(cfg=cfg, model=model, run_config=run_config, device=device)
metric = init_metric(
cfg=cfg,
model=model,
run_config=run_config,
device=device,
reconstruction_loss=reconstruction_loss,
)
if metric.slow and not slow_step:
continue
metrics.append(metric)
Expand All @@ -304,8 +318,7 @@ def evaluate(
weight_deltas = model.calc_weight_deltas()

for _ in range(n_eval_steps):
batch_raw = next(eval_iterator)
batch = extract_batch_data(batch_raw).to(device)
batch = next(eval_iterator)

target_output: OutputWithCache = model(batch, cache_type="input")
ci = model.calc_causal_importances(
Expand Down Expand Up @@ -344,8 +357,8 @@ def evaluate_multibatch_pgd(
model: ComponentModel,
create_data_iter: CreateDataIter,
config: Config,
batch_dims: tuple[int, ...],
device: str,
reconstruction_loss: ReconstructionLoss,
) -> dict[str, float]:
"""Calculate multibatch PGD metrics."""
weight_deltas = model.calc_weight_deltas() if config.use_delta_component else None
Expand All @@ -367,11 +380,10 @@ def evaluate_multibatch_pgd(
model=model,
weight_deltas=weight_deltas,
create_data_iter=create_data_iter,
output_loss_type=config.output_loss_type,
router=router,
sampling=config.sampling,
use_delta_component=config.use_delta_component,
batch_dims=batch_dims,
device=device,
reconstruction_loss=reconstruction_loss,
).item()
return metrics
1 change: 0 additions & 1 deletion spd/experiments/ih/ih_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ ci_recon_layerwise_coeff: null
stochastic_recon_layerwise_coeff: 1
importance_minimality_coeff: 1e-2
pnorm: 0.1
output_loss_type: kl
ci_fn_type: "vector_mlp"
ci_fn_hidden_dims: [128]

Expand Down
Loading
Loading