From e6b88065c4394891c159e11ef9c2e2a4faae823f Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 01:56:47 +0000 Subject: [PATCH 01/13] Add transcoder integration for the harvest pipeline Extends the generic harvest pipeline (from #398) to support transcoders from nn_decompositions. Adds TranscoderAdapter, TranscoderHarvestFn, and TranscoderHarvestConfig so that trained transcoders (loaded from wandb artifacts) can be harvested for activation statistics using the same pipeline as SPD. Includes an example script demonstrating end-to-end harvesting of BatchTopK k=32 transcoders across all 4 LlamaSimpleMLP layers. Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 5 + scripts/harvest_transcoders_example.py | 108 +++++++++++++++ spd/adapters/__init__.py | 7 +- spd/adapters/transcoder.py | 126 ++++++++++++++++++ spd/harvest/config.py | 16 ++- spd/harvest/harvest_fn/__init__.py | 5 + spd/harvest/harvest_fn/transcoder.py | 72 ++++++++++ uv.lock | 175 ++++++++++++++++++++++++- 8 files changed, 511 insertions(+), 3 deletions(-) create mode 100644 scripts/harvest_transcoders_example.py create mode 100644 spd/adapters/transcoder.py create mode 100644 spd/harvest/harvest_fn/transcoder.py diff --git a/pyproject.toml b/pyproject.toml index 88c3405a8..1dc03386d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,11 @@ spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" +[project.optional-dependencies] +transcoder = [ + "nn-decompositions @ git+https://github.com/bartbussmann/nn_decompositions.git", +] + [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" diff --git a/scripts/harvest_transcoders_example.py b/scripts/harvest_transcoders_example.py new file mode 100644 index 000000000..7cd75e28e --- /dev/null +++ b/scripts/harvest_transcoders_example.py @@ -0,0 +1,108 @@ +"""Example: Harvest transcoder activations using the SPD harvest pipeline. + +Loads trained BatchTopK transcoders from wandb artifacts and runs the generic +harvest pipeline to collect activation statistics (firing densities, token PMI, +activation examples). + +Usage: + python scripts/harvest_transcoders_example.py + +Prerequisites: + pip install -e /workspace/nn_decompositions +""" + +from datetime import datetime + +import torch + +from spd.adapters.transcoder import TranscoderAdapter +from spd.harvest.config import HarvestConfig, TranscoderHarvestConfig +from spd.harvest.harvest import harvest +from spd.harvest.harvest_fn.transcoder import TranscoderHarvestFn +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import get_harvest_subrun_dir + +# -- Configuration ------------------------------------------------------------ + +# Fill in the wandb artifact paths for each layer's transcoder. +# Find these at: https://wandb.ai/mats-sprint/pile_transcoder_sweep3 +# Each artifact should contain encoder.pt + config.json. +TRANSCODER_CONFIG = TranscoderHarvestConfig( + base_model_path="wandb:goodfire/spd/t-32d1bb3b", + artifact_paths={ + "h.0.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L0_5d5b1f_checkpoint_final:v0", + "h.1.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L1_c208d7_checkpoint_final:v0", + "h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0", + "h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0", + }, +) + +HARVEST_CONFIG = HarvestConfig( + method_config=TRANSCODER_CONFIG, + n_batches=20, + batch_size=8, + activation_examples_per_component=20, + activation_context_tokens_per_side=10, + pmi_token_top_k=10, +) + + +def main() -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Build adapter (downloads artifacts + loads base model and transcoders) + print("Loading transcoders and base model...") + adapter = TranscoderAdapter(TRANSCODER_CONFIG) + + print(f"Base model vocab size: {adapter.vocab_size}") + print(f"Layers: {adapter.layer_activation_sizes}") + for path, tc in adapter.transcoders.items(): + print( + f" {path}: dict_size={tc.dict_size}, encoder_type={tc.cfg.encoder_type}, top_k={tc.cfg.top_k}" + ) + + # Build harvest function + harvest_fn = TranscoderHarvestFn(adapter, TRANSCODER_CONFIG.activation_threshold, device) + + # Run harvest + subrun_id = "h-" + datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = get_harvest_subrun_dir(adapter.decomposition_id, subrun_id) + print(f"\nHarvesting to: {output_dir}") + + harvest( + layers=adapter.layer_activation_sizes, + vocab_size=adapter.vocab_size, + dataloader=adapter.dataloader(HARVEST_CONFIG.batch_size), + harvest_fn=harvest_fn, + config=HARVEST_CONFIG, + output_dir=output_dir, + rank_world_size=None, + device=device, + ) + + # Print summary + print("\n=== Harvest Summary ===") + repo = HarvestRepo(adapter.decomposition_id, subrun_id, readonly=True) + components = repo.get_all_components() + print(f"Total components harvested: {len(components)}") + + for comp in components[:10]: + n_examples = len(comp.activation_examples) + top_tokens = comp.input_token_pmi.top[:3] + print( + f" {comp.component_key}: " + f"density={comp.firing_density:.4f}, " + f"examples={n_examples}, " + f"mean_act={comp.mean_activations.get('activation', 0):.4f}, " + f"top_pmi_tokens={top_tokens}" + ) + + if len(components) > 10: + print(f" ... and {len(components) - 10} more components") + + print(f"\nResults saved to: {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/spd/adapters/__init__.py b/spd/adapters/__init__.py index aded4d188..87bf94758 100644 --- a/spd/adapters/__init__.py +++ b/spd/adapters/__init__.py @@ -1,6 +1,6 @@ """Harvest method adapters: method-specific logic for the generic harvest pipeline. -Each decomposition method (SPD, CLT, MOLT) provides an adapter that knows how to: +Each decomposition method (SPD, CLT, MOLT, Transcoder) provides an adapter that knows how to: - Load the model and build a dataloader - Compute firings and activations from a batch (harvest_fn) - Report layer structure and vocab size @@ -16,6 +16,11 @@ def adapter_from_id(id: str) -> DecompositionAdapter: if id.startswith("s-"): return SPDAdapter(id) + elif id.startswith("tc-"): + raise NotImplementedError( + "TranscoderAdapter requires a TranscoderHarvestConfig. " + "Use TranscoderAdapter(config) directly." + ) elif id.startswith("clt-"): raise NotImplementedError("CLT adapter not implemented yet") elif id.startswith("molt-"): diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py new file mode 100644 index 000000000..f484d96aa --- /dev/null +++ b/spd/adapters/transcoder.py @@ -0,0 +1,126 @@ +"""Transcoder adapter: loads trained transcoders from wandb artifacts.""" + +import json +from functools import cached_property +from pathlib import Path +from typing import Any, override + +import torch +import wandb +from nn_decompositions.config import EncoderConfig +from nn_decompositions.transcoder import ( + BatchTopKTranscoder, + JumpReLUTranscoder, + SharedTranscoder, + TopKTranscoder, + VanillaTranscoder, +) +from torch.utils.data import DataLoader + +from spd.adapters.base import DecompositionAdapter +from spd.autointerp.schemas import ModelMetadata +from spd.data import DatasetConfig, create_data_loader +from spd.harvest.config import TranscoderHarvestConfig +from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.topology import TransformerTopology + +_ENCODER_CLASSES: dict[str, type[SharedTranscoder]] = { + "vanilla": VanillaTranscoder, + "topk": TopKTranscoder, + "batchtopk": BatchTopKTranscoder, + "jumprelu": JumpReLUTranscoder, +} + + +def _load_transcoder(checkpoint_dir: Path, device: str) -> SharedTranscoder: + with open(checkpoint_dir / "config.json") as f: + cfg_dict: dict[str, Any] = json.load(f) + cfg_dict["dtype"] = getattr(torch, cfg_dict.get("dtype", "torch.float32").replace("torch.", "")) + cfg_dict["device"] = device + cfg = EncoderConfig(**cfg_dict) + encoder = _ENCODER_CLASSES[cfg.encoder_type](cfg) + encoder.load_state_dict(torch.load(checkpoint_dir / "encoder.pt", map_location=device)) + encoder.eval() + return encoder + + +def _download_artifact(artifact_path: str, dest: Path) -> Path: + if dest.exists() and (dest / "encoder.pt").exists(): + return dest + api = wandb.Api() + artifact = api.artifact(artifact_path) + artifact.download(root=str(dest)) + return dest + + +class TranscoderAdapter(DecompositionAdapter): + def __init__(self, config: TranscoderHarvestConfig): + self._config = config + + @cached_property + def base_model(self) -> LlamaSimpleMLP: + return LlamaSimpleMLP.from_pretrained(self._config.base_model_path) + + @cached_property + def _topology(self) -> TransformerTopology: + return TransformerTopology(self.base_model) + + @cached_property + def transcoders(self) -> dict[str, SharedTranscoder]: + result: dict[str, SharedTranscoder] = {} + for module_path, artifact_path in self._config.artifact_paths.items(): + safe_name = artifact_path.replace("/", "_").replace(":", "_") + dest = Path(f"checkpoints/tc_{safe_name}") + checkpoint_dir = _download_artifact(artifact_path, dest) + result[module_path] = _load_transcoder(checkpoint_dir, "cpu") + return result + + @property + @override + def decomposition_id(self) -> str: + return self._config.id + + @property + @override + def vocab_size(self) -> int: + return self.base_model.config.vocab_size + + @property + @override + def layer_activation_sizes(self) -> list[tuple[str, int]]: + return [(path, tc.dict_size) for path, tc in self.transcoders.items()] + + @property + @override + def tokenizer_name(self) -> str: + return "gpt2" + + @property + @override + def model_metadata(self) -> ModelMetadata: + return ModelMetadata( + n_blocks=self._topology.n_blocks, + model_class="spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP", + dataset_name="danbraunai/pile-uncopyrighted-tok", + layer_descriptions={ + path: self._topology.target_to_canon(path) for path in self.transcoders + }, + ) + + @override + def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: + dataset_config = DatasetConfig( + name="danbraunai/pile-uncopyrighted-tok", + is_tokenized=True, + hf_tokenizer_path="gpt2", + streaming=True, + split="train", + n_ctx=512, + column_name="input_ids", + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=1000, + ) + return loader diff --git a/spd/harvest/config.py b/spd/harvest/config.py index 8c3afba4e..7bcad4884 100644 --- a/spd/harvest/config.py +++ b/spd/harvest/config.py @@ -51,7 +51,21 @@ def id(self) -> str: return "molt" -DecompositionMethodHarvestConfig = SPDHarvestConfig | CLTHarvestConfig | MOLTHarvestConfig +class TranscoderHarvestConfig(BaseConfig): + type: Literal["TranscoderHarvestConfig"] = "TranscoderHarvestConfig" + base_model_path: str + artifact_paths: dict[str, str] + """Maps module paths (e.g. "h.0.mlp") to wandb artifact paths.""" + activation_threshold: float = 0.0 + + @property + def id(self) -> str: + return "tc-" + str(abs(hash(frozenset(self.artifact_paths.items()))))[:8] + + +DecompositionMethodHarvestConfig = ( + SPDHarvestConfig | CLTHarvestConfig | MOLTHarvestConfig | TranscoderHarvestConfig +) # -- Pipeline configs ---------------------------------------------------------- diff --git a/spd/harvest/harvest_fn/__init__.py b/spd/harvest/harvest_fn/__init__.py index de57541ba..f2d297712 100644 --- a/spd/harvest/harvest_fn/__init__.py +++ b/spd/harvest/harvest_fn/__init__.py @@ -2,14 +2,17 @@ from spd.adapters.base import DecompositionAdapter from spd.adapters.spd import SPDAdapter +from spd.adapters.transcoder import TranscoderAdapter from spd.harvest.config import ( CLTHarvestConfig, DecompositionMethodHarvestConfig, MOLTHarvestConfig, SPDHarvestConfig, + TranscoderHarvestConfig, ) from spd.harvest.harvest_fn.base import HarvestFn from spd.harvest.harvest_fn.spd import SPDHarvestFn +from spd.harvest.harvest_fn.transcoder import TranscoderHarvestFn def make_harvest_fn( @@ -20,6 +23,8 @@ def make_harvest_fn( match method_config, adapter: case SPDHarvestConfig(), SPDAdapter(): return SPDHarvestFn(method_config, adapter, device=device) + case TranscoderHarvestConfig(), TranscoderAdapter(): + return TranscoderHarvestFn(adapter, method_config.activation_threshold, device=device) case CLTHarvestConfig(), _: raise NotImplementedError("CLT harvest not implemented yet") case MOLTHarvestConfig(), _: diff --git a/spd/harvest/harvest_fn/transcoder.py b/spd/harvest/harvest_fn/transcoder.py new file mode 100644 index 000000000..20f044ddb --- /dev/null +++ b/spd/harvest/harvest_fn/transcoder.py @@ -0,0 +1,72 @@ +"""Transcoder harvest function: computes sparse activations from transcoders.""" + +from typing import override + +import torch +from torch import Tensor + +from spd.adapters.transcoder import TranscoderAdapter +from spd.harvest.harvest_fn.base import HarvestFn +from spd.harvest.schemas import HarvestBatch +from spd.utils.general_utils import extract_batch_data + + +class TranscoderHarvestFn(HarvestFn): + def __init__( + self, adapter: TranscoderAdapter, activation_threshold: float, device: torch.device + ): + self._adapter = adapter + self._activation_threshold = activation_threshold + self._device = device + + adapter.base_model.to(device).eval() + for tc in adapter.transcoders.values(): + tc.to(device).eval() + + @override + def __call__(self, batch_item: torch.Tensor) -> HarvestBatch: + model = self._adapter.base_model + + batch = extract_batch_data(batch_item).to(self._device) + + # Hook target modules to capture their inputs + mlp_inputs: dict[str, Tensor] = {} + hooks = [] + for module_path in self._adapter.transcoders: + module = model.get_submodule(module_path) + + def _hook( + _mod: torch.nn.Module, + inp: tuple[Tensor, ...], + _out: Tensor, + path: str = module_path, + ) -> None: + mlp_inputs[path] = inp[0].detach() + + hooks.append(module.register_forward_hook(_hook)) + + logits, _ = model(batch) + for h in hooks: + h.remove() + + assert logits is not None + probs = torch.softmax(logits, dim=-1) + + firings: dict[str, Tensor] = {} + activations: dict[str, dict[str, Tensor]] = {} + for module_path, tc in self._adapter.transcoders.items(): + mlp_in = mlp_inputs[module_path] + B, S, _ = mlp_in.shape + flat = mlp_in.reshape(-1, tc.input_size) + acts_raw = tc.encode(flat) + assert isinstance(acts_raw, Tensor) + acts = acts_raw.reshape(B, S, -1) + firings[module_path] = acts > self._activation_threshold + activations[module_path] = {"activation": acts} + + return HarvestBatch( + tokens=batch, + firings=firings, + activations=activations, + output_probs=probs, + ) diff --git a/uv.lock b/uv.lock index a750ef28e..33622313d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,11 +1,29 @@ version = 1 -revision = 2 +revision = 3 requires-python = "==3.13.*" resolution-markers = [ "sys_platform == 'linux'", "sys_platform != 'linux'", ] +[[package]] +name = "accelerate" +version = "1.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4a/8e/ac2a9566747a93f8be36ee08532eb0160558b07630a081a6056a9f89bf1d/accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6", size = 398399, upload-time = "2025-11-21T11:27:46.973Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/d2/c581486aa6c4fbd7394c23c47b83fa1a919d34194e16944241daf9e762dd/accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11", size = 380935, upload-time = "2025-11-21T11:27:44.522Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -155,6 +173,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, ] +[[package]] +name = "beartype" +version = "0.14.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/89/3b/9ecfc75d1f8bb75cbdc87fcb3df7c6ec4bc8f7481cb7102859ade1736c9d/beartype-0.14.1.tar.gz", hash = "sha256:23df4715d19cebb2ce60e53c3cf44cd925843f00c71938222d777ea6332de3cb", size = 964899, upload-time = "2023-06-07T05:38:56.905Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/db/8d01583b4175e0e45a6e6cd0c28db2dae38ffe5477141a7ac3a5a09c8bb9/beartype-0.14.1-py3-none-any.whl", hash = "sha256:0f70fccdb8eb6d7ddfaa3ffe3a0b66cf2edeb13452bd71ad46615775c2fa34f6", size = 739737, upload-time = "2023-06-07T05:38:54.076Z" }, +] + +[[package]] +name = "better-abc" +version = "0.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/72/3d630f781659015357cc08cad32aa636b252e007df0bae31184a3d872427/better-abc-0.0.3.tar.gz", hash = "sha256:a880fd6bc9675da2ec991e8712a555bffa0f12722efed78c739f78343cf989f6", size = 2852, upload-time = "2020-11-10T22:47:31.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/e8/7d00a23039ab74c5741736ce05d7700eb6237e83747aac4df07a5bf2d074/better_abc-0.0.3-py3-none-any.whl", hash = "sha256:3ae73b473fbeb536a548f542984976e80b821676ae6e18f14e24d8e180647187", size = 3475, upload-time = "2020-11-10T22:47:30.354Z" }, +] + [[package]] name = "blinker" version = "1.9.0" @@ -438,6 +474,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fancy-einsum" +version = "0.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/b1/f5a13cdc05b9a16502d760ead310a689a1538f3fee9618b92011200b9717/fancy_einsum-0.0.3.tar.gz", hash = "sha256:05ca6689999d0949bdaa5320c81117effa13644ec68a200121e93d7ebf3d3356", size = 4916, upload-time = "2022-02-04T01:53:46.028Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/14/26fc262ba70976eea9a42e67b05c67aa78a0ee38332d9d094cca5d2c5ec3/fancy_einsum-0.0.3-py3-none-any.whl", hash = "sha256:e0bf33587a61822b0668512ada237a0ffa5662adfb9acfcbb0356ee15a0396a1", size = 6239, upload-time = "2022-02-04T01:53:44.44Z" }, +] + [[package]] name = "fastapi" version = "0.127.0" @@ -857,6 +902,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/8f/8f6f491d595a9e5912971f3f863d81baddccc8a4d0c3749d6a0dd9ffc9df/kiwisolver-1.4.9-cp313-cp313t-win_arm64.whl", hash = "sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c", size = 68646, upload-time = "2025-08-10T21:27:00.52Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -932,6 +989,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1030,6 +1096,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] +[[package]] +name = "nn-decompositions" +version = "0.1.0" +source = { git = "https://github.com/bartbussmann/nn_decompositions.git#b7dc504eb19d68990e7b2a611be264428a5c8896" } +dependencies = [ + { name = "datasets" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformer-lens" }, + { name = "wandb" }, +] + [[package]] name = "nodeenv" version = "1.10.0" @@ -1781,6 +1859,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "14.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, +] + [[package]] name = "rpds-py" version = "0.30.0" @@ -1897,6 +1988,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, ] +[[package]] +name = "sentencepiece" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" }, + { url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" }, + { url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" }, + { url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" }, + { url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" }, + { url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" }, + { url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" }, + { url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" }, + { url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" }, + { url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" }, + { url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" }, + { url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" }, + { url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" }, + { url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" }, +] + [[package]] name = "sentry-sdk" version = "2.48.0" @@ -1970,6 +2085,11 @@ dependencies = [ { name = "zstandard" }, ] +[package.optional-dependencies] +transcoder = [ + { name = "nn-decompositions" }, +] + [package.dev-dependencies] dev = [ { name = "basedpyright" }, @@ -1992,6 +2112,7 @@ requires-dist = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib" }, + { name = "nn-decompositions", marker = "extra == 'transcoder'", git = "https://github.com/bartbussmann/nn_decompositions.git" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" }, { name = "orjson" }, @@ -2010,6 +2131,7 @@ requires-dist = [ { name = "wandb-workspaces", specifier = "==0.1.12" }, { name = "zstandard" }, ] +provides-extras = ["transcoder"] [package.metadata.requires-dev] dev = [ @@ -2251,6 +2373,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, ] +[[package]] +name = "transformer-lens" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "accelerate" }, + { name = "beartype" }, + { name = "better-abc" }, + { name = "datasets" }, + { name = "einops" }, + { name = "fancy-einsum" }, + { name = "huggingface-hub" }, + { name = "jaxtyping" }, + { name = "pandas" }, + { name = "protobuf" }, + { name = "rich" }, + { name = "sentencepiece" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "transformers-stream-generator" }, + { name = "typeguard" }, + { name = "typing-extensions" }, + { name = "wandb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0c/3b/e34b774ec23e146d307872920fec27e0e7d105507afb87cc155ab0f9954d/transformer_lens-2.17.0.tar.gz", hash = "sha256:93bfd6b6ac65e2b274edae430495663c8c8e99b53fd2c80ac428c4b0398e1272", size = 156918, upload-time = "2026-01-21T23:53:49.046Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/17/137f4e3791261f2cad2feee8f1ed31e0dea21ac7e7dec10ca133de841cbd/transformer_lens-2.17.0-py3-none-any.whl", hash = "sha256:7186cb9d29c4a20c5ea020dacc13579494e35f0b2316ebdbde972879fc4669f7", size = 195183, upload-time = "2026-01-21T23:53:50.076Z" }, +] + [[package]] name = "transformers" version = "4.57.3" @@ -2272,6 +2424,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl", hash = "sha256:c77d353a4851b1880191603d36acb313411d3577f6e2897814f333841f7003f4", size = 11993463, upload-time = "2025-11-25T15:51:26.493Z" }, ] +[[package]] +name = "transformers-stream-generator" +version = "0.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/c2/65f13aec253100e1916e9bd7965fe17bde796ebabeb1265f45191ab4ddc0/transformers-stream-generator-0.0.5.tar.gz", hash = "sha256:271deace0abf9c0f83b36db472c8ba61fdc7b04d1bf89d845644acac2795ed57", size = 13033, upload-time = "2024-03-11T14:18:02.079Z" } + [[package]] name = "triton" version = "3.4.0" @@ -2284,6 +2445,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" }, ] +[[package]] +name = "typeguard" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/e8/66e25efcc18542d58706ce4e50415710593721aae26e794ab1dec34fb66f/typeguard-4.5.1.tar.gz", hash = "sha256:f6f8ecbbc819c9bc749983cc67c02391e16a9b43b8b27f15dc70ed7c4a007274", size = 80121, upload-time = "2026-02-19T16:09:03.392Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/88/b55b3117287a8540b76dbdd87733808d4d01c8067a3b339408c250bb3600/typeguard-4.5.1-py3-none-any.whl", hash = "sha256:44d2bf329d49a244110a090b55f5f91aa82d9a9834ebfd30bcc73651e4a8cc40", size = 36745, upload-time = "2026-02-19T16:09:01.6Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From bf10ac9021c91103f632dffb3715bcb4675f80ff Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 02:12:54 +0000 Subject: [PATCH 02/13] Move tokenizer_name and dataset_name to TranscoderHarvestConfig These were incorrectly hardcoded as "gpt2" and "danbraunai/pile-uncopyrighted-tok" in the adapter. The transcoders are actually trained with the EleutherAI/gpt-neox-20b tokenizer. Co-Authored-By: Claude Opus 4.6 --- scripts/harvest_transcoders_example.py | 2 ++ spd/adapters/transcoder.py | 8 ++++---- spd/harvest/config.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/harvest_transcoders_example.py b/scripts/harvest_transcoders_example.py index 7cd75e28e..567703493 100644 --- a/scripts/harvest_transcoders_example.py +++ b/scripts/harvest_transcoders_example.py @@ -35,6 +35,8 @@ "h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0", "h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0", }, + tokenizer_name="EleutherAI/gpt-neox-20b", + dataset_name="danbraunai/pile-uncopyrighted-tok", ) HARVEST_CONFIG = HarvestConfig( diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index f484d96aa..a94e7ddaf 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -93,7 +93,7 @@ def layer_activation_sizes(self) -> list[tuple[str, int]]: @property @override def tokenizer_name(self) -> str: - return "gpt2" + return self._config.tokenizer_name @property @override @@ -101,7 +101,7 @@ def model_metadata(self) -> ModelMetadata: return ModelMetadata( n_blocks=self._topology.n_blocks, model_class="spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP", - dataset_name="danbraunai/pile-uncopyrighted-tok", + dataset_name=self._config.dataset_name, layer_descriptions={ path: self._topology.target_to_canon(path) for path in self.transcoders }, @@ -110,9 +110,9 @@ def model_metadata(self) -> ModelMetadata: @override def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: dataset_config = DatasetConfig( - name="danbraunai/pile-uncopyrighted-tok", + name=self._config.dataset_name, is_tokenized=True, - hf_tokenizer_path="gpt2", + hf_tokenizer_path=self._config.tokenizer_name, streaming=True, split="train", n_ctx=512, diff --git a/spd/harvest/config.py b/spd/harvest/config.py index 7bcad4884..f2aa88320 100644 --- a/spd/harvest/config.py +++ b/spd/harvest/config.py @@ -56,6 +56,8 @@ class TranscoderHarvestConfig(BaseConfig): base_model_path: str artifact_paths: dict[str, str] """Maps module paths (e.g. "h.0.mlp") to wandb artifact paths.""" + tokenizer_name: str + dataset_name: str activation_threshold: float = 0.0 @property From 0991b9c7f3bd12a23f32a54eb2368c59ca69be81 Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 02:18:35 +0000 Subject: [PATCH 03/13] Extract tokenizer and dataset from base model run info Instead of requiring tokenizer_name and dataset_name in the harvest config, extract them from the base model's PretrainRunInfo. The base model's wandb run already stores the full training config including hf_tokenizer_path and train_dataset_config. Co-Authored-By: Claude Opus 4.6 --- scripts/harvest_transcoders_example.py | 2 -- spd/adapters/transcoder.py | 30 +++++++++++++++++++------- spd/harvest/config.py | 2 -- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/scripts/harvest_transcoders_example.py b/scripts/harvest_transcoders_example.py index 567703493..7cd75e28e 100644 --- a/scripts/harvest_transcoders_example.py +++ b/scripts/harvest_transcoders_example.py @@ -35,8 +35,6 @@ "h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0", "h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0", }, - tokenizer_name="EleutherAI/gpt-neox-20b", - dataset_name="danbraunai/pile-uncopyrighted-tok", ) HARVEST_CONFIG = HarvestConfig( diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index a94e7ddaf..36f06d7dc 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -22,6 +22,7 @@ from spd.data import DatasetConfig, create_data_loader from spd.harvest.config import TranscoderHarvestConfig from spd.pretrain.models.llama_simple_mlp import LlamaSimpleMLP +from spd.pretrain.run_info import PretrainRunInfo from spd.topology import TransformerTopology _ENCODER_CLASSES: dict[str, type[SharedTranscoder]] = { @@ -57,14 +58,24 @@ class TranscoderAdapter(DecompositionAdapter): def __init__(self, config: TranscoderHarvestConfig): self._config = config + @cached_property + def _run_info(self) -> PretrainRunInfo: + return PretrainRunInfo.from_path(self._config.base_model_path) + @cached_property def base_model(self) -> LlamaSimpleMLP: - return LlamaSimpleMLP.from_pretrained(self._config.base_model_path) + return LlamaSimpleMLP.from_run_info(self._run_info) @cached_property def _topology(self) -> TransformerTopology: return TransformerTopology(self.base_model) + @cached_property + def _train_dataset_config(self) -> dict[str, Any]: + cfg = self._run_info.config_dict.get("train_dataset_config") + assert isinstance(cfg, dict), "base model run missing train_dataset_config" + return cfg + @cached_property def transcoders(self) -> dict[str, SharedTranscoder]: result: dict[str, SharedTranscoder] = {} @@ -93,7 +104,9 @@ def layer_activation_sizes(self) -> list[tuple[str, int]]: @property @override def tokenizer_name(self) -> str: - return self._config.tokenizer_name + tok = self._run_info.hf_tokenizer_path + assert tok is not None, "base model run missing hf_tokenizer_path" + return tok @property @override @@ -101,7 +114,7 @@ def model_metadata(self) -> ModelMetadata: return ModelMetadata( n_blocks=self._topology.n_blocks, model_class="spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP", - dataset_name=self._config.dataset_name, + dataset_name=self._train_dataset_config["name"], layer_descriptions={ path: self._topology.target_to_canon(path) for path in self.transcoders }, @@ -109,14 +122,15 @@ def model_metadata(self) -> ModelMetadata: @override def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: + ds_cfg = self._train_dataset_config dataset_config = DatasetConfig( - name=self._config.dataset_name, - is_tokenized=True, - hf_tokenizer_path=self._config.tokenizer_name, + name=ds_cfg["name"], + is_tokenized=ds_cfg.get("is_tokenized", True), + hf_tokenizer_path=self.tokenizer_name, streaming=True, split="train", - n_ctx=512, - column_name="input_ids", + n_ctx=self.base_model.config.block_size, + column_name=ds_cfg.get("column_name", "input_ids"), ) loader, _ = create_data_loader( dataset_config=dataset_config, diff --git a/spd/harvest/config.py b/spd/harvest/config.py index f2aa88320..7bcad4884 100644 --- a/spd/harvest/config.py +++ b/spd/harvest/config.py @@ -56,8 +56,6 @@ class TranscoderHarvestConfig(BaseConfig): base_model_path: str artifact_paths: dict[str, str] """Maps module paths (e.g. "h.0.mlp") to wandb artifact paths.""" - tokenizer_name: str - dataset_name: str activation_threshold: float = 0.0 @property From b9cdf8bfb25d3e15bf0a46fae72f4e9e1ee2edbd Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 02:26:56 +0000 Subject: [PATCH 04/13] Read dataloader config from base model run info Use the base model's train_dataset_config directly instead of hardcoding dataset fields. Only override streaming=True (for harvest) and n_ctx=block_size (strip the extra label token). Co-Authored-By: Claude Opus 4.6 --- spd/adapters/transcoder.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index 36f06d7dc..3097c43f6 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -70,12 +70,6 @@ def base_model(self) -> LlamaSimpleMLP: def _topology(self) -> TransformerTopology: return TransformerTopology(self.base_model) - @cached_property - def _train_dataset_config(self) -> dict[str, Any]: - cfg = self._run_info.config_dict.get("train_dataset_config") - assert isinstance(cfg, dict), "base model run missing train_dataset_config" - return cfg - @cached_property def transcoders(self) -> dict[str, SharedTranscoder]: result: dict[str, SharedTranscoder] = {} @@ -111,10 +105,11 @@ def tokenizer_name(self) -> str: @property @override def model_metadata(self) -> ModelMetadata: + ds_cfg = self._run_info.config_dict.get("train_dataset_config", {}) return ModelMetadata( n_blocks=self._topology.n_blocks, model_class="spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP", - dataset_name=self._train_dataset_config["name"], + dataset_name=ds_cfg.get("name", "unknown"), layer_descriptions={ path: self._topology.target_to_canon(path) for path in self.transcoders }, @@ -122,15 +117,15 @@ def model_metadata(self) -> ModelMetadata: @override def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: - ds_cfg = self._train_dataset_config + ds_cfg = self._run_info.config_dict["train_dataset_config"] dataset_config = DatasetConfig( name=ds_cfg["name"], - is_tokenized=ds_cfg.get("is_tokenized", True), - hf_tokenizer_path=self.tokenizer_name, + is_tokenized=ds_cfg["is_tokenized"], + hf_tokenizer_path=ds_cfg["hf_tokenizer_path"], streaming=True, - split="train", + split=ds_cfg["split"], n_ctx=self.base_model.config.block_size, - column_name=ds_cfg.get("column_name", "input_ids"), + column_name=ds_cfg["column_name"], ) loader, _ = create_data_loader( dataset_config=dataset_config, From 31e882e5af269c2636527d5c64d5a05f62692960 Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 02:31:37 +0000 Subject: [PATCH 05/13] Simplify dataloader: construct DatasetConfig from pretrain run config Co-Authored-By: Claude Opus 4.6 --- spd/adapters/transcoder.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index 3097c43f6..87b442ac3 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -118,18 +118,10 @@ def model_metadata(self) -> ModelMetadata: @override def dataloader(self, batch_size: int) -> DataLoader[torch.Tensor]: ds_cfg = self._run_info.config_dict["train_dataset_config"] - dataset_config = DatasetConfig( - name=ds_cfg["name"], - is_tokenized=ds_cfg["is_tokenized"], - hf_tokenizer_path=ds_cfg["hf_tokenizer_path"], - streaming=True, - split=ds_cfg["split"], - n_ctx=self.base_model.config.block_size, - column_name=ds_cfg["column_name"], + dataset_config = DatasetConfig.model_validate( + {**ds_cfg, "streaming": True, "n_ctx": self.base_model.config.block_size} ) loader, _ = create_data_loader( - dataset_config=dataset_config, - batch_size=batch_size, - buffer_size=1000, + dataset_config=dataset_config, batch_size=batch_size, buffer_size=1000 ) return loader From 323f3d356f11c55e43f8a0e961b700c299e5d058 Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 02:35:45 +0000 Subject: [PATCH 06/13] Derive model_class from actual model type instead of hardcoding Co-Authored-By: Claude Opus 4.6 --- spd/adapters/transcoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index 87b442ac3..b99474352 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -106,9 +106,10 @@ def tokenizer_name(self) -> str: @override def model_metadata(self) -> ModelMetadata: ds_cfg = self._run_info.config_dict.get("train_dataset_config", {}) + model_cls = type(self.base_model) return ModelMetadata( n_blocks=self._topology.n_blocks, - model_class="spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP", + model_class=f"{model_cls.__module__}.{model_cls.__qualname__}", dataset_name=ds_cfg.get("name", "unknown"), layer_descriptions={ path: self._topology.target_to_canon(path) for path in self.transcoders From db9847015a03ae9e3a24e2da94c209d4a738b9d6 Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 03:48:07 +0000 Subject: [PATCH 07/13] Fix prerequisite in example script to use optional dependency Co-Authored-By: Claude Opus 4.6 --- scripts/harvest_transcoders_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/harvest_transcoders_example.py b/scripts/harvest_transcoders_example.py index 7cd75e28e..83ef46461 100644 --- a/scripts/harvest_transcoders_example.py +++ b/scripts/harvest_transcoders_example.py @@ -8,7 +8,7 @@ python scripts/harvest_transcoders_example.py Prerequisites: - pip install -e /workspace/nn_decompositions + pip install -e ".[transcoder]" """ from datetime import datetime From c2af88490b4163c0f6c9d0ec5b2bed8d4fd6aa77 Mon Sep 17 00:00:00 2001 From: bartbussmann Date: Tue, 3 Mar 2026 03:50:04 +0000 Subject: [PATCH 08/13] Rename optional dependency from 'transcoder' to 'nn_decompositions' Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 2 +- scripts/harvest_transcoders_example.py | 2 +- uv.lock | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1dc03386d..6d732a594 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" [project.optional-dependencies] -transcoder = [ +nn_decompositions = [ "nn-decompositions @ git+https://github.com/bartbussmann/nn_decompositions.git", ] diff --git a/scripts/harvest_transcoders_example.py b/scripts/harvest_transcoders_example.py index 83ef46461..6055fcdd2 100644 --- a/scripts/harvest_transcoders_example.py +++ b/scripts/harvest_transcoders_example.py @@ -8,7 +8,7 @@ python scripts/harvest_transcoders_example.py Prerequisites: - pip install -e ".[transcoder]" + pip install -e ".[nn_decompositions]" """ from datetime import datetime diff --git a/uv.lock b/uv.lock index 33622313d..2ab178f43 100644 --- a/uv.lock +++ b/uv.lock @@ -2086,7 +2086,7 @@ dependencies = [ ] [package.optional-dependencies] -transcoder = [ +nn-decompositions = [ { name = "nn-decompositions" }, ] @@ -2112,7 +2112,7 @@ requires-dist = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib" }, - { name = "nn-decompositions", marker = "extra == 'transcoder'", git = "https://github.com/bartbussmann/nn_decompositions.git" }, + { name = "nn-decompositions", marker = "extra == 'nn-decompositions'", git = "https://github.com/bartbussmann/nn_decompositions.git" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" }, { name = "orjson" }, @@ -2131,7 +2131,7 @@ requires-dist = [ { name = "wandb-workspaces", specifier = "==0.1.12" }, { name = "zstandard" }, ] -provides-extras = ["transcoder"] +provides-extras = ["nn-decompositions"] [package.metadata.requires-dev] dev = [ From 14ce4caeab1f490f4eeb045365fcd14672c58855 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 10:19:52 +0000 Subject: [PATCH 09/13] Make transcoder harvest launchable from CLI config - Add adapter_from_config() that takes the full method_config, so TranscoderAdapter can be constructed in the harvest worker - Keep adapter_from_id() for downstream consumers (autointerp, intruder) that only have a decomposition ID - Replace Python example script with YAML config for spd-harvest - Exclude transcoder files from basedpyright (optional nn_decompositions dep) Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 2 +- scripts/harvest_transcoders_example.py | 108 ----------------------- scripts/harvest_transcoders_example.yaml | 28 ++++++ spd/adapters/__init__.py | 45 +++++++--- spd/adapters/transcoder.py | 5 +- spd/harvest/scripts/run_worker.py | 4 +- 6 files changed, 69 insertions(+), 123 deletions(-) delete mode 100644 scripts/harvest_transcoders_example.py create mode 100644 scripts/harvest_transcoders_example.yaml diff --git a/pyproject.toml b/pyproject.toml index 6d732a594..f8998809a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ known-third-party = ["wandb"] [tool.pyright] include = ["spd", "tests"] -exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend"] +exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend", "spd/adapters/transcoder.py", "spd/harvest/harvest_fn/transcoder.py"] stubPath = "typings" # Having type stubs for transformers shaves 10 seconds off basedpyright calls strictListInference = true diff --git a/scripts/harvest_transcoders_example.py b/scripts/harvest_transcoders_example.py deleted file mode 100644 index 6055fcdd2..000000000 --- a/scripts/harvest_transcoders_example.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Example: Harvest transcoder activations using the SPD harvest pipeline. - -Loads trained BatchTopK transcoders from wandb artifacts and runs the generic -harvest pipeline to collect activation statistics (firing densities, token PMI, -activation examples). - -Usage: - python scripts/harvest_transcoders_example.py - -Prerequisites: - pip install -e ".[nn_decompositions]" -""" - -from datetime import datetime - -import torch - -from spd.adapters.transcoder import TranscoderAdapter -from spd.harvest.config import HarvestConfig, TranscoderHarvestConfig -from spd.harvest.harvest import harvest -from spd.harvest.harvest_fn.transcoder import TranscoderHarvestFn -from spd.harvest.repo import HarvestRepo -from spd.harvest.schemas import get_harvest_subrun_dir - -# -- Configuration ------------------------------------------------------------ - -# Fill in the wandb artifact paths for each layer's transcoder. -# Find these at: https://wandb.ai/mats-sprint/pile_transcoder_sweep3 -# Each artifact should contain encoder.pt + config.json. -TRANSCODER_CONFIG = TranscoderHarvestConfig( - base_model_path="wandb:goodfire/spd/t-32d1bb3b", - artifact_paths={ - "h.0.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L0_5d5b1f_checkpoint_final:v0", - "h.1.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L1_c208d7_checkpoint_final:v0", - "h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0", - "h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0", - }, -) - -HARVEST_CONFIG = HarvestConfig( - method_config=TRANSCODER_CONFIG, - n_batches=20, - batch_size=8, - activation_examples_per_component=20, - activation_context_tokens_per_side=10, - pmi_token_top_k=10, -) - - -def main() -> None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Build adapter (downloads artifacts + loads base model and transcoders) - print("Loading transcoders and base model...") - adapter = TranscoderAdapter(TRANSCODER_CONFIG) - - print(f"Base model vocab size: {adapter.vocab_size}") - print(f"Layers: {adapter.layer_activation_sizes}") - for path, tc in adapter.transcoders.items(): - print( - f" {path}: dict_size={tc.dict_size}, encoder_type={tc.cfg.encoder_type}, top_k={tc.cfg.top_k}" - ) - - # Build harvest function - harvest_fn = TranscoderHarvestFn(adapter, TRANSCODER_CONFIG.activation_threshold, device) - - # Run harvest - subrun_id = "h-" + datetime.now().strftime("%Y%m%d_%H%M%S") - output_dir = get_harvest_subrun_dir(adapter.decomposition_id, subrun_id) - print(f"\nHarvesting to: {output_dir}") - - harvest( - layers=adapter.layer_activation_sizes, - vocab_size=adapter.vocab_size, - dataloader=adapter.dataloader(HARVEST_CONFIG.batch_size), - harvest_fn=harvest_fn, - config=HARVEST_CONFIG, - output_dir=output_dir, - rank_world_size=None, - device=device, - ) - - # Print summary - print("\n=== Harvest Summary ===") - repo = HarvestRepo(adapter.decomposition_id, subrun_id, readonly=True) - components = repo.get_all_components() - print(f"Total components harvested: {len(components)}") - - for comp in components[:10]: - n_examples = len(comp.activation_examples) - top_tokens = comp.input_token_pmi.top[:3] - print( - f" {comp.component_key}: " - f"density={comp.firing_density:.4f}, " - f"examples={n_examples}, " - f"mean_act={comp.mean_activations.get('activation', 0):.4f}, " - f"top_pmi_tokens={top_tokens}" - ) - - if len(components) > 10: - print(f" ... and {len(components) - 10} more components") - - print(f"\nResults saved to: {output_dir}") - - -if __name__ == "__main__": - main() diff --git a/scripts/harvest_transcoders_example.yaml b/scripts/harvest_transcoders_example.yaml new file mode 100644 index 000000000..a1026ff80 --- /dev/null +++ b/scripts/harvest_transcoders_example.yaml @@ -0,0 +1,28 @@ +# Example: Harvest transcoder activations using spd-harvest. +# +# Loads trained BatchTopK transcoders from wandb artifacts and runs the generic +# harvest pipeline to collect activation statistics (firing densities, token PMI, +# activation examples). +# +# Usage: +# spd-harvest scripts/harvest_transcoders_example.yaml +# +# Prerequisites: +# pip install -e ".[nn_decompositions]" + +config: + method_config: + type: TranscoderHarvestConfig + base_model_path: "wandb:goodfire/spd/t-32d1bb3b" + artifact_paths: + "h.0.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L0_5d5b1f_checkpoint_final:v0" + "h.1.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L1_c208d7_checkpoint_final:v0" + "h.2.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L2_4f6e37_checkpoint_final:v0" + "h.3.mlp": "mats-sprint/pile_transcoder_sweep3/4096_batchtopk_k32_0.0003_L3_e76468_checkpoint_final:v0" + n_batches: 20 + batch_size: 8 + activation_examples_per_component: 20 + activation_context_tokens_per_side: 10 + pmi_token_top_k: 10 + +n_gpus: 1 diff --git a/spd/adapters/__init__.py b/spd/adapters/__init__.py index 87bf94758..4168b603c 100644 --- a/spd/adapters/__init__.py +++ b/spd/adapters/__init__.py @@ -9,21 +9,44 @@ """ from spd.adapters.base import DecompositionAdapter +from spd.harvest.config import DecompositionMethodHarvestConfig + + +def adapter_from_config(method_config: DecompositionMethodHarvestConfig) -> DecompositionAdapter: + from spd.harvest.config import ( + CLTHarvestConfig, + MOLTHarvestConfig, + SPDHarvestConfig, + TranscoderHarvestConfig, + ) + + match method_config: + case SPDHarvestConfig(): + from spd.adapters.spd import SPDAdapter + + return SPDAdapter(method_config.id) + case TranscoderHarvestConfig(): + from spd.adapters.transcoder import TranscoderAdapter + + return TranscoderAdapter(method_config) + case CLTHarvestConfig(): + raise NotImplementedError("CLT adapter not implemented yet") + case MOLTHarvestConfig(): + raise NotImplementedError("MOLT adapter not implemented yet") def adapter_from_id(id: str) -> DecompositionAdapter: + """Construct an adapter from just a decomposition ID (e.g. "s-abc123"). + + Only works for methods whose adapter can be constructed from an ID alone (SPD). + For transcoders, use adapter_from_config() with the full method config. + """ from spd.adapters.spd import SPDAdapter if id.startswith("s-"): return SPDAdapter(id) - elif id.startswith("tc-"): - raise NotImplementedError( - "TranscoderAdapter requires a TranscoderHarvestConfig. " - "Use TranscoderAdapter(config) directly." - ) - elif id.startswith("clt-"): - raise NotImplementedError("CLT adapter not implemented yet") - elif id.startswith("molt-"): - raise NotImplementedError("MOLT adapter not implemented yet") - - raise ValueError(f"Unsupported decomposition ID: {id}") + + raise ValueError( + f"Cannot construct adapter from ID alone: {id!r}. " + f"Use adapter_from_config() with the full method config." + ) diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index b99474352..7fb147743 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -1,4 +1,7 @@ -"""Transcoder adapter: loads trained transcoders from wandb artifacts.""" +"""Transcoder adapter: loads trained transcoders from wandb artifacts. + +Requires the optional `nn_decompositions` dependency: pip install -e ".[nn_decompositions]" +""" import json from functools import cached_property diff --git a/spd/harvest/scripts/run_worker.py b/spd/harvest/scripts/run_worker.py index 9c3d1c582..51057c510 100644 --- a/spd/harvest/scripts/run_worker.py +++ b/spd/harvest/scripts/run_worker.py @@ -11,7 +11,7 @@ import fire import torch -from spd.adapters import adapter_from_id +from spd.adapters import adapter_from_config from spd.harvest.config import HarvestConfig from spd.harvest.harvest import harvest from spd.harvest.harvest_fn import make_harvest_fn @@ -35,7 +35,7 @@ def main( config = HarvestConfig.model_validate(config_json) - adapter = adapter_from_id(config.method_config.id) + adapter = adapter_from_config(config.method_config) output_dir = get_harvest_subrun_dir(adapter.decomposition_id, subrun_id) From 8c331223a0b0585c1a6f01c7a1d3d3744a483788 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 10:42:44 +0000 Subject: [PATCH 10/13] Vendor nn_decompositions transcoder code into spd/adapters/ Copies EncoderConfig and SharedTranscoder + subclasses (474 lines) from bartbussmann/nn_decompositions (MIT) into spd/adapters/, eliminating the optional dependency. Only torch + stdlib needed, both already deps. - spd/adapters/encoder_config.py: EncoderConfig dataclass - spd/adapters/transcoders.py: SharedTranscoder, Vanilla/TopK/BatchTopK/JumpReLU - Remove nn_decompositions optional dep from pyproject.toml Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 7 +- scripts/harvest_transcoders_example.yaml | 3 - spd/adapters/encoder_config.py | 68 +++++ spd/adapters/transcoder.py | 15 +- spd/adapters/transcoders.py | 334 +++++++++++++++++++++++ uv.lock | 175 +----------- 6 files changed, 410 insertions(+), 192 deletions(-) create mode 100644 spd/adapters/encoder_config.py create mode 100644 spd/adapters/transcoders.py diff --git a/pyproject.toml b/pyproject.toml index f8998809a..88c3405a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,11 +58,6 @@ spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" -[project.optional-dependencies] -nn_decompositions = [ - "nn-decompositions @ git+https://github.com/bartbussmann/nn_decompositions.git", -] - [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" @@ -106,7 +101,7 @@ known-third-party = ["wandb"] [tool.pyright] include = ["spd", "tests"] -exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend", "spd/adapters/transcoder.py", "spd/harvest/harvest_fn/transcoder.py"] +exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend"] stubPath = "typings" # Having type stubs for transformers shaves 10 seconds off basedpyright calls strictListInference = true diff --git a/scripts/harvest_transcoders_example.yaml b/scripts/harvest_transcoders_example.yaml index a1026ff80..f586f604d 100644 --- a/scripts/harvest_transcoders_example.yaml +++ b/scripts/harvest_transcoders_example.yaml @@ -6,9 +6,6 @@ # # Usage: # spd-harvest scripts/harvest_transcoders_example.yaml -# -# Prerequisites: -# pip install -e ".[nn_decompositions]" config: method_config: diff --git a/spd/adapters/encoder_config.py b/spd/adapters/encoder_config.py new file mode 100644 index 000000000..e9a0e2e40 --- /dev/null +++ b/spd/adapters/encoder_config.py @@ -0,0 +1,68 @@ +"""Encoder configuration for transcoder architectures. + +Vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). +Only EncoderConfig is used; CLTConfig and SAEConfig are omitted. +""" + +from dataclasses import dataclass, field +from typing import Literal + +import torch + + +@dataclass +class EncoderConfig: + """Base config for encoder architectures (SAE and Transcoder).""" + + # Architecture + input_size: int + output_size: int + dict_size: int = 12288 + + # Encoder type + encoder_type: Literal["vanilla", "topk", "batchtopk", "jumprelu"] = "topk" + + # Training + seed: int = 49 + batch_size: int = 4096 + lr: float = 3e-4 + num_tokens: int = int(1e9) + l1_coeff: float = 0.0 + beta1: float = 0.9 + beta2: float = 0.99 + max_grad_norm: float = 1.0 + + # Device + device: str = "cuda:0" + dtype: torch.dtype = field(default=torch.float32) + + # Dead feature tracking + n_batches_to_dead: int = 50 + + # Optional features + input_unit_norm: bool = False + pre_enc_bias: bool = False + + # TopK specific + top_k: int = 32 + top_k_aux: int = 512 + aux_penalty: float = 1 / 32 + + # JumpReLU specific + bandwidth: float = 0.001 + + # Logging + run_name: str | None = None + wandb_project: str = "encoders" + perf_log_freq: int = 1000 + checkpoint_freq: int | Literal["final"] = "final" + n_eval_seqs: int = 8 + + @property + def name(self) -> str: + if self.run_name is not None: + return self.run_name + base = f"{self.dict_size}_{self.encoder_type}" + if self.encoder_type in ("topk", "batchtopk"): + base += f"_k{self.top_k}" + return f"{base}_{self.lr}" diff --git a/spd/adapters/transcoder.py b/spd/adapters/transcoder.py index 7fb147743..457c7afaa 100644 --- a/spd/adapters/transcoder.py +++ b/spd/adapters/transcoder.py @@ -1,7 +1,4 @@ -"""Transcoder adapter: loads trained transcoders from wandb artifacts. - -Requires the optional `nn_decompositions` dependency: pip install -e ".[nn_decompositions]" -""" +"""Transcoder adapter: loads trained transcoders from wandb artifacts.""" import json from functools import cached_property @@ -10,17 +7,17 @@ import torch import wandb -from nn_decompositions.config import EncoderConfig -from nn_decompositions.transcoder import ( +from torch.utils.data import DataLoader + +from spd.adapters.base import DecompositionAdapter +from spd.adapters.encoder_config import EncoderConfig +from spd.adapters.transcoders import ( BatchTopKTranscoder, JumpReLUTranscoder, SharedTranscoder, TopKTranscoder, VanillaTranscoder, ) -from torch.utils.data import DataLoader - -from spd.adapters.base import DecompositionAdapter from spd.autointerp.schemas import ModelMetadata from spd.data import DatasetConfig, create_data_loader from spd.harvest.config import TranscoderHarvestConfig diff --git a/spd/adapters/transcoders.py b/spd/adapters/transcoders.py new file mode 100644 index 000000000..8ca0cd93d --- /dev/null +++ b/spd/adapters/transcoders.py @@ -0,0 +1,334 @@ +"""Transcoder nn.Module implementations (Vanilla, TopK, BatchTopK, JumpReLU). + +Vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). +""" + +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F + +from spd.adapters.encoder_config import EncoderConfig + + +class SharedTranscoder(nn.Module): + """Base class for encoder-decoder models (SAE and Transcoder). + + Supports both SAE mode (input = target) and Transcoder mode (input != target). + All subclasses use forward(x_in, y_target) signature. + """ + + def __init__(self, cfg: EncoderConfig): + super().__init__() + + self.cfg = cfg + torch.manual_seed(cfg.seed) + + self.input_size = cfg.input_size + self.output_size = cfg.output_size + self.dict_size = cfg.dict_size + + self.b_dec = nn.Parameter(torch.zeros(cfg.output_size)) + self.b_enc = nn.Parameter(torch.zeros(cfg.dict_size)) + self.W_enc = nn.Parameter( + torch.nn.init.kaiming_uniform_(torch.empty(cfg.input_size, cfg.dict_size)) + ) + self.W_dec = nn.Parameter( + torch.nn.init.kaiming_uniform_(torch.empty(cfg.dict_size, cfg.output_size)) + ) + # Initialize W_dec from W_enc only if input_size == output_size (SAE case) + if cfg.input_size == cfg.output_size: + self.W_dec.data[:] = self.W_enc.t().data + self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) + self.num_batches_not_active = torch.zeros((cfg.dict_size,)).to(cfg.device) + + self.to(cfg.dtype).to(cfg.device) + + def encode( + self, x: torch.Tensor, return_dense: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Encode input to sparse activations. Subclasses must implement. + + If return_dense=True, also returns the dense (pre-sparsification) activations + as a second element. + """ + raise NotImplementedError + + def decode(self, acts: torch.Tensor) -> torch.Tensor: + return acts @ self.W_dec + self.b_dec + + def preprocess_input( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if self.cfg.input_unit_norm: + x_mean = x.mean(dim=-1, keepdim=True) + x = x - x_mean + x_std = x.std(dim=-1, keepdim=True) + x = x / (x_std + 1e-5) + return x, x_mean, x_std + return x, None, None + + def postprocess_output( + self, out: torch.Tensor, mean: torch.Tensor | None, std: torch.Tensor | None + ) -> torch.Tensor: + if self.cfg.input_unit_norm and mean is not None: + return out * std + mean + return out + + @torch.no_grad() + def make_decoder_weights_and_grad_unit_norm(self): + W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) + W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed + self.W_dec.grad -= W_dec_grad_proj + self.W_dec.data = W_dec_normed + + def update_inactive_features(self, acts: torch.Tensor): + self.num_batches_not_active += (acts.sum(0) == 0).float() + self.num_batches_not_active[acts.sum(0) > 0] = 0 + + def _get_auxiliary_loss( + self, y_target: torch.Tensor, y_pred: torch.Tensor, acts: torch.Tensor + ) -> torch.Tensor: + dead_features = self.num_batches_not_active >= self.cfg.n_batches_to_dead + if dead_features.sum() > 0: + residual = y_target.float() - y_pred.float() + acts_topk_aux = torch.topk( + acts[:, dead_features], + min(self.cfg.top_k_aux, dead_features.sum()), + dim=-1, + ) + acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( + -1, acts_topk_aux.indices, acts_topk_aux.values + ) + y_pred_aux = acts_aux @ self.W_dec[dead_features] + return self.cfg.aux_penalty * (y_pred_aux.float() - residual.float()).pow(2).mean() + return torch.tensor(0, dtype=y_target.dtype, device=y_target.device) + + def _build_loss_dict( + self, + y_target: torch.Tensor, + y_pred: torch.Tensor, + acts: torch.Tensor, + y_pred_out: torch.Tensor, + l0_norm: torch.Tensor, + extra_losses: dict[str, torch.Tensor] | None = None, + ) -> dict: + l2_loss = (y_pred.float() - y_target.float()).pow(2).mean() + l1_norm = acts.float().abs().sum(-1).mean() + num_dead = (self.num_batches_not_active > self.cfg.n_batches_to_dead).sum() + + loss = l2_loss + if extra_losses: + loss = loss + sum(extra_losses.values()) + + result = { + "output": y_pred_out, + "feature_acts": acts, + "num_dead_features": num_dead, + "loss": loss, + "l2_loss": l2_loss, + "l0_norm": l0_norm, + "l1_norm": l1_norm, + } + if extra_losses: + result.update(extra_losses) + return result + + +class VanillaTranscoder(SharedTranscoder): + def encode(self, x: torch.Tensor, return_dense: bool = False): + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + return (acts, acts) if return_dense else acts + + def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts = self.encode(x_in) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = (acts > 0).float().sum(-1).mean() + l1_loss = self.cfg.l1_coeff * acts.float().abs().sum(-1).mean() + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"l1_loss": l1_loss}, + ) + + +class TopKTranscoder(SharedTranscoder): + def encode(self, x: torch.Tensor, return_dense: bool = False): + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts, self.cfg.top_k, dim=-1) + acts_sparse = torch.zeros_like(acts).scatter(-1, acts_topk.indices, acts_topk.values) + return (acts_sparse, acts) if return_dense else acts_sparse + + def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts, acts_dense = self.encode(x_in, return_dense=True) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = (acts > 0).float().sum(-1).mean() + l1_loss = self.cfg.l1_coeff * acts.float().abs().sum(-1).mean() + aux_loss = self._get_auxiliary_loss(y_target, y_pred, acts_dense) + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"l1_loss": l1_loss, "aux_loss": aux_loss}, + ) + + +class BatchTopKTranscoder(SharedTranscoder): + def encode(self, x: torch.Tensor, return_dense: bool = False): + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts.flatten(), self.cfg.top_k * x.shape[0], dim=-1) + acts_sparse = ( + torch.zeros_like(acts.flatten()) + .scatter(-1, acts_topk.indices, acts_topk.values) + .reshape(acts.shape) + ) + return (acts_sparse, acts) if return_dense else acts_sparse + + def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts, acts_dense = self.encode(x_in, return_dense=True) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = (acts > 0).float().sum(-1).mean() + l1_loss = self.cfg.l1_coeff * acts.float().abs().sum(-1).mean() + aux_loss = self._get_auxiliary_loss(y_target, y_pred, acts_dense) + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"l1_loss": l1_loss, "aux_loss": aux_loss}, + ) + + +class RectangleFunction(autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return ((x > -0.5) & (x < 0.5)).float() + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + grad_input = grad_output.clone() + grad_input[(x <= -0.5) | (x >= 0.5)] = 0 + return grad_input + + +class JumpReLUFunction(autograd.Function): + @staticmethod + def forward(ctx, x, log_threshold, bandwidth): + ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) + threshold = torch.exp(log_threshold) + return x * (x > threshold).float() + + @staticmethod + def backward(ctx, grad_output): + x, log_threshold, bandwidth_tensor = ctx.saved_tensors + bandwidth = bandwidth_tensor.item() + threshold = torch.exp(log_threshold) + x_grad = (x > threshold).float() * grad_output + threshold_grad = ( + -(threshold / bandwidth) + * RectangleFunction.apply((x - threshold) / bandwidth) + * grad_output + ) + return x_grad, threshold_grad, None + + +class JumpReLUActivation(nn.Module): + def __init__(self, feature_size: int, bandwidth: float, device: str = "cpu"): + super().__init__() + self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) + self.bandwidth = bandwidth + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) + + +class StepFunction(autograd.Function): + @staticmethod + def forward(ctx, x, log_threshold, bandwidth): + ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) + threshold = torch.exp(log_threshold) + return (x > threshold).float() + + @staticmethod + def backward(ctx, grad_output): + x, log_threshold, bandwidth_tensor = ctx.saved_tensors + bandwidth = bandwidth_tensor.item() + threshold = torch.exp(log_threshold) + x_grad = torch.zeros_like(x) + threshold_grad = ( + -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output + ) + return x_grad, threshold_grad, None + + +class JumpReLUTranscoder(SharedTranscoder): + def __init__(self, cfg: EncoderConfig): + super().__init__(cfg) + self.jumprelu = JumpReLUActivation(cfg.dict_size, cfg.bandwidth, cfg.device) + + def encode(self, x: torch.Tensor, return_dense: bool = False): + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + pre_acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts = self.jumprelu(pre_acts) + return (acts, pre_acts) if return_dense else acts + + def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + x_in, _, _ = self.preprocess_input(x_in) + y_target, y_mean, y_std = self.preprocess_input(y_target) + + acts, pre_acts = self.encode(x_in, return_dense=True) + y_pred = self.decode(acts) + y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) + + self.update_inactive_features(acts) + + l0_norm = ( + StepFunction.apply(pre_acts, self.jumprelu.log_threshold, self.cfg.bandwidth) + .sum(dim=-1) + .mean() + ) + sparsity_loss = self.cfg.l1_coeff * l0_norm + return self._build_loss_dict( + y_target, + y_pred, + acts, + y_pred_out, + l0_norm, + extra_losses={"sparsity_loss": sparsity_loss}, + ) diff --git a/uv.lock b/uv.lock index 2ab178f43..a750ef28e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,29 +1,11 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.13.*" resolution-markers = [ "sys_platform == 'linux'", "sys_platform != 'linux'", ] -[[package]] -name = "accelerate" -version = "1.12.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "huggingface-hub" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "psutil" }, - { name = "pyyaml" }, - { name = "safetensors" }, - { name = "torch" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4a/8e/ac2a9566747a93f8be36ee08532eb0160558b07630a081a6056a9f89bf1d/accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6", size = 398399, upload-time = "2025-11-21T11:27:46.973Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/d2/c581486aa6c4fbd7394c23c47b83fa1a919d34194e16944241daf9e762dd/accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11", size = 380935, upload-time = "2025-11-21T11:27:44.522Z" }, -] - [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -173,24 +155,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, ] -[[package]] -name = "beartype" -version = "0.14.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/89/3b/9ecfc75d1f8bb75cbdc87fcb3df7c6ec4bc8f7481cb7102859ade1736c9d/beartype-0.14.1.tar.gz", hash = "sha256:23df4715d19cebb2ce60e53c3cf44cd925843f00c71938222d777ea6332de3cb", size = 964899, upload-time = "2023-06-07T05:38:56.905Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/db/8d01583b4175e0e45a6e6cd0c28db2dae38ffe5477141a7ac3a5a09c8bb9/beartype-0.14.1-py3-none-any.whl", hash = "sha256:0f70fccdb8eb6d7ddfaa3ffe3a0b66cf2edeb13452bd71ad46615775c2fa34f6", size = 739737, upload-time = "2023-06-07T05:38:54.076Z" }, -] - -[[package]] -name = "better-abc" -version = "0.0.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8b/72/3d630f781659015357cc08cad32aa636b252e007df0bae31184a3d872427/better-abc-0.0.3.tar.gz", hash = "sha256:a880fd6bc9675da2ec991e8712a555bffa0f12722efed78c739f78343cf989f6", size = 2852, upload-time = "2020-11-10T22:47:31.303Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/e8/7d00a23039ab74c5741736ce05d7700eb6237e83747aac4df07a5bf2d074/better_abc-0.0.3-py3-none-any.whl", hash = "sha256:3ae73b473fbeb536a548f542984976e80b821676ae6e18f14e24d8e180647187", size = 3475, upload-time = "2020-11-10T22:47:30.354Z" }, -] - [[package]] name = "blinker" version = "1.9.0" @@ -474,15 +438,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] -[[package]] -name = "fancy-einsum" -version = "0.0.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b9/b1/f5a13cdc05b9a16502d760ead310a689a1538f3fee9618b92011200b9717/fancy_einsum-0.0.3.tar.gz", hash = "sha256:05ca6689999d0949bdaa5320c81117effa13644ec68a200121e93d7ebf3d3356", size = 4916, upload-time = "2022-02-04T01:53:46.028Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/27/14/26fc262ba70976eea9a42e67b05c67aa78a0ee38332d9d094cca5d2c5ec3/fancy_einsum-0.0.3-py3-none-any.whl", hash = "sha256:e0bf33587a61822b0668512ada237a0ffa5662adfb9acfcbb0356ee15a0396a1", size = 6239, upload-time = "2022-02-04T01:53:44.44Z" }, -] - [[package]] name = "fastapi" version = "0.127.0" @@ -902,18 +857,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/8f/8f6f491d595a9e5912971f3f863d81baddccc8a4d0c3749d6a0dd9ffc9df/kiwisolver-1.4.9-cp313-cp313t-win_arm64.whl", hash = "sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c", size = 68646, upload-time = "2025-08-10T21:27:00.52Z" }, ] -[[package]] -name = "markdown-it-py" -version = "4.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mdurl" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, -] - [[package]] name = "markupsafe" version = "3.0.3" @@ -989,15 +932,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, ] -[[package]] -name = "mdurl" -version = "0.1.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, -] - [[package]] name = "mpmath" version = "1.3.0" @@ -1096,18 +1030,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] -[[package]] -name = "nn-decompositions" -version = "0.1.0" -source = { git = "https://github.com/bartbussmann/nn_decompositions.git#b7dc504eb19d68990e7b2a611be264428a5c8896" } -dependencies = [ - { name = "datasets" }, - { name = "torch" }, - { name = "tqdm" }, - { name = "transformer-lens" }, - { name = "wandb" }, -] - [[package]] name = "nodeenv" version = "1.10.0" @@ -1859,19 +1781,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] -[[package]] -name = "rich" -version = "14.3.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "markdown-it-py" }, - { name = "pygments" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, -] - [[package]] name = "rpds-py" version = "0.30.0" @@ -1988,30 +1897,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, ] -[[package]] -name = "sentencepiece" -version = "0.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" }, - { url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" }, - { url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" }, - { url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" }, - { url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" }, - { url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" }, - { url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" }, - { url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" }, - { url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" }, - { url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" }, - { url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" }, - { url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" }, - { url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" }, - { url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" }, - { url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" }, - { url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" }, -] - [[package]] name = "sentry-sdk" version = "2.48.0" @@ -2085,11 +1970,6 @@ dependencies = [ { name = "zstandard" }, ] -[package.optional-dependencies] -nn-decompositions = [ - { name = "nn-decompositions" }, -] - [package.dev-dependencies] dev = [ { name = "basedpyright" }, @@ -2112,7 +1992,6 @@ requires-dist = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib" }, - { name = "nn-decompositions", marker = "extra == 'nn-decompositions'", git = "https://github.com/bartbussmann/nn_decompositions.git" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" }, { name = "orjson" }, @@ -2131,7 +2010,6 @@ requires-dist = [ { name = "wandb-workspaces", specifier = "==0.1.12" }, { name = "zstandard" }, ] -provides-extras = ["nn-decompositions"] [package.metadata.requires-dev] dev = [ @@ -2373,36 +2251,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, ] -[[package]] -name = "transformer-lens" -version = "2.17.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "accelerate" }, - { name = "beartype" }, - { name = "better-abc" }, - { name = "datasets" }, - { name = "einops" }, - { name = "fancy-einsum" }, - { name = "huggingface-hub" }, - { name = "jaxtyping" }, - { name = "pandas" }, - { name = "protobuf" }, - { name = "rich" }, - { name = "sentencepiece" }, - { name = "torch" }, - { name = "tqdm" }, - { name = "transformers" }, - { name = "transformers-stream-generator" }, - { name = "typeguard" }, - { name = "typing-extensions" }, - { name = "wandb" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0c/3b/e34b774ec23e146d307872920fec27e0e7d105507afb87cc155ab0f9954d/transformer_lens-2.17.0.tar.gz", hash = "sha256:93bfd6b6ac65e2b274edae430495663c8c8e99b53fd2c80ac428c4b0398e1272", size = 156918, upload-time = "2026-01-21T23:53:49.046Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/17/137f4e3791261f2cad2feee8f1ed31e0dea21ac7e7dec10ca133de841cbd/transformer_lens-2.17.0-py3-none-any.whl", hash = "sha256:7186cb9d29c4a20c5ea020dacc13579494e35f0b2316ebdbde972879fc4669f7", size = 195183, upload-time = "2026-01-21T23:53:50.076Z" }, -] - [[package]] name = "transformers" version = "4.57.3" @@ -2424,15 +2272,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl", hash = "sha256:c77d353a4851b1880191603d36acb313411d3577f6e2897814f333841f7003f4", size = 11993463, upload-time = "2025-11-25T15:51:26.493Z" }, ] -[[package]] -name = "transformers-stream-generator" -version = "0.0.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "transformers" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/c2/65f13aec253100e1916e9bd7965fe17bde796ebabeb1265f45191ab4ddc0/transformers-stream-generator-0.0.5.tar.gz", hash = "sha256:271deace0abf9c0f83b36db472c8ba61fdc7b04d1bf89d845644acac2795ed57", size = 13033, upload-time = "2024-03-11T14:18:02.079Z" } - [[package]] name = "triton" version = "3.4.0" @@ -2445,18 +2284,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" }, ] -[[package]] -name = "typeguard" -version = "4.5.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/2b/e8/66e25efcc18542d58706ce4e50415710593721aae26e794ab1dec34fb66f/typeguard-4.5.1.tar.gz", hash = "sha256:f6f8ecbbc819c9bc749983cc67c02391e16a9b43b8b27f15dc70ed7c4a007274", size = 80121, upload-time = "2026-02-19T16:09:03.392Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/91/88/b55b3117287a8540b76dbdd87733808d4d01c8067a3b339408c250bb3600/typeguard-4.5.1-py3-none-any.whl", hash = "sha256:44d2bf329d49a244110a090b55f5f91aa82d9a9834ebfd30bcc73651e4a8cc40", size = 36745, upload-time = "2026-02-19T16:09:01.6Z" }, -] - [[package]] name = "typing-extensions" version = "4.15.0" From 65c15e184489631bc401a295c019335eb6cbba37 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 10:47:25 +0000 Subject: [PATCH 11/13] Fix type errors in vendored transcoder code - Split encode() into encode() and encode_dense() to avoid union return type - Add type annotations to autograd.Function forward/backward methods - Type _build_loss_dict return as dict[str, Any] - Assert std is not None in postprocess_output, .grad in weight norm - Use int() for dead_features.sum() passed to min() Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/adapters/transcoders.py | 161 ++++++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 60 deletions(-) diff --git a/spd/adapters/transcoders.py b/spd/adapters/transcoders.py index 8ca0cd93d..6ef4af5ba 100644 --- a/spd/adapters/transcoders.py +++ b/spd/adapters/transcoders.py @@ -3,10 +3,13 @@ Vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). """ +from typing import Any, override + import torch import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from spd.adapters.encoder_config import EncoderConfig @@ -36,7 +39,6 @@ def __init__(self, cfg: EncoderConfig): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_(torch.empty(cfg.dict_size, cfg.output_size)) ) - # Initialize W_dec from W_enc only if input_size == output_size (SAE case) if cfg.input_size == cfg.output_size: self.W_dec.data[:] = self.W_enc.t().data self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) @@ -44,22 +46,18 @@ def __init__(self, cfg: EncoderConfig): self.to(cfg.dtype).to(cfg.device) - def encode( - self, x: torch.Tensor, return_dense: bool = False - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Encode input to sparse activations. Subclasses must implement. + def encode(self, x: Tensor) -> Tensor: + _ = x + raise NotImplementedError - If return_dense=True, also returns the dense (pre-sparsification) activations - as a second element. - """ + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + _ = x raise NotImplementedError - def decode(self, acts: torch.Tensor) -> torch.Tensor: + def decode(self, acts: Tensor) -> Tensor: return acts @ self.W_dec + self.b_dec - def preprocess_input( - self, x: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + def preprocess_input(self, x: Tensor) -> tuple[Tensor, Tensor | None, Tensor | None]: if self.cfg.input_unit_norm: x_mean = x.mean(dim=-1, keepdim=True) x = x - x_mean @@ -68,33 +66,31 @@ def preprocess_input( return x, x_mean, x_std return x, None, None - def postprocess_output( - self, out: torch.Tensor, mean: torch.Tensor | None, std: torch.Tensor | None - ) -> torch.Tensor: + def postprocess_output(self, out: Tensor, mean: Tensor | None, std: Tensor | None) -> Tensor: if self.cfg.input_unit_norm and mean is not None: + assert std is not None return out * std + mean return out - @torch.no_grad() - def make_decoder_weights_and_grad_unit_norm(self): + @torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator] + def make_decoder_weights_and_grad_unit_norm(self) -> None: W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) + assert self.W_dec.grad is not None W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed self.W_dec.grad -= W_dec_grad_proj self.W_dec.data = W_dec_normed - def update_inactive_features(self, acts: torch.Tensor): + def update_inactive_features(self, acts: Tensor) -> None: self.num_batches_not_active += (acts.sum(0) == 0).float() self.num_batches_not_active[acts.sum(0) > 0] = 0 - def _get_auxiliary_loss( - self, y_target: torch.Tensor, y_pred: torch.Tensor, acts: torch.Tensor - ) -> torch.Tensor: + def _get_auxiliary_loss(self, y_target: Tensor, y_pred: Tensor, acts: Tensor) -> Tensor: dead_features = self.num_batches_not_active >= self.cfg.n_batches_to_dead if dead_features.sum() > 0: residual = y_target.float() - y_pred.float() acts_topk_aux = torch.topk( acts[:, dead_features], - min(self.cfg.top_k_aux, dead_features.sum()), + min(self.cfg.top_k_aux, int(dead_features.sum().item())), dim=-1, ) acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( @@ -106,13 +102,13 @@ def _get_auxiliary_loss( def _build_loss_dict( self, - y_target: torch.Tensor, - y_pred: torch.Tensor, - acts: torch.Tensor, - y_pred_out: torch.Tensor, - l0_norm: torch.Tensor, - extra_losses: dict[str, torch.Tensor] | None = None, - ) -> dict: + y_target: Tensor, + y_pred: Tensor, + acts: Tensor, + y_pred_out: Tensor, + l0_norm: Tensor, + extra_losses: dict[str, Tensor] | None = None, + ) -> dict[str, Any]: l2_loss = (y_pred.float() - y_target.float()).pow(2).mean() l1_norm = acts.float().abs().sum(-1).mean() num_dead = (self.num_batches_not_active > self.cfg.n_batches_to_dead).sum() @@ -121,7 +117,7 @@ def _build_loss_dict( if extra_losses: loss = loss + sum(extra_losses.values()) - result = { + result: dict[str, Any] = { "output": y_pred_out, "feature_acts": acts, "num_dead_features": num_dead, @@ -136,13 +132,19 @@ def _build_loss_dict( class VanillaTranscoder(SharedTranscoder): - def encode(self, x: torch.Tensor, return_dense: bool = False): + @override + def encode(self, x: Tensor) -> Tensor: use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size x_enc = x - self.b_dec if use_pre_enc_bias else x - acts = F.relu(x_enc @ self.W_enc + self.b_enc) - return (acts, acts) if return_dense else acts + return F.relu(x_enc @ self.W_enc + self.b_enc) - def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: + acts = self.encode(x) + return acts, acts + + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: x_in, _, _ = self.preprocess_input(x_in) y_target, y_mean, y_std = self.preprocess_input(y_target) @@ -165,19 +167,29 @@ def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: class TopKTranscoder(SharedTranscoder): - def encode(self, x: torch.Tensor, return_dense: bool = False): + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts, self.cfg.top_k, dim=-1) + return torch.zeros_like(acts).scatter(-1, acts_topk.indices, acts_topk.values) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size x_enc = x - self.b_dec if use_pre_enc_bias else x acts = F.relu(x_enc @ self.W_enc + self.b_enc) acts_topk = torch.topk(acts, self.cfg.top_k, dim=-1) acts_sparse = torch.zeros_like(acts).scatter(-1, acts_topk.indices, acts_topk.values) - return (acts_sparse, acts) if return_dense else acts_sparse + return acts_sparse, acts - def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: x_in, _, _ = self.preprocess_input(x_in) y_target, y_mean, y_std = self.preprocess_input(y_target) - acts, acts_dense = self.encode(x_in, return_dense=True) + acts, acts_dense = self.encode_dense(x_in) y_pred = self.decode(acts) y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) @@ -197,7 +209,20 @@ def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: class BatchTopKTranscoder(SharedTranscoder): - def encode(self, x: torch.Tensor, return_dense: bool = False): + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + acts = F.relu(x_enc @ self.W_enc + self.b_enc) + acts_topk = torch.topk(acts.flatten(), self.cfg.top_k * x.shape[0], dim=-1) + return ( + torch.zeros_like(acts.flatten()) + .scatter(-1, acts_topk.indices, acts_topk.values) + .reshape(acts.shape) + ) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size x_enc = x - self.b_dec if use_pre_enc_bias else x acts = F.relu(x_enc @ self.W_enc + self.b_enc) @@ -207,13 +232,14 @@ def encode(self, x: torch.Tensor, return_dense: bool = False): .scatter(-1, acts_topk.indices, acts_topk.values) .reshape(acts.shape) ) - return (acts_sparse, acts) if return_dense else acts_sparse + return acts_sparse, acts - def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: x_in, _, _ = self.preprocess_input(x_in) y_target, y_mean, y_std = self.preprocess_input(y_target) - acts, acts_dense = self.encode(x_in, return_dense=True) + acts, acts_dense = self.encode_dense(x_in) y_pred = self.decode(acts) y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) @@ -234,12 +260,14 @@ def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: class RectangleFunction(autograd.Function): @staticmethod - def forward(ctx, x): + @override + def forward(ctx: Any, x: Tensor) -> Tensor: ctx.save_for_backward(x) return ((x > -0.5) & (x < 0.5)).float() @staticmethod - def backward(ctx, grad_output): + @override + def backward(ctx: Any, grad_output: Tensor) -> Tensor: # pyright: ignore[reportIncompatibleMethodOverride] (x,) = ctx.saved_tensors grad_input = grad_output.clone() grad_input[(x <= -0.5) | (x >= 0.5)] = 0 @@ -248,13 +276,15 @@ def backward(ctx, grad_output): class JumpReLUFunction(autograd.Function): @staticmethod - def forward(ctx, x, log_threshold, bandwidth): + @override + def forward(ctx: Any, x: Tensor, log_threshold: Tensor, bandwidth: float) -> Tensor: ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) threshold = torch.exp(log_threshold) return x * (x > threshold).float() @staticmethod - def backward(ctx, grad_output): + @override + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, Tensor, None]: # pyright: ignore[reportIncompatibleMethodOverride] x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) @@ -273,19 +303,24 @@ def __init__(self, feature_size: int, bandwidth: float, device: str = "cpu"): self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) self.bandwidth = bandwidth - def forward(self, x: torch.Tensor) -> torch.Tensor: - return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) + @override + def forward(self, x: Tensor) -> Tensor: + result = JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) + assert isinstance(result, Tensor) + return result class StepFunction(autograd.Function): @staticmethod - def forward(ctx, x, log_threshold, bandwidth): + @override + def forward(ctx: Any, x: Tensor, log_threshold: Tensor, bandwidth: float) -> Tensor: ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) threshold = torch.exp(log_threshold) return (x > threshold).float() @staticmethod - def backward(ctx, grad_output): + @override + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, Tensor, None]: # pyright: ignore[reportIncompatibleMethodOverride] x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) @@ -301,28 +336,34 @@ def __init__(self, cfg: EncoderConfig): super().__init__(cfg) self.jumprelu = JumpReLUActivation(cfg.dict_size, cfg.bandwidth, cfg.device) - def encode(self, x: torch.Tensor, return_dense: bool = False): + @override + def encode(self, x: Tensor) -> Tensor: + use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size + x_enc = x - self.b_dec if use_pre_enc_bias else x + pre_acts = F.relu(x_enc @ self.W_enc + self.b_enc) + return self.jumprelu(pre_acts) + + @override + def encode_dense(self, x: Tensor) -> tuple[Tensor, Tensor]: use_pre_enc_bias = self.cfg.pre_enc_bias and self.input_size == self.output_size x_enc = x - self.b_dec if use_pre_enc_bias else x pre_acts = F.relu(x_enc @ self.W_enc + self.b_enc) - acts = self.jumprelu(pre_acts) - return (acts, pre_acts) if return_dense else acts + return self.jumprelu(pre_acts), pre_acts - def forward(self, x_in: torch.Tensor, y_target: torch.Tensor) -> dict: + @override + def forward(self, x_in: Tensor, y_target: Tensor) -> dict[str, Any]: x_in, _, _ = self.preprocess_input(x_in) y_target, y_mean, y_std = self.preprocess_input(y_target) - acts, pre_acts = self.encode(x_in, return_dense=True) + acts, pre_acts = self.encode_dense(x_in) y_pred = self.decode(acts) y_pred_out = self.postprocess_output(y_pred, y_mean, y_std) self.update_inactive_features(acts) - l0_norm = ( - StepFunction.apply(pre_acts, self.jumprelu.log_threshold, self.cfg.bandwidth) - .sum(dim=-1) - .mean() - ) + step_result = StepFunction.apply(pre_acts, self.jumprelu.log_threshold, self.cfg.bandwidth) + assert isinstance(step_result, Tensor) + l0_norm = step_result.sum(dim=-1).mean() sparsity_loss = self.cfg.l1_coeff * l0_norm return self._build_loss_dict( y_target, From dedb6b2ba37781be6b44289aa532fc13401dd11a Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 10:56:17 +0000 Subject: [PATCH 12/13] Remove pyright ignores from vendored transcoder code - Use *grad_outputs signature for autograd.Function.backward - Replace @torch.no_grad() decorator with context manager - Credit Bart Bussmann by name in vendored file docstrings Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/adapters/encoder_config.py | 2 +- spd/adapters/transcoders.py | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/spd/adapters/encoder_config.py b/spd/adapters/encoder_config.py index e9a0e2e40..c2625c77e 100644 --- a/spd/adapters/encoder_config.py +++ b/spd/adapters/encoder_config.py @@ -1,6 +1,6 @@ """Encoder configuration for transcoder architectures. -Vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). +Originally by Bart Bussmann, vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). Only EncoderConfig is used; CLTConfig and SAEConfig are omitted. """ diff --git a/spd/adapters/transcoders.py b/spd/adapters/transcoders.py index 6ef4af5ba..46d4e57a0 100644 --- a/spd/adapters/transcoders.py +++ b/spd/adapters/transcoders.py @@ -1,6 +1,6 @@ """Transcoder nn.Module implementations (Vanilla, TopK, BatchTopK, JumpReLU). -Vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). +Originally by Bart Bussmann, vendored from https://github.com/bartbussmann/nn_decompositions (MIT license). """ from typing import Any, override @@ -72,13 +72,13 @@ def postprocess_output(self, out: Tensor, mean: Tensor | None, std: Tensor | Non return out * std + mean return out - @torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator] def make_decoder_weights_and_grad_unit_norm(self) -> None: - W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) - assert self.W_dec.grad is not None - W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed - self.W_dec.grad -= W_dec_grad_proj - self.W_dec.data = W_dec_normed + with torch.no_grad(): + W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) + assert self.W_dec.grad is not None + W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed + self.W_dec.grad -= W_dec_grad_proj + self.W_dec.data = W_dec_normed def update_inactive_features(self, acts: Tensor) -> None: self.num_batches_not_active += (acts.sum(0) == 0).float() @@ -267,7 +267,8 @@ def forward(ctx: Any, x: Tensor) -> Tensor: @staticmethod @override - def backward(ctx: Any, grad_output: Tensor) -> Tensor: # pyright: ignore[reportIncompatibleMethodOverride] + def backward(ctx: Any, *grad_outputs: Tensor) -> Tensor: + (grad_output,) = grad_outputs (x,) = ctx.saved_tensors grad_input = grad_output.clone() grad_input[(x <= -0.5) | (x >= 0.5)] = 0 @@ -284,7 +285,8 @@ def forward(ctx: Any, x: Tensor, log_threshold: Tensor, bandwidth: float) -> Ten @staticmethod @override - def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, Tensor, None]: # pyright: ignore[reportIncompatibleMethodOverride] + def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, Tensor, None]: + (grad_output,) = grad_outputs x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) @@ -320,7 +322,8 @@ def forward(ctx: Any, x: Tensor, log_threshold: Tensor, bandwidth: float) -> Ten @staticmethod @override - def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, Tensor, None]: # pyright: ignore[reportIncompatibleMethodOverride] + def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, Tensor, None]: + (grad_output,) = grad_outputs x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) From 31459c9f0bcb97b16439470d1a298adb98a21111 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 11:27:22 +0000 Subject: [PATCH 13/13] Make adapter_from_id work for transcoders via harvest DB lookup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For non-SPD decomposition IDs (e.g. tc-*), recover the full method config from the harvest DB. This means spd-autointerp, intruder eval, graph-interp, and label scoring all work with transcoders — no config passing needed, just the decomposition ID. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/adapters/__init__.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/spd/adapters/__init__.py b/spd/adapters/__init__.py index 4168b603c..b5dcb0765 100644 --- a/spd/adapters/__init__.py +++ b/spd/adapters/__init__.py @@ -35,18 +35,32 @@ def adapter_from_config(method_config: DecompositionMethodHarvestConfig) -> Deco raise NotImplementedError("MOLT adapter not implemented yet") -def adapter_from_id(id: str) -> DecompositionAdapter: - """Construct an adapter from just a decomposition ID (e.g. "s-abc123"). +def adapter_from_id(decomposition_id: str) -> DecompositionAdapter: + """Construct an adapter from a decomposition ID (e.g. "s-abc123", "tc-1a2b3c4d"). - Only works for methods whose adapter can be constructed from an ID alone (SPD). - For transcoders, use adapter_from_config() with the full method config. + For SPD runs, the ID is sufficient. For other methods, recovers the full + method config from the harvest DB (which is always populated before downstream + steps like autointerp run). """ - from spd.adapters.spd import SPDAdapter + if decomposition_id.startswith("s-"): + from spd.adapters.spd import SPDAdapter - if id.startswith("s-"): - return SPDAdapter(id) + return SPDAdapter(decomposition_id) - raise ValueError( - f"Cannot construct adapter from ID alone: {id!r}. " - f"Use adapter_from_config() with the full method config." + return adapter_from_config(_load_method_config(decomposition_id)) + + +def _load_method_config(decomposition_id: str) -> DecompositionMethodHarvestConfig: + from pydantic import TypeAdapter + + from spd.harvest.repo import HarvestRepo + + repo = HarvestRepo.open_most_recent(decomposition_id) + assert repo is not None, ( + f"No harvest data found for {decomposition_id!r}. " + f"Run spd-harvest first to populate the method config." ) + config_dict = repo.get_config() + method_config_raw = config_dict["method_config"] + ta = TypeAdapter(DecompositionMethodHarvestConfig) + return ta.validate_python(method_config_raw)