diff --git a/scripts/harvest_transcoders_example.yaml b/scripts/harvest_transcoders_example.yaml new file mode 100644 index 000000000..f586f604d --- /dev/null +++ b/scripts/harvest_transcoders_example.yaml @@ -0,0 +1,25 @@ +# Example: Harvest transcoder activations using spd-harvest. +# +# Loads trained BatchTopK transcoders from wandb artifacts and runs the generic +# harvest pipeline to collect activation statistics (firing densities, token PMI, +# activation examples). +# +# Usage: +# spd-harvest scripts/harvest_transcoders_example.yaml + +config: + method_config: + type: TranscoderHarvestConfig + base_model_path: "wandb:goodfire/spd/t-32d1bb3b" + artifact_paths: + "h.0.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L0_5d5b1f_checkpoint_final:v0" + "h.1.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L1_c208d7_checkpoint_final:v0" + "h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0" + "h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0" + n_batches: 20 + batch_size: 8 + activation_examples_per_component: 20 + activation_context_tokens_per_side: 10 + pmi_token_top_k: 10 + +n_gpus: 1 diff --git a/spd/adapters/__init__.py b/spd/adapters/__init__.py index aded4d188..b5dcb0765 100644 --- a/spd/adapters/__init__.py +++ b/spd/adapters/__init__.py @@ -1,6 +1,6 @@ """Harvest method adapters: method-specific logic for the generic harvest pipeline. -Each decomposition method (SPD, CLT, MOLT) provides an adapter that knows how to: +Each decomposition method (SPD, CLT, MOLT, Transcoder) provides an adapter that knows how to: - Load the model and build a dataloader - Compute firings and activations from a batch (harvest_fn) - Report layer structure and vocab size @@ -9,16 +9,58 @@ """ from spd.adapters.base import DecompositionAdapter +from spd.harvest.config import DecompositionMethodHarvestConfig -def adapter_from_id(id: str) -> DecompositionAdapter: - from spd.adapters.spd import SPDAdapter +def adapter_from_config(method_config: DecompositionMethodHarvestConfig) -> DecompositionAdapter: + from spd.harvest.config import ( + CLTHarvestConfig, + MOLTHarvestConfig, + SPDHarvestConfig, + TranscoderHarvestConfig, + ) - if id.startswith("s-"): - return SPDAdapter(id) - elif id.startswith("clt-"): - raise NotImplementedError("CLT adapter not implemented yet") - elif id.startswith("molt-"): - raise NotImplementedError("MOLT adapter not implemented yet") + match method_config: + case SPDHarvestConfig(): + from spd.adapters.spd import SPDAdapter - raise ValueError(f"Unsupported decomposition ID: {id}") + return SPDAdapter(method_config.id) + case TranscoderHarvestConfig(): + from spd.adapters.transcoder import TranscoderAdapter + + return TranscoderAdapter(method_config) + case CLTHarvestConfig(): + raise NotImplementedError("CLT adapter not implemented yet") + case MOLTHarvestConfig(): + raise NotImplementedError("MOLT adapter not implemented yet") + + +def adapter_from_id(decomposition_id: str) -> DecompositionAdapter: + """Construct an adapter from a decomposition ID (e.g. "s-abc123", "tc-1a2b3c4d"). + + For SPD runs, the ID is sufficient. For other methods, recovers the full + method config from the harvest DB (which is always populated before downstream + steps like autointerp run). + """ + if decomposition_id.startswith("s-"): + from spd.adapters.spd import SPDAdapter + + return SPDAdapter(decomposition_id) + + return adapter_from_config(_load_method_config(decomposition_id)) + + +def _load_method_config(decomposition_id: str) -> DecompositionMethodHarvestConfig: + from pydantic import TypeAdapter + + from spd.harvest.repo import HarvestRepo + + repo = HarvestRepo.open_most_recent(decomposition_id) + assert repo is not None, ( + f"No harvest data found for {decomposition_id!r}. " + f"Run spd-harvest first to populate the method config." + ) + config_dict = repo.get_config() + method_config_raw = config_dict["method_config"] + ta = TypeAdapter(DecompositionMethodHarvestConfig) + return ta.validate_python(method_config_raw) diff --git a/spd/adapters/encoder_config.py b/spd/adapters/encoder_config.py new file mode 100644 index 000000000..c2625c77e --- /dev/null +++ b/spd/adapters/encoder_config.py @@ -0,0 +1,68 @@ +"""Encoder configuration for transcoder architectures. + +Originally by Bart Bussmann, vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). +Only EncoderConfig is used; CLTConfig and SAEConfig are omitted. +""" + +from dataclasses import dataclass, field +from typing import Literal + +import torch + + +@dataclass +class EncoderConfig: + """Base config for encoder architectures (SAE and Transcoder).""" + + # Architecture + input_size: int + output_size: int + dict_size: int = 12288 + + # Encoder type + encoder_type: Literal["vanilla", "topk", "batchtopk", "jumprelu"] = "topk" + + # Training + seed: int = 49 + batch_size: int = 4096 + lr: float = 3e-4 + num_tokens: int = int(1e9) + l1_coeff: float = 0.0 + beta1: float = 0.9 + beta2: float = 0.99 + max_grad_norm: float = 1.0 + + # Device + device: str = "cuda:0" + dtype: torch.dtype = field(default=torch.float32) + + # Dead feature tracking + n_batches_to_dead: int = 50 + + # Optional features + input_unit_norm: bool = False + pre_enc_bias: bool = False + + # TopK specific + top_k: int = 32 + top_k_aux: int = 512 + aux_penalty: float = 1 / 32 + + # JumpReLU specific + bandwidth: float = 0.001 + + # Logging + run_name: str | None = None + wandb_project: str = "encoders" + perf_log_freq: int = 1000 + checkpoint_freq: int | Literal["final"] = "final" + n_eval_seqs: int = 8 + + @property + def name(self) -> str: + if self.run_name is not None: + return self.run_name + base = f"{self.dict_size}_{self.encoder_type}" + if self.encoder_type in ("topk", "batchtopk"): + base += f"_k{self.top_k}" + return f"{base}_{self.lr}" diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py new file mode 100644 index 000000000..457c7afaa --- /dev/null +++ b/spd/adapters/transcoder.py @@ -0,0 +1,128 @@ +"""Transcoder adapter: loads trained transcoders from wandb artifacts.""" + +import json +from functools import cached_property +from pathlib import Path +from typing import Any, override + +import torch +import wandb +from torch.utils.data import DataLoader + +from spd.adapters.base import DecompositionAdapter +from spd.adapters.encoder_config import EncoderConfig +from spd.adapters.transcoders import ( + BatchTopKTranscoder, + JumpReLUTranscoder, + SharedTranscoder, + TopKTranscoder, + VanillaTranscoder, +) +from spd.autointerp.schemas import ModelMetadata +from spd.data import DatasetConfig, create_data_loader +from spd.harvest.config import TranscoderHarvestConfig +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.pretrain.run_info import PretrainRunInfo +from spd.topology import TransformerTopology + +_ENCODER_CLASSES: dict[str, type[SharedTranscoder]] = { + "vanilla": VanillaTranscoder, + "topk": TopKTranscoder, + "batchtopk": BatchTopKTranscoder, + "jumprelu": JumpReLUTranscoder, +} + + +def _load_transcoder(checkpoint_dir: Path, device: str) -> SharedTranscoder: + with open(checkpoint_dir / "config.json") as f: + cfg_dict: dict[str, Any] = json.load(f) + cfg_dict["dtype"] = getattr(torch, cfg_dict.get("dtype", "torch.float32").replace("torch.", "")) + cfg_dict["device"] = device + cfg = EncoderConfig(**cfg_dict) + encoder = _ENCODER_CLASSES[cfg.encoder_type](cfg) + encoder.load_state_dict(torch.load(checkpoint_dir / "encoder.pt", map_location=device)) + encoder.eval() + return encoder + + +def _download_artifact(artifact_path: str, dest: Path) -> Path: + if dest.exists() and (dest / "encoder.pt").exists(): + return dest + api = wandb.Api() + artifact = api.artifact(artifact_path) + artifact.download(root=str(dest)) + return dest + + +class TranscoderAdapter(DecompositionAdapter): + def __init__(self, config: TranscoderHarvestConfig): + self._config = config + + @cached_property + def _run_info(self) -> PretrainRunInfo: + return PretrainRunInfo.from_path(self._config.base_model_path) + + @cached_property + def base_model(self) -> LlamaSimpleMLP: + return LlamaSimpleMLP.from_run_info(self._run_info) + + @cached_property + def _topology(self) -> TransformerTopology: + return TransformerTopology(self.base_model) + + @cached_property + def transcoders(self) -> dict[str, SharedTranscoder]: + result: dict[str, SharedTranscoder] = {} + for module_path, artifact_path in self._config.artifact_paths.items(): + safe_name = artifact_path.replace("/", "_").replace(":", "_") + dest = Path(f"checkpoints/tc_{safe_name}") + checkpoint_dir = _download_artifact(artifact_path, dest) + result[module_path] = _load_transcoder(checkpoint_dir, "cpu") + return result + + @property + @override + def decomposition_id(self) -> str: + return self._config.id + + @property + @override + def vocab_size(self) -> int: + return self.base_model.config.vocab_size + + @property + @override + def layer_activation_sizes(self) -> list[tuple[str, int]]: + return [(path, tc.dict_size) for path, tc in self.transcoders.items()] + + @property + @override + def tokenizer_name(self) -> str: + tok = self._run_info.hf_tokenizer_path + assert tok is not None, "base model run missing hf_tokenizer_path" + return tok + + @property + @override + def model_metadata(self) -> ModelMetadata: + ds_cfg = self._run_info.config_dict.get("train_dataset_config", {}) + model_cls = type(self.base_model) + return ModelMetadata( + n_blocks=self._topology.n_blocks, + model_class=f"{model_cls.__module__}.{model_cls.__qualname__}", + dataset_name=ds_cfg.get("name", "unknown"), + layer_descriptions={ + path: self._topology.target_to_canon(path) for path in self.transcoders + }, + ) + + @override + def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: + ds_cfg = self._run_info.config_dict["train_dataset_config"] + dataset_config = DatasetConfig.model_validate( + {**ds_cfg, "streaming": True, "n_ctx": self.base_model.config.block_size} + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, batch_size=batch_size, buffer_size=1000 + ) + return loader diff --git a/spd/adapters/transcoders.py b/spd/adapters/transcoders.py new file mode 100644 index 000000000..46d4e57a0 --- /dev/null +++ b/spd/adapters/transcoders.py @@ -0,0 +1,378 @@ +"""Transcoder nn.Module implementations (Vanilla, TopK, BatchTopK, JumpReLU). + +Originally by Bart Bussmann, vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). +""" + +from typing import Any, override + +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from spd.adapters.encoder_config import EncoderConfig + + +class SharedTranscoder(nn.Module): + """Base class for encoder-decoder models (SAE and Transcoder). + + Supports both SAE mode (input = target) and Transcoder mode (input != target). + All subclasses use forward(x_in, y_target) signature. + """ + + def __init__(self, cfg: EncoderConfig): + super().__init__() + + self.cfg = cfg + torch.manual_seed(cfg.seed) + + self.input_size = cfg.input_size + self.output_size = cfg.output_size + self.dict_size = cfg.dict_size + + self.b_dec = nn.Parameter(torch.zeros(cfg.output_size)) + self.b_enc = nn.Parameter(torch.zeros(cfg.dict_size)) + self.W_enc = nn.Parameter( + torch.nn.init.kaiming_uniform_(torch.empty(cfg.input_size, cfg.dict_size)) + ) + self.W_dec = nn.Parameter( + torch.nn.init.kaiming_uniform_(torch.empty(cfg.dict_size, cfg.output_size)) + ) + if cfg.input_size == cfg.output_size: + self.W_dec.data[:] = self.W_enc.t().data + self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) + self.num_batches_not_active = torch.zeros((cfg.dict_size,)).to(cfg.device) + + self.to(cfg.dtype).to(cfg.device) + + def encode(self, x: Tensor) -> Tensor: + _ = x + raise NotImplementedError + + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + _ = x + raise NotImplementedError + + def decode(self, acts: Tensor) -> Tensor: + return acts @ self.W_dec + self.b_dec + + def preprocess_input(self, x: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: + if self.cfg.input_unit_norm: + x_mean = x.mean(dim=-1, keepdim=True) + x = x - x_mean + x_std = x.std(dim=-1, keepdim=True) + x = x / (x_std + 1e-5) + return x, x_mean, x_std + return x, None, None + + def postprocess_output(self, out: Tensor, mean: Tensor | None, std: Tensor | None) -> Tensor: + if self.cfg.input_unit_norm and mean is not None: + assert std is not None + return out * std + mean + return out + + def make_decoder_weights_and_grad_unit_norm(self) -> None: + with torch.no_grad(): + W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) + assert self.W_dec.grad is not None + W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed + self.W_dec.grad -= W_dec_grad_proj + self.W_dec.data = W_dec_normed + + def update_inactive_features(self, acts: Tensor) -> None: + self.num_batches_not_active += (acts.sum(0) == 0).float() + self.num_batches_not_active[acts.sum(0) > 0] = 0 + + def _get_auxiliary_loss(self, y_target: Tensor, y_pred: Tensor, acts: Tensor) -> Tensor: + dead_features = self.num_batches_not_active >= self.cfg.n_batches_to_dead + if dead_features.sum() > 0: + residual = y_target.float() - y_pred.float() + acts_topk_aux = torch.topk( + acts[:, dead_features], + min(self.cfg.top_k_aux, int(dead_features.sum().item())), + dim=-1, + ) + acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( + -1, acts_topk_aux.indices, acts_topk_aux.values + ) + y_pred_aux = acts_aux @ self.W_dec[dead_features] + return self.cfg.aux_penalty * (y_pred_aux.float() - residual.float()).pow(2).mean() + return torch.tensor(0, dtype=y_target.dtype, device=y_target.device) + + def _build_loss_dict( + self, + y_target: Tensor, + y_pred: Tensor, + acts: Tensor, + y_pred_out: Tensor, + l0_norm: Tensor, + extra_losses: dict[str, Tensor] | None = None, + ) -> dict[str, Any]: + l2_loss = (y_pred.float() - y_target.float()).pow(2).mean() + l1_norm = acts.float().abs().sum(-1).mean() + num_dead = (self.num_batches_not_active > self.cfg.n_batches_to_dead).sum() + + loss = l2_loss + if extra_losses: + loss = loss + sum(extra_losses.values()) + + result: dict[str, Any] = { + "output": y_pred_out, + "feature_acts": acts, + "num_dead_features": num_dead, + "loss": loss, + "l2_loss": l2_loss, + "l0_norm": l0_norm, + "l1_norm": l1_norm, + } + if extra_losses: + result.update(extra_losses) + return result + + +class VanillaTranscoder(SharedTranscoder): + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + return F.relu(x_enc @ self.W_enc + self.b_enc) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + acts = self.encode(x) + return acts, acts + + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts = self.encode(x_in) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = (acts > 0).float().sum(-1).mean() + l1_loss = self.cfg.l1_coeff * acts.float().abs().sum(-1).mean() + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"l1_loss": l1_loss}, + ) + + +class TopKTranscoder(SharedTranscoder): + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts, self.cfg.top_k, dim=-1) + return torch.zeros_like(acts).scatter(-1, acts_topk.indices, acts_topk.values) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts, self.cfg.top_k, dim=-1) + acts_sparse = torch.zeros_like(acts).scatter(-1, acts_topk.indices, acts_topk.values) + return acts_sparse, acts + + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts, acts_dense = self.encode_dense(x_in) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = (acts > 0).float().sum(-1).mean() + l1_loss = self.cfg.l1_coeff * acts.float().abs().sum(-1).mean() + aux_loss = self._get_auxiliary_loss(y_target, y_pred, acts_dense) + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"l1_loss": l1_loss, "aux_loss": aux_loss}, + ) + + +class BatchTopKTranscoder(SharedTranscoder): + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts.flatten(), self.cfg.top_k * x.shape[0], dim=-1) + return ( + torch.zeros_like(acts.flatten()) + .scatter(-1, acts_topk.indices, acts_topk.values) + .reshape(acts.shape) + ) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts.flatten(), self.cfg.top_k * x.shape[0], dim=-1) + acts_sparse = ( + torch.zeros_like(acts.flatten()) + .scatter(-1, acts_topk.indices, acts_topk.values) + .reshape(acts.shape) + ) + return acts_sparse, acts + + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts, acts_dense = self.encode_dense(x_in) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = (acts > 0).float().sum(-1).mean() + l1_loss = self.cfg.l1_coeff * acts.float().abs().sum(-1).mean() + aux_loss = self._get_auxiliary_loss(y_target, y_pred, acts_dense) + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"l1_loss": l1_loss, "aux_loss": aux_loss}, + ) + + +class RectangleFunction(autograd.Function): + @staticmethod + @override + def forward(ctx: Any, x: Tensor) -> Tensor: + ctx.save_for_backward(x) + return ((x > -0.5) & (x < 0.5)).float() + + @staticmethod + @override + def backward(ctx: Any, *grad_outputs: Tensor) -> Tensor: + (grad_output,) = grad_outputs + (x,) = ctx.saved_tensors + grad_input = grad_output.clone() + grad_input[(x <= -0.5) | (x >= 0.5)] = 0 + return grad_input + + +class JumpReLUFunction(autograd.Function): + @staticmethod + @override + def forward(ctx: Any, x: Tensor, log_threshold: Tensor, bandwidth: float) -> Tensor: + ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) + threshold = torch.exp(log_threshold) + return x * (x > threshold).float() + + @staticmethod + @override + def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, Tensor, None]: + (grad_output,) = grad_outputs + x, log_threshold, bandwidth_tensor = ctx.saved_tensors + bandwidth = bandwidth_tensor.item() + threshold = torch.exp(log_threshold) + x_grad = (x > threshold).float() * grad_output + threshold_grad = ( + -(threshold / bandwidth) + * RectangleFunction.apply((x - threshold) / bandwidth) + * grad_output + ) + return x_grad, threshold_grad, None + + +class JumpReLUActivation(nn.Module): + def __init__(self, feature_size: int, bandwidth: float, device: str = "cpu"): + super().__init__() + self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) + self.bandwidth = bandwidth + + @override + def forward(self, x: Tensor) -> Tensor: + result = JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) + assert isinstance(result, Tensor) + return result + + +class StepFunction(autograd.Function): + @staticmethod + @override + def forward(ctx: Any, x: Tensor, log_threshold: Tensor, bandwidth: float) -> Tensor: + ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) + threshold = torch.exp(log_threshold) + return (x > threshold).float() + + @staticmethod + @override + def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, Tensor, None]: + (grad_output,) = grad_outputs + x, log_threshold, bandwidth_tensor = ctx.saved_tensors + bandwidth = bandwidth_tensor.item() + threshold = torch.exp(log_threshold) + x_grad = torch.zeros_like(x) + threshold_grad = ( + -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output + ) + return x_grad, threshold_grad, None + + +class JumpReLUTranscoder(SharedTranscoder): + def __init__(self, cfg: EncoderConfig): + super().__init__(cfg) + self.jumprelu = JumpReLUActivation(cfg.dict_size, cfg.bandwidth, cfg.device) + + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + pre_acts = F.relu(x_enc @ self.W_enc + self.b_enc) + return self.jumprelu(pre_acts) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + pre_acts = F.relu(x_enc @ self.W_enc + self.b_enc) + return self.jumprelu(pre_acts), pre_acts + + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts, pre_acts = self.encode_dense(x_in) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + step_result = StepFunction.apply(pre_acts, self.jumprelu.log_threshold, self.cfg.bandwidth) + assert isinstance(step_result, Tensor) + l0_norm = step_result.sum(dim=-1).mean() + sparsity_loss = self.cfg.l1_coeff * l0_norm + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"sparsity_loss": sparsity_loss}, + ) diff --git a/spd/harvest/config.py b/spd/harvest/config.py index cc01b3cd0..15a30ac55 100644 --- a/spd/harvest/config.py +++ b/spd/harvest/config.py @@ -51,7 +51,21 @@ def id(self) -> str: return "molt" -DecompositionMethodHarvestConfig = SPDHarvestConfig | CLTHarvestConfig | MOLTHarvestConfig +class TranscoderHarvestConfig(BaseConfig): + type: Literal["TranscoderHarvestConfig"] = "TranscoderHarvestConfig" + base_model_path: str + artifact_paths: dict[str, str] + """Maps module paths (e.g. "h.0.mlp") to wandb artifact paths.""" + activation_threshold: float = 0.0 + + @property + def id(self) -> str: + return "tc-" + str(abs(hash(frozenset(self.artifact_paths.items()))))[:8] + + +DecompositionMethodHarvestConfig = ( + SPDHarvestConfig | CLTHarvestConfig | MOLTHarvestConfig | TranscoderHarvestConfig +) # -- Pipeline configs ---------------------------------------------------------- diff --git a/spd/harvest/harvest_fn/__init__.py b/spd/harvest/harvest_fn/__init__.py index de57541ba..f2d297712 100644 --- a/spd/harvest/harvest_fn/__init__.py +++ b/spd/harvest/harvest_fn/__init__.py @@ -2,14 +2,17 @@ from spd.adapters.base import DecompositionAdapter from spd.adapters.spd import SPDAdapter +from spd.adapters.transcoder import TranscoderAdapter from spd.harvest.config import ( CLTHarvestConfig, DecompositionMethodHarvestConfig, MOLTHarvestConfig, SPDHarvestConfig, + TranscoderHarvestConfig, ) from spd.harvest.harvest_fn.base import HarvestFn from spd.harvest.harvest_fn.spd import SPDHarvestFn +from spd.harvest.harvest_fn.transcoder import TranscoderHarvestFn def make_harvest_fn( @@ -20,6 +23,8 @@ def make_harvest_fn( match method_config, adapter: case SPDHarvestConfig(), SPDAdapter(): return SPDHarvestFn(method_config, adapter, device=device) + case TranscoderHarvestConfig(), TranscoderAdapter(): + return TranscoderHarvestFn(adapter, method_config.activation_threshold, device=device) case CLTHarvestConfig(), _: raise NotImplementedError("CLT harvest not implemented yet") case MOLTHarvestConfig(), _: diff --git a/spd/harvest/harvest_fn/transcoder.py b/spd/harvest/harvest_fn/transcoder.py new file mode 100644 index 000000000..20f044ddb --- /dev/null +++ b/spd/harvest/harvest_fn/transcoder.py @@ -0,0 +1,72 @@ +"""Transcoder harvest function: computes sparse activations from transcoders.""" + +from typing import override + +import torch +from torch import Tensor + +from spd.adapters.transcoder import TranscoderAdapter +from spd.harvest.harvest_fn.base import HarvestFn +from spd.harvest.schemas import HarvestBatch +from spd.utils.general_utils import extract_batch_data + + +class TranscoderHarvestFn(HarvestFn): + def __init__( + self, adapter: TranscoderAdapter, activation_threshold: float, device: torch.device + ): + self._adapter = adapter + self._activation_threshold = activation_threshold + self._device = device + + adapter.base_model.to(device).eval() + for tc in adapter.transcoders.values(): + tc.to(device).eval() + + @override + def __call__(self, batch_item: torch.Tensor) -> HarvestBatch: + model = self._adapter.base_model + + batch = extract_batch_data(batch_item).to(self._device) + + # Hook target modules to capture their inputs + mlp_inputs: dict[str, Tensor] = {} + hooks = [] + for module_path in self._adapter.transcoders: + module = model.get_submodule(module_path) + + def _hook( + _mod: torch.nn.Module, + inp: tuple[Tensor, ...], + _out: Tensor, + path: str = module_path, + ) -> None: + mlp_inputs[path] = inp[0].detach() + + hooks.append(module.register_forward_hook(_hook)) + + logits, _ = model(batch) + for h in hooks: + h.remove() + + assert logits is not None + probs = torch.softmax(logits, dim=-1) + + firings: dict[str, Tensor] = {} + activations: dict[str, dict[str, Tensor]] = {} + for module_path, tc in self._adapter.transcoders.items(): + mlp_in = mlp_inputs[module_path] + B, S, _ = mlp_in.shape + flat = mlp_in.reshape(-1, tc.input_size) + acts_raw = tc.encode(flat) + assert isinstance(acts_raw, Tensor) + acts = acts_raw.reshape(B, S, -1) + firings[module_path] = acts > self._activation_threshold + activations[module_path] = {"activation": acts} + + return HarvestBatch( + tokens=batch, + firings=firings, + activations=activations, + output_probs=probs, + ) diff --git a/spd/harvest/scripts/run_worker.py b/spd/harvest/scripts/run_worker.py index 9c3d1c582..51057c510 100644 --- a/spd/harvest/scripts/run_worker.py +++ b/spd/harvest/scripts/run_worker.py @@ -11,7 +11,7 @@ import fire import torch -from spd.adapters import adapter_from_id +from spd.adapters import adapter_from_config from spd.harvest.config import HarvestConfig from spd.harvest.harvest import harvest from spd.harvest.harvest_fn import make_harvest_fn @@ -35,7 +35,7 @@ def main( config = HarvestConfig.model_validate(config_json) - adapter = adapter_from_id(config.method_config.id) + adapter = adapter_from_config(config.method_config) output_dir = get_harvest_subrun_dir(adapter.decomposition_id, subrun_id)