Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions scripts/harvest_transcoders_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Example: Harvest transcoder activations using spd-harvest.
#
# Loads trained BatchTopK transcoders from wandb artifacts and runs the generic
# harvest pipeline to collect activation statistics (firing densities, token PMI,
# activation examples).
#
# Usage:
# spd-harvest scripts/harvest_transcoders_example.yaml

config:
method_config:
type: TranscoderHarvestConfig
base_model_path: "wandb:goodfire/spd/t-32d1bb3b"
artifact_paths:
"h.0.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L0_5d5b1f_checkpoint_final:v0"
"h.1.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L1_c208d7_checkpoint_final:v0"
"h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0"
"h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0"
n_batches: 20
batch_size: 8
activation_examples_per_component: 20
activation_context_tokens_per_side: 10
pmi_token_top_k: 10

n_gpus: 1
62 changes: 52 additions & 10 deletions spd/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Harvest method adapters: method-specific logic for the generic harvest pipeline.

Each decomposition method (SPD, CLT, MOLT) provides an adapter that knows how to:
Each decomposition method (SPD, CLT, MOLT, Transcoder) provides an adapter that knows how to:
- Load the model and build a dataloader
- Compute firings and activations from a batch (harvest_fn)
- Report layer structure and vocab size
Expand All @@ -9,16 +9,58 @@
"""

from spd.adapters.base import DecompositionAdapter
from spd.harvest.config import DecompositionMethodHarvestConfig


def adapter_from_id(id: str) -> DecompositionAdapter:
from spd.adapters.spd import SPDAdapter
def adapter_from_config(method_config: DecompositionMethodHarvestConfig) -> DecompositionAdapter:
from spd.harvest.config import (
CLTHarvestConfig,
MOLTHarvestConfig,
SPDHarvestConfig,
TranscoderHarvestConfig,
)

if id.startswith("s-"):
return SPDAdapter(id)
elif id.startswith("clt-"):
raise NotImplementedError("CLT adapter not implemented yet")
elif id.startswith("molt-"):
raise NotImplementedError("MOLT adapter not implemented yet")
match method_config:
case SPDHarvestConfig():
from spd.adapters.spd import SPDAdapter

raise ValueError(f"Unsupported decomposition ID: {id}")
return SPDAdapter(method_config.id)
case TranscoderHarvestConfig():
from spd.adapters.transcoder import TranscoderAdapter

return TranscoderAdapter(method_config)
case CLTHarvestConfig():
raise NotImplementedError("CLT adapter not implemented yet")
case MOLTHarvestConfig():
raise NotImplementedError("MOLT adapter not implemented yet")


def adapter_from_id(decomposition_id: str) -> DecompositionAdapter:
"""Construct an adapter from a decomposition ID (e.g. "s-abc123", "tc-1a2b3c4d").

For SPD runs, the ID is sufficient. For other methods, recovers the full
method config from the harvest DB (which is always populated before downstream
steps like autointerp run).
"""
if decomposition_id.startswith("s-"):
from spd.adapters.spd import SPDAdapter

return SPDAdapter(decomposition_id)

return adapter_from_config(_load_method_config(decomposition_id))


def _load_method_config(decomposition_id: str) -> DecompositionMethodHarvestConfig:
from pydantic import TypeAdapter

from spd.harvest.repo import HarvestRepo

repo = HarvestRepo.open_most_recent(decomposition_id)
assert repo is not None, (
f"No harvest data found for {decomposition_id!r}. "
f"Run spd-harvest first to populate the method config."
)
config_dict = repo.get_config()
method_config_raw = config_dict["method_config"]
ta = TypeAdapter(DecompositionMethodHarvestConfig)
return ta.validate_python(method_config_raw)
68 changes: 68 additions & 0 deletions spd/adapters/encoder_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Encoder configuration for transcoder architectures.

Originally by Bart Bussmann, vendored from https://github.com/bartbussmann/nn_decompositions (MIT license).
Only EncoderConfig is used; CLTConfig and SAEConfig are omitted.
"""

from dataclasses import dataclass, field
from typing import Literal

import torch


@dataclass
class EncoderConfig:
"""Base config for encoder architectures (SAE and Transcoder)."""

# Architecture
input_size: int
output_size: int
dict_size: int = 12288

# Encoder type
encoder_type: Literal["vanilla", "topk", "batchtopk", "jumprelu"] = "topk"

# Training
seed: int = 49
batch_size: int = 4096
lr: float = 3e-4
num_tokens: int = int(1e9)
l1_coeff: float = 0.0
beta1: float = 0.9
beta2: float = 0.99
max_grad_norm: float = 1.0

# Device
device: str = "cuda:0"
dtype: torch.dtype = field(default=torch.float32)

# Dead feature tracking
n_batches_to_dead: int = 50

# Optional features
input_unit_norm: bool = False
pre_enc_bias: bool = False

# TopK specific
top_k: int = 32
top_k_aux: int = 512
aux_penalty: float = 1 / 32

# JumpReLU specific
bandwidth: float = 0.001

# Logging
run_name: str | None = None
wandb_project: str = "encoders"
perf_log_freq: int = 1000
checkpoint_freq: int | Literal["final"] = "final"
n_eval_seqs: int = 8

@property
def name(self) -> str:
if self.run_name is not None:
return self.run_name
base = f"{self.dict_size}_{self.encoder_type}"
if self.encoder_type in ("topk", "batchtopk"):
base += f"_k{self.top_k}"
return f"{base}_{self.lr}"
128 changes: 128 additions & 0 deletions spd/adapters/transcoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Transcoder adapter: loads trained transcoders from wandb artifacts."""

import json
from functools import cached_property
from pathlib import Path
from typing import Any, override

import torch
import wandb
from torch.utils.data import DataLoader

from spd.adapters.base import DecompositionAdapter
from spd.adapters.encoder_config import EncoderConfig
from spd.adapters.transcoders import (
BatchTopKTranscoder,
JumpReLUTranscoder,
SharedTranscoder,
TopKTranscoder,
VanillaTranscoder,
)
from spd.autointerp.schemas import ModelMetadata
from spd.data import DatasetConfig, create_data_loader
from spd.harvest.config import TranscoderHarvestConfig
from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP
from spd.pretrain.run_info import PretrainRunInfo
from spd.topology import TransformerTopology

_ENCODER_CLASSES: dict[str, type[SharedTranscoder]] = {
"vanilla": VanillaTranscoder,
"topk": TopKTranscoder,
"batchtopk": BatchTopKTranscoder,
"jumprelu": JumpReLUTranscoder,
}


def _load_transcoder(checkpoint_dir: Path, device: str) -> SharedTranscoder:
with open(checkpoint_dir / "config.json") as f:
cfg_dict: dict[str, Any] = json.load(f)
cfg_dict["dtype"] = getattr(torch, cfg_dict.get("dtype", "torch.float32").replace("torch.", ""))
cfg_dict["device"] = device
cfg = EncoderConfig(**cfg_dict)
encoder = _ENCODER_CLASSES[cfg.encoder_type](cfg)
encoder.load_state_dict(torch.load(checkpoint_dir / "encoder.pt", map_location=device))
encoder.eval()
return encoder


def _download_artifact(artifact_path: str, dest: Path) -> Path:
if dest.exists() and (dest / "encoder.pt").exists():
return dest
api = wandb.Api()
artifact = api.artifact(artifact_path)
artifact.download(root=str(dest))
return dest


class TranscoderAdapter(DecompositionAdapter):
def __init__(self, config: TranscoderHarvestConfig):
self._config = config

@cached_property
def _run_info(self) -> PretrainRunInfo:
return PretrainRunInfo.from_path(self._config.base_model_path)

@cached_property
def base_model(self) -> LlamaSimpleMLP:
return LlamaSimpleMLP.from_run_info(self._run_info)

@cached_property
def _topology(self) -> TransformerTopology:
return TransformerTopology(self.base_model)

@cached_property
def transcoders(self) -> dict[str, SharedTranscoder]:
result: dict[str, SharedTranscoder] = {}
for module_path, artifact_path in self._config.artifact_paths.items():
safe_name = artifact_path.replace("/", "_").replace(":", "_")
dest = Path(f"checkpoints/tc_{safe_name}")
checkpoint_dir = _download_artifact(artifact_path, dest)
result[module_path] = _load_transcoder(checkpoint_dir, "cpu")
return result

@property
@override
def decomposition_id(self) -> str:
return self._config.id

@property
@override
def vocab_size(self) -> int:
return self.base_model.config.vocab_size

@property
@override
def layer_activation_sizes(self) -> list[tuple[str, int]]:
return [(path, tc.dict_size) for path, tc in self.transcoders.items()]

@property
@override
def tokenizer_name(self) -> str:
tok = self._run_info.hf_tokenizer_path
assert tok is not None, "base model run missing hf_tokenizer_path"
return tok

@property
@override
def model_metadata(self) -> ModelMetadata:
ds_cfg = self._run_info.config_dict.get("train_dataset_config", {})
model_cls = type(self.base_model)
return ModelMetadata(
n_blocks=self._topology.n_blocks,
model_class=f"{model_cls.__module__}.{model_cls.__qualname__}",
dataset_name=ds_cfg.get("name", "unknown"),
layer_descriptions={
path: self._topology.target_to_canon(path) for path in self.transcoders
},
)

@override
def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]:
ds_cfg = self._run_info.config_dict["train_dataset_config"]
dataset_config = DatasetConfig.model_validate(
{**ds_cfg, "streaming": True, "n_ctx": self.base_model.config.block_size}
)
loader, _ = create_data_loader(
dataset_config=dataset_config, batch_size=batch_size, buffer_size=1000
)
return loader
Loading