From c887c05c1190849c14dbe50abc443bf01898da5a Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Nov 2025 18:30:58 +0000 Subject: [PATCH 01/36] WIP global attribution calcs --- scripts/test_component_acts_caching.py | 315 +++++++++++++++++++++++++ spd/models/component_model.py | 38 +-- spd/models/components.py | 17 +- 3 files changed, 351 insertions(+), 19 deletions(-) create mode 100644 scripts/test_component_acts_caching.py diff --git a/scripts/test_component_acts_caching.py b/scripts/test_component_acts_caching.py new file mode 100644 index 000000000..3a7d7eb26 --- /dev/null +++ b/scripts/test_component_acts_caching.py @@ -0,0 +1,315 @@ +"""Test script for component activation caching. + +This script loads a ComponentModel and dataset, then runs a forward pass with +component activation caching to verify the implementation of the caching plumbing. +""" + +import torch + +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, OutputWithCache, SPDRunInfo +from spd.models.components import make_mask_infos +from spd.utils.general_utils import extract_batch_data + + +def main() -> None: + # Configuration + wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" + device = "cuda" if torch.cuda.is_available() else "cpu" + batch_size = 2 + + print(f"Using device: {device}") + print(f"Loading model from {wandb_path}...") + + # Load the model + run_info = SPDRunInfo.from_path(wandb_path) + config: Config = run_info.config + model = ComponentModel.from_run_info(run_info) + model = model.to(device) + model.eval() + + print("Model loaded successfully!") + print(f"Number of components: {model.C}") + print(f"Target module paths: {model.target_module_paths}") + + # Load the dataset + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig), "Expected LM task config" + + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.train_data_split, # Using train split for now + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, # No need to shuffle for testing + seed=42, + ) + + print(f"\nLoading dataset {dataset_config.name}...") + data_loader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=task_config.buffer_size, + global_seed=42, + ddp_rank=0, + ddp_world_size=1, + ) + + # Get a batch + batch_raw = next(iter(data_loader)) + batch = extract_batch_data(batch_raw).to(device) + print(f"Batch shape: {batch.shape}") + + # Test 1: Forward pass without component replacement, just caching + print("\n" + "=" * 80) + print("Test 1: Forward pass with input caching (no component replacement)") + print("=" * 80) + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + print(f"Output shape: {output_with_cache.output.shape}") + print(f"Number of cached layers: {len(output_with_cache.cache)}") + print(f"Cached layer names: {list(output_with_cache.cache.keys())}") + for name, acts in list(output_with_cache.cache.items())[:3]: + print(f" {name}: {acts.shape}") + + # Test 2: Forward pass with component replacement and component activation caching + print("\n" + "=" * 80) + print("Test 2: Forward pass with component replacement and component activation caching") + print("=" * 80) + + # Calculate causal importances + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + # Create masks for component replacement (use all components with causal importance as mask) + component_masks = ci.lower_leaky + mask_infos = make_mask_infos( + component_masks=component_masks, + routing_masks="all", + ) + + # Forward pass with component replacement and component activation caching + with torch.no_grad(): + comp_output_with_cache: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + + print(f"Output shape: {comp_output_with_cache.output.shape}") + print(f"Number of cached items: {len(comp_output_with_cache.cache)}") + + # The cache should contain entries like "layer_name_pre_detach" and "layer_name_post_detach" + pre_detach_keys = [k for k in comp_output_with_cache.cache if "pre_detach" in k] + post_detach_keys = [k for k in comp_output_with_cache.cache if "post_detach" in k] + + print(f"\nPre-detach cache entries: {len(pre_detach_keys)}") + print(f"Post-detach cache entries: {len(post_detach_keys)}") + + # Show some examples + for name in pre_detach_keys[:3]: + acts = comp_output_with_cache.cache[name] + print(f" {name}: {acts.shape}") + + # Verify that pre_detach and post_detach have the same values but different grad_fn + print("\n" + "-" * 80) + print("Verifying pre_detach and post_detach activations:") + print("-" * 80) + for key in pre_detach_keys[:2]: + layer_name = key.replace("_pre_detach", "") + post_key = f"{layer_name}_post_detach" + + pre_acts = comp_output_with_cache.cache[key] + post_acts = comp_output_with_cache.cache[post_key] + + print(f"\n{layer_name}:") + print(f" Pre-detach shape: {pre_acts.shape}, requires_grad: {pre_acts.requires_grad}") + print( + f" Post-detach shape: {post_acts.shape}, requires_grad: {post_acts.requires_grad}" + ) + print(f" Values match: {torch.allclose(pre_acts, post_acts)}") + print(f" Pre has grad_fn: {pre_acts.grad_fn is not None}") + print(f" Post has grad_fn: {post_acts.grad_fn is not None}") + + # Test 3: Verify cached activations structure + print("\n" + "=" * 80) + print("Test 3: Verify cached activations for attribution graph construction") + print("=" * 80) + + # Forward pass with component replacement and component activation caching (with gradients enabled) + # We use torch.enable_grad() to enable gradients during the forward pass + with torch.enable_grad(): + comp_output_with_cache_grad: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + + # Get the cached activations + cache = comp_output_with_cache_grad.cache + + # Get layer names sorted by depth + layer_names = sorted( + set( + k.replace("_pre_detach", "").replace("_post_detach", "") for k in cache if "detach" in k + ) + ) + + print(f"Found {len(layer_names)} layers") + print(f"\nLayer names (in order): {layer_names[:5]}...{layer_names[-3:]}") + + # Show structure of cached activations + print("\n" + "-" * 80) + print("Cached activation properties:") + print("-" * 80) + for layer_name in layer_names[:3]: # Show first 3 layers + pre_detach = cache[f"{layer_name}_pre_detach"] + post_detach = cache[f"{layer_name}_post_detach"] + + print(f"\n{layer_name}:") + print( + f" pre_detach: shape={pre_detach.shape}, requires_grad={pre_detach.requires_grad}, has_grad_fn={pre_detach.grad_fn is not None}" + ) + print( + f" post_detach: shape={post_detach.shape}, requires_grad={post_detach.requires_grad}, has_grad_fn={post_detach.grad_fn is not None}" + ) + print(f" values_match: {torch.allclose(pre_detach, post_detach)}") + + print("\n" + "-" * 80) + print("Summary:") + print("-" * 80) + print("✓ Component activation caching is working correctly") + print("✓ Both pre_detach and post_detach versions are cached") + print("✓ pre_detach tensors have grad_fn (part of computation graph)") + print("✓ post_detach tensors are LEAF tensors with requires_grad=True") + + # Test 4: Compute cross-layer gradient - test multiple layer pairs + print("\n" + "=" * 80) + print("Test 4: Computing cross-layer gradients (pre_detach w.r.t. earlier post_detach)") + print("=" * 80) + + # Try multiple layer pairs to find connections + test_pairs = [ + ("h.0.mlp.c_fc", "h.1.attn.q_proj"), # MLP to next layer's attention + ("h.0.mlp.down_proj", "h.1.attn.k_proj"), # MLP output to next layer + ("h.0.attn.o_proj", "h.1.attn.q_proj"), # Attention output to next attention + ("h.0.attn.q_proj", "h.1.attn.k_proj"), # Original test + ] + + gradient_found = False + + for source_layer, target_layer in test_pairs: + print(f"\n{'=' * 60}") + print(f"Testing: d({target_layer}_pre_detach) / d({source_layer}_post_detach)") + print("=" * 60) + + source_post_detach = cache[f"{source_layer}_post_detach"] + target_pre_detach = cache[f"{target_layer}_pre_detach"] + + # Select a specific component and position to compute gradient for + batch_idx = 0 + seq_idx = 50 + target_component_idx = 10 + + # Get the scalar value we want to take gradient of (from pre_detach, not post_detach!) + target_value = target_pre_detach[batch_idx, seq_idx, target_component_idx] + + # Try to compute gradient + try: + grads = torch.autograd.grad( + outputs=target_value, + inputs=source_post_detach, + retain_graph=True, + allow_unused=True, + ) + + if grads[0] is not None: + grad = grads[0] + grad_norm = grad.norm().item() + grad_max = grad.abs().max().item() + nonzero_grads = (grad.abs() > 1e-8).sum().item() + + print("✓ GRADIENT FOUND!") + print(f" Gradient norm: {grad_norm:.6f}") + print(f" Gradient max abs: {grad_max:.6f}") + print(f" Non-zero elements (>1e-8): {nonzero_grads}") + + # Show top attributed source components + grad_for_position = grad[batch_idx, seq_idx] # Shape: [C] + top_k = 5 + top_values, top_indices = grad_for_position.abs().topk(top_k) + + print(f"\n Top {top_k} attributed components in {source_layer}[seq={seq_idx}]:") + for i, (idx, val) in enumerate( + zip(top_indices.tolist(), top_values.tolist(), strict=True) + ): + print(f" {i + 1}. Component {idx}: gradient = {val:.6f}") + + gradient_found = True + break # Found one, that's enough for the test + else: + print(" ✗ No gradient (not connected)") + except RuntimeError as e: + print(f" ✗ Error: {e}") + + if not gradient_found: + print(f"\n{'=' * 60}") + print("No gradients found between tested layer pairs.") + print("This might be due to model architecture - will need to investigate further.") + print("=" * 60) + + # Test 5: Summary and explanation + print("\n" + "=" * 80) + print("Test 5: Understanding the gradient flow") + print("=" * 80) + + print("\nWhat just happened:") + print("-" * 60) + print("pre_detach[layer_l] = input[layer_l] @ V[layer_l]") + print("") + print("During the forward pass:") + print("1. h.0.attn.q_proj computes post_detach (detached, requires_grad=True)") + print("2. This post_detach flows through the model (via post_detach @ U)") + print("3. Eventually becomes part of the input to h.1.attn.k_proj") + print("4. h.1.attn.k_proj computes pre_detach = input @ V") + print("") + print("So pre_detach[h.1.attn.k_proj] DOES depend on post_detach[h.0.attn.q_proj]") + print("through the model's computation graph!") + print("") + print("If the gradient computed successfully, this means:") + print("✓ You can directly compute attribution from earlier to later components") + print("✓ The attributions skip intermediate layers (due to detachment at each layer)") + print("✓ This is exactly what you specified in your plan!") + print("") + print("If the gradient is None, it might mean:") + print("✗ The specific layers tested don't have a direct path in the model") + print("✗ The model architecture might skip certain layer connections") + + print("\n" + "=" * 80) + print("Summary:") + print("=" * 80) + print("✓ Component activation caching is working correctly") + print("✓ post_detach tensors are leaf nodes with requires_grad=True") + print("✓ Automatic gradient flow is blocked between layers (as intended)") + print("✓ You can now implement your global attribution graph algorithm by:") + print(" - Caching all post_detach activations in one forward pass") + print(" - Building custom computation paths from earlier to later layers") + print(" - Computing gradients to get direct (non-mediated) attributions") + + print("\n" + "=" * 80) + print("Testing complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/spd/models/component_model.py b/spd/models/component_model.py index e95e150f0..026512940 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -90,10 +90,11 @@ class ComponentModel(LoadableModule): `LlamaForCausalLM`, `AutoModelForCausalLM`) as long as its sub-module names match the patterns you pass in `target_module_patterns`. - Forward passes support optional component replacement and/or input caching: + Forward passes support optional component replacement and/or caching: - No args: Standard forward pass of the target model - With mask_infos: Components replace the specified modules via forward hooks - With cache_type="input": Input activations are cached for the specified modules + - With cache_type="component_acts": Component activations are cached for the specified modules - Both can be used simultaneously for component forward pass with input caching We register components and causal importance functions (ci_fns) as modules in this class in order to have them update @@ -307,7 +308,7 @@ def __call__( self, *args: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["input"], + cache_type: Literal["component_acts", "input"], **kwargs: Any, ) -> OutputWithCache: ... @@ -329,31 +330,30 @@ def forward( self, *args: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["input", "none"] = "none", + cache_type: Literal["component_acts", "input", "none"] = "none", **kwargs: Any, ) -> Tensor | OutputWithCache: """Forward pass with optional component replacement and/or input caching. This method handles the following 4 cases: 1. mask_infos is None and cache_type is "none": Regular forward pass. - 2. mask_infos is None and cache_type is "input": Forward pass with input caching on - all modules in self.target_module_paths. - 3. mask_infos is not None and cache_type is "input": Forward pass with component replacement - and input caching on the modules provided in mask_infos. + 2. mask_infos is None and cache_type is "input" or "component_acts": Forward pass with + caching on all modules in self.target_module_paths. + 3. mask_infos is not None and cache_type is "input" or "component_acts": Forward pass with + component replacement and caching on the modules provided in mask_infos. 4. mask_infos is not None and cache_type is "none": Forward pass with component replacement on the modules provided in mask_infos and no caching. - We use the same _components_and_cache_hook for cases 2, 3, and 4, and don't use any hooks - for case 1. - Args: mask_infos: Dictionary mapping module names to ComponentsMaskInfo. If provided, those modules will be replaced with their components. - cache_type: If "input", cache the inputs to the modules provided in mask_infos. If - mask_infos is None, cache the inputs to all modules in self.target_module_paths. + cache_type: If "input" or "component_acts", cache the inputs or component acts to the + modules provided in mask_infos. If "none", no caching is done. If mask_infos is None, + cache the inputs or component acts to all modules in self.target_module_paths. Returns: - OutputWithCache object if cache_type is "input", otherwise the model output tensor. + OutputWithCache object if cache_type is "input" or "component_acts", otherwise the + model output tensor. """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. @@ -382,7 +382,7 @@ def forward( out = self._extract_output(raw_out) match cache_type: - case "input": + case "input" | "component_acts": return OutputWithCache(output=out, cache=cache) case "none": return out @@ -396,7 +396,7 @@ def _components_and_cache_hook( module_name: str, components: Components | None, mask_info: ComponentsMaskInfo | None, - cache_type: Literal["input", "none"], + cache_type: Literal["component_acts", "input", "none"], cache: dict[str, Tensor], ) -> Any | None: """Unified hook function that handles both component replacement and caching. @@ -409,7 +409,7 @@ def _components_and_cache_hook( module_name: Name of the module in the target model components: Component replacement (if using components) mask_info: Mask information (if using components) - cache_type: Whether to cache the input + cache_type: Whether to cache the component acts, input, or none cache: Cache dictionary to populate (if cache_type is not None) Returns: @@ -420,7 +420,6 @@ def _components_and_cache_hook( assert len(kwargs) == 0, "Expected no keyword arguments" x = args[0] assert isinstance(x, Tensor), "Expected input tensor" - assert cache_type in ["input", "none"], "Expected cache_type to be 'input' or 'none'" if cache_type == "input": cache[module_name] = x @@ -430,11 +429,16 @@ def _components_and_cache_hook( f"Only supports single-tensor outputs, got {type(output)}" ) + component_acts_cache = {} if cache_type == "component_acts" else None components_out = components( x, mask=mask_info.component_mask, weight_delta_and_mask=mask_info.weight_delta_and_mask, + component_acts_cache=component_acts_cache, ) + if component_acts_cache is not None: + for k, v in component_acts_cache.items(): + cache[f"{module_name}_{k}"] = v if mask_info.routing_mask == "all": return components_out diff --git a/spd/models/components.py b/spd/models/components.py index 61c63c3f5..0841b9949 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -190,6 +190,7 @@ def forward( x: Float[Tensor, "... d_in"], mask: Float[Tensor, "... C"] | None = None, weight_delta_and_mask: WeightDeltaAndMask | None = None, + component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = None, ) -> Float[Tensor, "... d_out"]: """Forward pass through V and U matrices. @@ -199,13 +200,18 @@ def forward( weight_delta_and_mask: Optional tuple of tensors containing: 0: the weight differences between the target model and summed component weights 1: mask over the weight delta component for each sample + component_acts_cache: Cache dictionary to populate with component acts Returns: output: The summed output across all components """ component_acts = self.get_inner_acts(x) + if component_acts_cache is not None: + component_acts_cache["pre_detach"] = component_acts + component_acts = component_acts.detach().requires_grad_(True) + component_acts_cache["post_detach"] = component_acts if mask is not None: - component_acts *= mask + component_acts = component_acts * mask out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out") @@ -254,6 +260,7 @@ def forward( x: Int[Tensor, "..."], mask: Float[Tensor, "... C"] | None = None, weight_delta_and_mask: WeightDeltaAndMask | None = None, + component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = None, ) -> Float[Tensor, "... embedding_dim"]: """Forward through the embedding component using indexing instead of one-hot matmul. @@ -263,13 +270,19 @@ def forward( weight_delta_and_mask: Optional tuple of tensors containing: 0: the weight differences between the target model and summed component weights 1: mask over the weight delta component for each sample + component_acts_cache: Cache dictionary to populate with component acts """ assert x.dtype == torch.long, "x must be an integer tensor" component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x) + if component_acts_cache is not None: + component_acts_cache["pre_detach"] = component_acts + component_acts = component_acts.detach().requires_grad_(True) + component_acts_cache["post_detach"] = component_acts + if mask is not None: - component_acts *= mask + component_acts = component_acts * mask out = einops.einsum(component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim") From 975f1fc989085d8558df60e82d8d570a45fcbeff Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 25 Nov 2025 16:08:24 +0000 Subject: [PATCH 02/36] First draft of calculation and plotting of attrs --- scripts/calc_attributions.py | 436 +++++++++++++++++++++++++ scripts/plot_attributions.py | 229 +++++++++++++ scripts/test_component_acts_caching.py | 315 ------------------ 3 files changed, 665 insertions(+), 315 deletions(-) create mode 100644 scripts/calc_attributions.py create mode 100644 scripts/plot_attributions.py delete mode 100644 scripts/test_component_acts_caching.py diff --git a/scripts/calc_attributions.py b/scripts/calc_attributions.py new file mode 100644 index 000000000..f8560dbb4 --- /dev/null +++ b/scripts/calc_attributions.py @@ -0,0 +1,436 @@ +# %% + +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import torch +from jaxtyping import Float +from PIL import Image +from torch import Tensor +from tqdm.auto import tqdm + +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, OutputWithCache, SPDRunInfo +from spd.models.components import make_mask_infos +from spd.plotting import plot_mean_component_cis_both_scales +from spd.utils.general_utils import extract_batch_data + + +# %% +def compute_mean_ci_per_component( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + max_batches: int | None, +) -> dict[str, torch.Tensor]: + """Compute mean causal importance per component over the dataset. + + Args: + model: The ComponentModel to analyze. + data_loader: DataLoader providing batches. + device: Device to run on. + config: SPD config with sampling settings. + max_batches: Maximum number of batches to process. + + Returns: + Dictionary mapping module path -> tensor of shape [C] with mean CI per component. + """ + # Initialize accumulators + ci_sums: dict[str, torch.Tensor] = { + module_name: torch.zeros(model.C, device=device) for module_name in model.components + } + examples_seen: dict[str, int] = {module_name: 0 for module_name in model.components} + + if max_batches is not None: + batch_pbar = tqdm(enumerate(data_loader), desc="Computing mean CI", total=max_batches) + else: + batch_pbar = tqdm(enumerate(data_loader), desc="Computing mean CI") + + for batch_idx, batch_raw in batch_pbar: + if max_batches is not None and batch_idx >= max_batches: + break + + batch = extract_batch_data(batch_raw).to(device) + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + # Accumulate CI values (using lower_leaky as in CIMeanPerComponent) + for module_name, ci_vals in ci.lower_leaky.items(): + n_leading_dims = ci_vals.ndim - 1 + n_examples = ci_vals.shape[:n_leading_dims].numel() + examples_seen[module_name] += n_examples + leading_dim_idxs = tuple(range(n_leading_dims)) + ci_sums[module_name] += ci_vals.sum(dim=leading_dim_idxs) + + # Compute means + mean_cis = { + module_name: ci_sums[module_name] / examples_seen[module_name] + for module_name in model.components + } + + return mean_cis + + +def compute_alive_components( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + max_batches: int | None, + threshold: float, +) -> tuple[dict[str, torch.Tensor], dict[str, list[int]], tuple[Image.Image, Image.Image]]: + """Compute alive components based on mean CI threshold. + + Args: + model: The ComponentModel to analyze. + data_loader: DataLoader providing batches. + device: Device to run on. + config: SPD config with sampling settings. + max_batches: Maximum number of batches to process. + threshold: Minimum mean CI to consider a component alive. + + Returns: + Tuple of: + - mean_cis: Dictionary mapping module path -> tensor of mean CI per component + - alive_indices: Dictionary mapping module path -> list of alive component indices + - images: Tuple of (linear_scale_image, log_scale_image) for verification + """ + mean_cis = compute_mean_ci_per_component(model, data_loader, device, config, max_batches) + alive_indices = {} + for module_name, mean_ci in mean_cis.items(): + alive_mask = mean_ci >= threshold + alive_indices[module_name] = torch.where(alive_mask)[0].tolist() + images = plot_mean_component_cis_both_scales(mean_cis) + + return mean_cis, alive_indices, images + + +def get_valid_pairs( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + n_blocks: int, +) -> list[tuple[str, str]]: + # Get an arbitrary batch + batch_raw = next(iter(data_loader)) + batch = extract_batch_data(batch_raw).to(device) + print(f"Batch shape: {batch.shape}") + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + print(f"Output shape: {output_with_cache.output.shape}") + print(f"Number of cached layers: {len(output_with_cache.cache)}") + print(f"Cached layer names: {list(output_with_cache.cache.keys())}") + + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + # Create masks for component replacement (use all components with causal importance as mask) + component_masks = ci.lower_leaky + mask_infos = make_mask_infos( + component_masks=component_masks, + routing_masks="all", + ) + with torch.enable_grad(): + comp_output_with_cache_grad: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + + cache = comp_output_with_cache_grad.cache + layers = [] + layer_names = [ + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "attn.o_proj", + "mlp.c_fc", + "mlp.down_proj", + ] + for i in range(n_blocks): + layers.extend([f"h.{i}.{layer_name}" for layer_name in layer_names]) + + test_pairs = [] + for in_layer in layers: + for out_layer in layers: + if layers.index(in_layer) < layers.index(out_layer): + test_pairs.append((in_layer, out_layer)) + + valid_pairs = [] + for in_layer, out_layer in test_pairs: + out_pre_detach = cache[f"{out_layer}_pre_detach"] + in_post_detach = cache[f"{in_layer}_post_detach"] + batch_idx = 0 + seq_idx = 50 + target_component_idx = 10 + out_value = out_pre_detach[batch_idx, seq_idx, target_component_idx] + try: + grads = torch.autograd.grad( + outputs=out_value, + inputs=in_post_detach, + retain_graph=True, + allow_unused=True, + ) + assert len(grads) == 1, "Expected 1 gradient" + grad = grads[0] + # torch.autograd.grad returns None for unused inputs when allow_unused=True + has_grad = ( + grad.abs().max().item() > 1e-8 + if grad is not None # pyright: ignore[reportUnnecessaryComparison] + else False + ) + except RuntimeError: + has_grad = False + if has_grad: + valid_pairs.append((in_layer, out_layer)) + return valid_pairs + + +def compute_global_attributions( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + valid_pairs: list[tuple[str, str]], + max_batches: int, + alive_indices: dict[str, list[int]], +) -> dict[tuple[str, str], torch.Tensor]: + """Compute global attributions accumulated over the dataset. + + For each valid layer pair (in_layer, out_layer), computes the mean absolute gradient + of output component activations with respect to input component activations, + averaged over batch, sequence positions, and number of batches. + + Args: + model: The ComponentModel to analyze. + data_loader: DataLoader providing batches. + device: Device to run on. + config: SPD config with sampling settings. + valid_pairs: List of (in_layer, out_layer) pairs to compute attributions for. + max_batches: Maximum number of batches to process. + alive_indices: Dictionary mapping module path -> list of alive component indices. + Returns: + Dictionary mapping (in_layer, out_layer) -> attribution tensor of shape [n_alive_in, n_alive_out] + where attribution[i, j] is the mean absolute gradient from the i-th alive input component to the j-th alive output component. + """ + + # Initialize accumulators for each valid pair + attribution_sums: dict[tuple[str, str], Float[Tensor, "n_alive_in n_alive_out"]] = {} + for pair in valid_pairs: + in_layer, out_layer = pair + n_alive_in = len(alive_indices[in_layer]) + n_alive_out = len(alive_indices[out_layer]) + attribution_sums[(in_layer, out_layer)] = torch.zeros( + n_alive_in, n_alive_out, device=device + ) + + total_samples = 0 # Track total (batch * seq) samples processed + + batch_pbar = tqdm(enumerate(data_loader), desc="Batches", total=max_batches) + for batch_idx, batch_raw in batch_pbar: + if batch_idx >= max_batches: + break + + batch = extract_batch_data(batch_raw).to(device) + + # Forward pass to get pre-weight activations + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + + # Calculate causal importances for masking + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + # Create masks and run forward pass with gradient tracking + component_masks = ci.lower_leaky + mask_infos = make_mask_infos( + component_masks=component_masks, + routing_masks="all", + ) + + with torch.enable_grad(): + comp_output_with_cache: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + + cache = comp_output_with_cache.cache + + # Compute attributions for each valid pair + pair_pbar = tqdm(valid_pairs, desc="Layer pairs", leave=False) + for in_layer, out_layer in pair_pbar: + out_pre_detach = cache[f"{out_layer}_pre_detach"] + in_post_detach = cache[f"{in_layer}_post_detach"] + + # Compute gradients for each output component + # out_pre_detach shape: [batch, seq, n_components] + # in_post_detach shape: [batch, seq, n_components] + batch_attribution = torch.zeros( + len(alive_indices[in_layer]), len(alive_indices[out_layer]), device=device + ) + + for i, c_out in enumerate(alive_indices[out_layer]): + # Sum over batch and seq to get a scalar for this output component + out_sum = out_pre_detach[:, :, c_out].sum() + + grads = torch.autograd.grad( + outputs=out_sum, inputs=in_post_detach, retain_graph=True + )[0] + + assert grads is not None, "Gradient is None" + # grads shape: [batch, seq, n_components] + # Only consider the components that are alive + alive_grads = grads[..., alive_indices[in_layer]] + # Mean absolute gradient over batch and seq for each input component + mean_abs_grad = alive_grads.abs().mean(dim=(0, 1)) # [n_alive_components] + batch_attribution[:, i] = mean_abs_grad + + attribution_sums[(in_layer, out_layer)] += batch_attribution + + total_samples += 1 # Count batches (already averaged over batch/seq within) + + # Average over number of batches + global_attributions = { + pair: attr_sum / total_samples for pair, attr_sum in attribution_sums.items() + } + + print(f"Computed global attributions over {total_samples} batches") + return global_attributions + + +# %% +# Configuration +# wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) +wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L +n_blocks = 2 +batch_size = 512 +n_attribution_batches = 10 +n_alive_calc_batches = 200 +dataset_seed = 0 + +out_dir = Path(__file__).parent / "out" +out_dir.mkdir(parents=True, exist_ok=True) +wandb_id = wandb_path.split("/")[-1] + +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Using device: {device}") +print(f"Loading model from {wandb_path}...") + +# Load the model +run_info = SPDRunInfo.from_path(wandb_path) +config: Config = run_info.config +model = ComponentModel.from_run_info(run_info) +model = model.to(device) +model.eval() + +print("Model loaded successfully!") +print(f"Number of components: {model.C}") +print(f"Target module paths: {model.target_module_paths}") + +# Load the dataset +task_config = config.task_config +assert isinstance(task_config, LMTaskConfig), "Expected LM task config" + +dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.train_data_split, # Using train split for now + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, # No need to shuffle for testing + seed=dataset_seed, +) + +print(f"\nLoading dataset {dataset_config.name}...") +data_loader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=task_config.buffer_size, + global_seed=dataset_seed, + ddp_rank=0, + ddp_world_size=1, +) + +valid_pairs = get_valid_pairs(model, data_loader, device, config, n_blocks) +print(f"Valid layer pairs: {valid_pairs}") +# %% +# Compute alive components based on mean CI threshold +print("\nComputing alive components based on mean CI...") +mean_cis, alive_indices, (img_linear, img_log) = compute_alive_components( + model=model, + data_loader=data_loader, + device=device, + config=config, + max_batches=n_alive_calc_batches, + threshold=1e-6, +) + +# Print summary +print("\nAlive components per layer:") +for module_name, indices in alive_indices.items(): + n_alive = len(indices) + print(f" {module_name}: {n_alive}/{model.C} alive") + +# Save images for verification +img_linear.save(out_dir / f"ci_mean_per_component_linear_{wandb_id}.png") +img_log.save(out_dir / f"ci_mean_per_component_log_{wandb_id}.png") +print( + f"Saved verification images to {out_dir / f'ci_mean_per_component_linear_{wandb_id}.png'} and {out_dir / f'ci_mean_per_component_log_{wandb_id}.png'}" +) +# %% +# Compute global attributions over the dataset +print("\nComputing global attributions...") +global_attributions = compute_global_attributions( + model=model, + data_loader=data_loader, + device=device, + config=config, + valid_pairs=valid_pairs, + max_batches=n_attribution_batches, + alive_indices=alive_indices, +) + +# Print summary statistics +for pair, attr in global_attributions.items(): + print(f"{pair[0]} -> {pair[1]}: mean={attr.mean():.6f}, max={attr.max():.6f}") + +torch.save(global_attributions, out_dir / f"global_attributions_{wandb_id}.pt") + +# %% +# Plot the attribution graph +print("\nPlotting attribution graph...") +out_dir = Path(__file__).parent / "out" +global_attributions = torch.load(out_dir / f"global_attributions_{wandb_id}.pt") +# graph_img = plot_attribution_graph( +# global_attributions=global_attributions, +# alive_indices=alive_indices, +# n_blocks=n_blocks, +# output_path=out_dir / f"attribution_graph_{wandb_id}.png", +# edge_threshold=0.0, +# ) +# print(f"Attribution graph has {sum(len(v) for v in alive_indices.values())} nodes") + +# %% diff --git a/scripts/plot_attributions.py b/scripts/plot_attributions.py new file mode 100644 index 000000000..a855ace41 --- /dev/null +++ b/scripts/plot_attributions.py @@ -0,0 +1,229 @@ +# %% +"""Plot attribution graph from saved global attributions.""" + +from pathlib import Path + +import matplotlib.pyplot as plt +import networkx as nx +import torch + +# Configuration +wandb_id = "c0k3z78g" +n_blocks = 2 +edge_threshold = 0.1 + +# Load saved data +out_dir = Path(__file__).parent / "out" +global_attributions = torch.load(out_dir / f"global_attributions_{wandb_id}.pt") + +# Reconstruct alive_indices from attribution tensor shapes +alive_indices: dict[str, list[int]] = {} +for (in_layer, out_layer), attr in global_attributions.items(): + n_alive_in, n_alive_out = attr.shape + if in_layer not in alive_indices: + alive_indices[in_layer] = list(range(n_alive_in)) + if out_layer not in alive_indices: + alive_indices[out_layer] = list(range(n_alive_out)) + +print(f"Loaded attributions for {len(global_attributions)} layer pairs") +print(f"Total alive components: {sum(len(v) for v in alive_indices.values())}") + +# Count edges +total_edges = sum((attr > edge_threshold).sum().item() for attr in global_attributions.values()) +print(f"Edges to draw (threshold={edge_threshold}): {total_edges:,}") + +# %% +# Plot the attribution graph +print("\nPlotting attribution graph...") + +# Define layer order within a block (network order) +layer_names_in_block = [ + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "attn.o_proj", + "mlp.c_fc", + "mlp.down_proj", +] + +# Build full layer list in network order +all_layers = [] +for block_idx in range(n_blocks): + for layer_name in layer_names_in_block: + all_layers.append(f"h.{block_idx}.{layer_name}") + +# Create graph +G = nx.DiGraph() + +# Add nodes for each (layer, component) pair +node_positions = {} +block_spacing = 6.0 # Vertical spacing between blocks + +# Layer y-offsets within a block: down_proj at top, q/k/v at bottom +layer_y_offsets = { + "mlp.down_proj": 2.5, + "mlp.c_fc": 1.5, + "attn.o_proj": 0.5, + "attn.v_proj": -0.5, + "attn.k_proj": -1.0, + "attn.q_proj": -1.5, +} + +for layer in all_layers: + parts = layer.split(".") + block_idx = int(parts[1]) + layer_name = ".".join(parts[2:]) + + n_alive = len(alive_indices.get(layer, [])) + if n_alive == 0: + continue + + # Block 1 on top (higher y), Block 0 on bottom (lower y) + y_base = block_idx * block_spacing + layer_y_offsets[layer_name] + # X-axis for spreading components horizontally + x_base = 0 + + for comp_idx, local_idx in enumerate(alive_indices.get(layer, [])): + node_id = f"{layer}:{local_idx}" + G.add_node(node_id, layer=layer, component=local_idx) + x = x_base + (comp_idx - n_alive / 2) * 0.15 + y = y_base + node_positions[node_id] = (x, y) + +# Add edges based on attributions +edge_weights = [] +for (in_layer, out_layer), attr_tensor in global_attributions.items(): + in_alive = alive_indices.get(in_layer, []) + out_alive = alive_indices.get(out_layer, []) + + for i, in_comp in enumerate(in_alive): + for j, out_comp in enumerate(out_alive): + weight = attr_tensor[i, j].item() + if weight > edge_threshold: + in_node = f"{in_layer}:{in_comp}" + out_node = f"{out_layer}:{out_comp}" + if in_node in G.nodes and out_node in G.nodes: + G.add_edge(in_node, out_node, weight=weight) + edge_weights.append(weight) + +print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") + +# Create figure (tall layout for vertical block arrangement) +fig, ax = plt.subplots(1, 1, figsize=(12, 14)) + +# Draw nodes grouped by layer +layer_colors = { + "attn.q_proj": "#1f77b4", + "attn.k_proj": "#2ca02c", + "attn.v_proj": "#9467bd", + "attn.o_proj": "#d62728", + "mlp.c_fc": "#ff7f0e", + "mlp.down_proj": "#8c564b", +} + +for layer in all_layers: + parts = layer.split(".") + layer_name = ".".join(parts[2:]) + color = layer_colors.get(layer_name, "#333333") + + layer_nodes = [n for n in G.nodes if G.nodes[n].get("layer") == layer] + if layer_nodes: + pos_subset = {n: node_positions[n] for n in layer_nodes} + nx.draw_networkx_nodes( + G, + pos_subset, + nodelist=layer_nodes, + node_color=color, + node_size=100, + alpha=0.8, + ax=ax, + ) + +# Draw edges batched by weight bucket for performance +if edge_weights: + max_weight = max(edge_weights) + min_weight = min(edge_weights) + + n_buckets = 10 + edge_buckets: list[list[tuple[str, str]]] = [[] for _ in range(n_buckets)] + + for u, v, data in G.edges(data=True): + weight = data.get("weight", 0) + if max_weight > min_weight: + normalized = (weight - min_weight) / (max_weight - min_weight) + else: + normalized = 0.5 + bucket_idx = min(int(normalized * n_buckets), n_buckets - 1) + edge_buckets[bucket_idx].append((u, v)) + + for bucket_idx, bucket_edges in enumerate(edge_buckets): + if not bucket_edges: + continue + normalized = (bucket_idx + 0.5) / n_buckets + width = 0.2 + normalized * 2.0 + alpha = 0.3 + normalized * 0.5 + + nx.draw_networkx_edges( + G, + node_positions, + edgelist=bucket_edges, + width=width, + alpha=alpha, + edge_color="#666666", + arrows=True, + arrowsize=8, + connectionstyle="arc3,rad=0.1", + ax=ax, + ) + +# Add layer labels +for block_idx in range(n_blocks): + # Block 1 on top, Block 0 on bottom + block_y_base = block_idx * block_spacing + for layer_name, y_offset in layer_y_offsets.items(): + short_name = layer_name.split(".")[-1] + ax.annotate( + short_name, + (-1.5, block_y_base + y_offset), + fontsize=8, + ha="right", + va="center", + color="#555555", + ) + + ax.annotate( + f"Block {block_idx}", + (-2.5, block_y_base), + fontsize=10, + ha="center", + va="center", + fontweight="bold", + rotation=90, + ) + +# Add legend +legend_elements = [ + plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markersize=10, label=name) + for name, color in [ + ("q_proj", "#1f77b4"), + ("k_proj", "#2ca02c"), + ("v_proj", "#9467bd"), + ("o_proj", "#d62728"), + ("c_fc", "#ff7f0e"), + ("down_proj", "#8c564b"), + ] +] +ax.legend(handles=legend_elements, loc="upper right", fontsize=8) + +ax.set_title("Global Attribution Graph", fontsize=14, fontweight="bold") +ax.axis("off") +plt.tight_layout() + +# Save +output_path = out_dir / f"attribution_graph_{wandb_id}.png" +fig.savefig(output_path, dpi=150, bbox_inches="tight") +print(f"Saved to {output_path}") + +plt.close(fig) + +# %% diff --git a/scripts/test_component_acts_caching.py b/scripts/test_component_acts_caching.py deleted file mode 100644 index 3a7d7eb26..000000000 --- a/scripts/test_component_acts_caching.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Test script for component activation caching. - -This script loads a ComponentModel and dataset, then runs a forward pass with -component activation caching to verify the implementation of the caching plumbing. -""" - -import torch - -from spd.configs import Config -from spd.data import DatasetConfig, create_data_loader -from spd.experiments.lm.configs import LMTaskConfig -from spd.models.component_model import ComponentModel, OutputWithCache, SPDRunInfo -from spd.models.components import make_mask_infos -from spd.utils.general_utils import extract_batch_data - - -def main() -> None: - # Configuration - wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" - device = "cuda" if torch.cuda.is_available() else "cpu" - batch_size = 2 - - print(f"Using device: {device}") - print(f"Loading model from {wandb_path}...") - - # Load the model - run_info = SPDRunInfo.from_path(wandb_path) - config: Config = run_info.config - model = ComponentModel.from_run_info(run_info) - model = model.to(device) - model.eval() - - print("Model loaded successfully!") - print(f"Number of components: {model.C}") - print(f"Target module paths: {model.target_module_paths}") - - # Load the dataset - task_config = config.task_config - assert isinstance(task_config, LMTaskConfig), "Expected LM task config" - - dataset_config = DatasetConfig( - name=task_config.dataset_name, - hf_tokenizer_path=config.tokenizer_name, - split=task_config.train_data_split, # Using train split for now - n_ctx=task_config.max_seq_len, - is_tokenized=task_config.is_tokenized, - streaming=task_config.streaming, - column_name=task_config.column_name, - shuffle_each_epoch=False, # No need to shuffle for testing - seed=42, - ) - - print(f"\nLoading dataset {dataset_config.name}...") - data_loader, tokenizer = create_data_loader( - dataset_config=dataset_config, - batch_size=batch_size, - buffer_size=task_config.buffer_size, - global_seed=42, - ddp_rank=0, - ddp_world_size=1, - ) - - # Get a batch - batch_raw = next(iter(data_loader)) - batch = extract_batch_data(batch_raw).to(device) - print(f"Batch shape: {batch.shape}") - - # Test 1: Forward pass without component replacement, just caching - print("\n" + "=" * 80) - print("Test 1: Forward pass with input caching (no component replacement)") - print("=" * 80) - - with torch.no_grad(): - output_with_cache: OutputWithCache = model(batch, cache_type="input") - print(f"Output shape: {output_with_cache.output.shape}") - print(f"Number of cached layers: {len(output_with_cache.cache)}") - print(f"Cached layer names: {list(output_with_cache.cache.keys())}") - for name, acts in list(output_with_cache.cache.items())[:3]: - print(f" {name}: {acts.shape}") - - # Test 2: Forward pass with component replacement and component activation caching - print("\n" + "=" * 80) - print("Test 2: Forward pass with component replacement and component activation caching") - print("=" * 80) - - # Calculate causal importances - with torch.no_grad(): - ci = model.calc_causal_importances( - pre_weight_acts=output_with_cache.cache, - sampling=config.sampling, - detach_inputs=False, - ) - - # Create masks for component replacement (use all components with causal importance as mask) - component_masks = ci.lower_leaky - mask_infos = make_mask_infos( - component_masks=component_masks, - routing_masks="all", - ) - - # Forward pass with component replacement and component activation caching - with torch.no_grad(): - comp_output_with_cache: OutputWithCache = model( - batch, - mask_infos=mask_infos, - cache_type="component_acts", - ) - - print(f"Output shape: {comp_output_with_cache.output.shape}") - print(f"Number of cached items: {len(comp_output_with_cache.cache)}") - - # The cache should contain entries like "layer_name_pre_detach" and "layer_name_post_detach" - pre_detach_keys = [k for k in comp_output_with_cache.cache if "pre_detach" in k] - post_detach_keys = [k for k in comp_output_with_cache.cache if "post_detach" in k] - - print(f"\nPre-detach cache entries: {len(pre_detach_keys)}") - print(f"Post-detach cache entries: {len(post_detach_keys)}") - - # Show some examples - for name in pre_detach_keys[:3]: - acts = comp_output_with_cache.cache[name] - print(f" {name}: {acts.shape}") - - # Verify that pre_detach and post_detach have the same values but different grad_fn - print("\n" + "-" * 80) - print("Verifying pre_detach and post_detach activations:") - print("-" * 80) - for key in pre_detach_keys[:2]: - layer_name = key.replace("_pre_detach", "") - post_key = f"{layer_name}_post_detach" - - pre_acts = comp_output_with_cache.cache[key] - post_acts = comp_output_with_cache.cache[post_key] - - print(f"\n{layer_name}:") - print(f" Pre-detach shape: {pre_acts.shape}, requires_grad: {pre_acts.requires_grad}") - print( - f" Post-detach shape: {post_acts.shape}, requires_grad: {post_acts.requires_grad}" - ) - print(f" Values match: {torch.allclose(pre_acts, post_acts)}") - print(f" Pre has grad_fn: {pre_acts.grad_fn is not None}") - print(f" Post has grad_fn: {post_acts.grad_fn is not None}") - - # Test 3: Verify cached activations structure - print("\n" + "=" * 80) - print("Test 3: Verify cached activations for attribution graph construction") - print("=" * 80) - - # Forward pass with component replacement and component activation caching (with gradients enabled) - # We use torch.enable_grad() to enable gradients during the forward pass - with torch.enable_grad(): - comp_output_with_cache_grad: OutputWithCache = model( - batch, - mask_infos=mask_infos, - cache_type="component_acts", - ) - - # Get the cached activations - cache = comp_output_with_cache_grad.cache - - # Get layer names sorted by depth - layer_names = sorted( - set( - k.replace("_pre_detach", "").replace("_post_detach", "") for k in cache if "detach" in k - ) - ) - - print(f"Found {len(layer_names)} layers") - print(f"\nLayer names (in order): {layer_names[:5]}...{layer_names[-3:]}") - - # Show structure of cached activations - print("\n" + "-" * 80) - print("Cached activation properties:") - print("-" * 80) - for layer_name in layer_names[:3]: # Show first 3 layers - pre_detach = cache[f"{layer_name}_pre_detach"] - post_detach = cache[f"{layer_name}_post_detach"] - - print(f"\n{layer_name}:") - print( - f" pre_detach: shape={pre_detach.shape}, requires_grad={pre_detach.requires_grad}, has_grad_fn={pre_detach.grad_fn is not None}" - ) - print( - f" post_detach: shape={post_detach.shape}, requires_grad={post_detach.requires_grad}, has_grad_fn={post_detach.grad_fn is not None}" - ) - print(f" values_match: {torch.allclose(pre_detach, post_detach)}") - - print("\n" + "-" * 80) - print("Summary:") - print("-" * 80) - print("✓ Component activation caching is working correctly") - print("✓ Both pre_detach and post_detach versions are cached") - print("✓ pre_detach tensors have grad_fn (part of computation graph)") - print("✓ post_detach tensors are LEAF tensors with requires_grad=True") - - # Test 4: Compute cross-layer gradient - test multiple layer pairs - print("\n" + "=" * 80) - print("Test 4: Computing cross-layer gradients (pre_detach w.r.t. earlier post_detach)") - print("=" * 80) - - # Try multiple layer pairs to find connections - test_pairs = [ - ("h.0.mlp.c_fc", "h.1.attn.q_proj"), # MLP to next layer's attention - ("h.0.mlp.down_proj", "h.1.attn.k_proj"), # MLP output to next layer - ("h.0.attn.o_proj", "h.1.attn.q_proj"), # Attention output to next attention - ("h.0.attn.q_proj", "h.1.attn.k_proj"), # Original test - ] - - gradient_found = False - - for source_layer, target_layer in test_pairs: - print(f"\n{'=' * 60}") - print(f"Testing: d({target_layer}_pre_detach) / d({source_layer}_post_detach)") - print("=" * 60) - - source_post_detach = cache[f"{source_layer}_post_detach"] - target_pre_detach = cache[f"{target_layer}_pre_detach"] - - # Select a specific component and position to compute gradient for - batch_idx = 0 - seq_idx = 50 - target_component_idx = 10 - - # Get the scalar value we want to take gradient of (from pre_detach, not post_detach!) - target_value = target_pre_detach[batch_idx, seq_idx, target_component_idx] - - # Try to compute gradient - try: - grads = torch.autograd.grad( - outputs=target_value, - inputs=source_post_detach, - retain_graph=True, - allow_unused=True, - ) - - if grads[0] is not None: - grad = grads[0] - grad_norm = grad.norm().item() - grad_max = grad.abs().max().item() - nonzero_grads = (grad.abs() > 1e-8).sum().item() - - print("✓ GRADIENT FOUND!") - print(f" Gradient norm: {grad_norm:.6f}") - print(f" Gradient max abs: {grad_max:.6f}") - print(f" Non-zero elements (>1e-8): {nonzero_grads}") - - # Show top attributed source components - grad_for_position = grad[batch_idx, seq_idx] # Shape: [C] - top_k = 5 - top_values, top_indices = grad_for_position.abs().topk(top_k) - - print(f"\n Top {top_k} attributed components in {source_layer}[seq={seq_idx}]:") - for i, (idx, val) in enumerate( - zip(top_indices.tolist(), top_values.tolist(), strict=True) - ): - print(f" {i + 1}. Component {idx}: gradient = {val:.6f}") - - gradient_found = True - break # Found one, that's enough for the test - else: - print(" ✗ No gradient (not connected)") - except RuntimeError as e: - print(f" ✗ Error: {e}") - - if not gradient_found: - print(f"\n{'=' * 60}") - print("No gradients found between tested layer pairs.") - print("This might be due to model architecture - will need to investigate further.") - print("=" * 60) - - # Test 5: Summary and explanation - print("\n" + "=" * 80) - print("Test 5: Understanding the gradient flow") - print("=" * 80) - - print("\nWhat just happened:") - print("-" * 60) - print("pre_detach[layer_l] = input[layer_l] @ V[layer_l]") - print("") - print("During the forward pass:") - print("1. h.0.attn.q_proj computes post_detach (detached, requires_grad=True)") - print("2. This post_detach flows through the model (via post_detach @ U)") - print("3. Eventually becomes part of the input to h.1.attn.k_proj") - print("4. h.1.attn.k_proj computes pre_detach = input @ V") - print("") - print("So pre_detach[h.1.attn.k_proj] DOES depend on post_detach[h.0.attn.q_proj]") - print("through the model's computation graph!") - print("") - print("If the gradient computed successfully, this means:") - print("✓ You can directly compute attribution from earlier to later components") - print("✓ The attributions skip intermediate layers (due to detachment at each layer)") - print("✓ This is exactly what you specified in your plan!") - print("") - print("If the gradient is None, it might mean:") - print("✗ The specific layers tested don't have a direct path in the model") - print("✗ The model architecture might skip certain layer connections") - - print("\n" + "=" * 80) - print("Summary:") - print("=" * 80) - print("✓ Component activation caching is working correctly") - print("✓ post_detach tensors are leaf nodes with requires_grad=True") - print("✓ Automatic gradient flow is blocked between layers (as intended)") - print("✓ You can now implement your global attribution graph algorithm by:") - print(" - Caching all post_detach activations in one forward pass") - print(" - Building custom computation paths from earlier to later layers") - print(" - Computing gradients to get direct (non-mediated) attributions") - - print("\n" + "=" * 80) - print("Testing complete!") - print("=" * 80) - - -if __name__ == "__main__": - main() From 8e40188bbdf6033f5aa930a208f8cd5d21b3f9df Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 25 Nov 2025 16:09:48 +0000 Subject: [PATCH 03/36] Make ci_mean alive threshold a hyperparam --- scripts/calc_attributions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/calc_attributions.py b/scripts/calc_attributions.py index f8560dbb4..50dd1b4ea 100644 --- a/scripts/calc_attributions.py +++ b/scripts/calc_attributions.py @@ -327,6 +327,7 @@ def compute_global_attributions( batch_size = 512 n_attribution_batches = 10 n_alive_calc_batches = 200 +ci_mean_alive_threshold = 1e-6 dataset_seed = 0 out_dir = Path(__file__).parent / "out" @@ -385,7 +386,7 @@ def compute_global_attributions( device=device, config=config, max_batches=n_alive_calc_batches, - threshold=1e-6, + threshold=ci_mean_alive_threshold, ) # Print summary From 8530335005373b24038a58d1e0702f41e950212a Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 26 Nov 2025 11:11:10 +0000 Subject: [PATCH 04/36] Misc cleaning --- ...butions.py => calc_global_attributions.py} | 112 ++++++++++++------ ...butions.py => plot_global_attributions.py} | 83 +++++++------ 2 files changed, 113 insertions(+), 82 deletions(-) rename scripts/{calc_attributions.py => calc_global_attributions.py} (80%) rename scripts/{plot_attributions.py => plot_global_attributions.py} (75%) diff --git a/scripts/calc_attributions.py b/scripts/calc_global_attributions.py similarity index 80% rename from scripts/calc_attributions.py rename to scripts/calc_global_attributions.py index 50dd1b4ea..4df862c69 100644 --- a/scripts/calc_attributions.py +++ b/scripts/calc_global_attributions.py @@ -1,5 +1,7 @@ # %% +import gzip +import json from collections.abc import Iterable from pathlib import Path from typing import Any @@ -26,7 +28,7 @@ def compute_mean_ci_per_component( device: str, config: Config, max_batches: int | None, -) -> dict[str, torch.Tensor]: +) -> dict[str, Tensor]: """Compute mean causal importance per component over the dataset. Args: @@ -40,7 +42,7 @@ def compute_mean_ci_per_component( Dictionary mapping module path -> tensor of shape [C] with mean CI per component. """ # Initialize accumulators - ci_sums: dict[str, torch.Tensor] = { + ci_sums: dict[str, Tensor] = { module_name: torch.zeros(model.C, device=device) for module_name in model.components } examples_seen: dict[str, int] = {module_name: 0 for module_name in model.components} @@ -88,7 +90,7 @@ def compute_alive_components( config: Config, max_batches: int | None, threshold: float, -) -> tuple[dict[str, torch.Tensor], dict[str, list[int]], tuple[Image.Image, Image.Image]]: +) -> tuple[dict[str, Tensor], dict[str, list[int]], tuple[Image.Image, Image.Image]]: """Compute alive components based on mean CI threshold. Args: @@ -210,7 +212,7 @@ def compute_global_attributions( valid_pairs: list[tuple[str, str]], max_batches: int, alive_indices: dict[str, list[int]], -) -> dict[tuple[str, str], torch.Tensor]: +) -> dict[tuple[str, str], Tensor]: """Compute global attributions accumulated over the dataset. For each valid layer pair (in_layer, out_layer), computes the mean absolute gradient @@ -280,53 +282,61 @@ def compute_global_attributions( # Compute attributions for each valid pair pair_pbar = tqdm(valid_pairs, desc="Layer pairs", leave=False) for in_layer, out_layer in pair_pbar: - out_pre_detach = cache[f"{out_layer}_pre_detach"] - in_post_detach = cache[f"{in_layer}_post_detach"] + out_pre_detach: Float[Tensor, "b s n_components"] = cache[f"{out_layer}_pre_detach"] + in_post_detach: Float[Tensor, "b s n_components"] = cache[f"{in_layer}_post_detach"] - # Compute gradients for each output component - # out_pre_detach shape: [batch, seq, n_components] - # in_post_detach shape: [batch, seq, n_components] batch_attribution = torch.zeros( len(alive_indices[in_layer]), len(alive_indices[out_layer]), device=device ) - for i, c_out in enumerate(alive_indices[out_layer]): - # Sum over batch and seq to get a scalar for this output component - out_sum = out_pre_detach[:, :, c_out].sum() + alive_out = alive_indices[out_layer] + # Detach CI weights - we only need values, not gradients through them + ci_weights = ci.lower_leaky[out_layer].detach() + for i, c_out in enumerate(alive_out): + # Sum over batch and seq, weighted by the out ci values + out_sum = (out_pre_detach[..., c_out] * ci_weights[..., c_out]).sum() - grads = torch.autograd.grad( + grads: Float[Tensor, "b s n_components"] = torch.autograd.grad( outputs=out_sum, inputs=in_post_detach, retain_graph=True )[0] assert grads is not None, "Gradient is None" - # grads shape: [batch, seq, n_components] - # Only consider the components that are alive - alive_grads = grads[..., alive_indices[in_layer]] - # Mean absolute gradient over batch and seq for each input component - mean_abs_grad = alive_grads.abs().mean(dim=(0, 1)) # [n_alive_components] - batch_attribution[:, i] = mean_abs_grad + # Detach in_post_detach for the multiplication - grads already captures the gradient info + raw_attributions: Float[Tensor, "b s n_components"] = ( + grads * in_post_detach.detach() + ) + alive_attributions: Float[Tensor, "b s n_alive_in"] = raw_attributions[ + ..., alive_indices[in_layer] + ] + mean_abs_attributions: Float[Tensor, " n_alive_in"] = alive_attributions.abs().mean( + dim=(0, 1) + ) + batch_attribution[:, i] = mean_abs_attributions attribution_sums[(in_layer, out_layer)] += batch_attribution total_samples += 1 # Count batches (already averaged over batch/seq within) - # Average over number of batches + # Average over number of samples global_attributions = { pair: attr_sum / total_samples for pair, attr_sum in attribution_sums.items() } - print(f"Computed global attributions over {total_samples} batches") + print(f"Computed global attributions over {total_samples} samples") return global_attributions # %% # Configuration # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) -wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L -n_blocks = 2 -batch_size = 512 -n_attribution_batches = 10 -n_alive_calc_batches = 200 +# wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L +wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L +n_blocks = 1 +batch_size = 600 +# n_attribution_batches = 20 +n_attribution_batches = 2 +n_alive_calc_batches = 50 +# n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 dataset_seed = 0 @@ -418,20 +428,44 @@ def compute_global_attributions( for pair, attr in global_attributions.items(): print(f"{pair[0]} -> {pair[1]}: mean={attr.mean():.6f}, max={attr.max():.6f}") -torch.save(global_attributions, out_dir / f"global_attributions_{wandb_id}.pt") - # %% -# Plot the attribution graph -print("\nPlotting attribution graph...") +# Save attributions in both PyTorch and JSON formats +print("\nSaving attribution data...") out_dir = Path(__file__).parent / "out" -global_attributions = torch.load(out_dir / f"global_attributions_{wandb_id}.pt") -# graph_img = plot_attribution_graph( -# global_attributions=global_attributions, -# alive_indices=alive_indices, -# n_blocks=n_blocks, -# output_path=out_dir / f"attribution_graph_{wandb_id}.png", -# edge_threshold=0.0, -# ) -# print(f"Attribution graph has {sum(len(v) for v in alive_indices.values())} nodes") + +# Save PyTorch format +pt_path = out_dir / f"global_attributions_{wandb_id}.pt" +torch.save(global_attributions, pt_path) +print(f"Saved PyTorch format to {pt_path}") + +# Convert and save JSON format for web visualization +attributions_json = {} +for (in_layer, out_layer), attr_tensor in global_attributions.items(): + key = f"('{in_layer}', '{out_layer}')" + # Keep full precision - just convert to list + attributions_json[key] = attr_tensor.cpu().tolist() + +json_data = { + "n_blocks": n_blocks, + "attributions": attributions_json, + "alive_indices": alive_indices, +} + +json_path = out_dir / f"global_attributions_{wandb_id}.json" + +# Write JSON with compact formatting +with open(json_path, "w") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) + +# Also save a compressed version for very large files +gz_path = out_dir / f"global_attributions_{wandb_id}.json.gz" +with gzip.open(gz_path, "wt", encoding="utf-8") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) + +print(f"Saved JSON format to {json_path}") +print(f"Saved compressed format to {gz_path}") +print(f" - {len(attributions_json)} layer pairs") +print(f" - {sum(len(v) for v in alive_indices.values())} total alive components") +print(f"\nTo visualize: Open scripts/plot_attributions.html and load {json_path}") # %% diff --git a/scripts/plot_attributions.py b/scripts/plot_global_attributions.py similarity index 75% rename from scripts/plot_attributions.py rename to scripts/plot_global_attributions.py index a855ace41..2f0c6f6b5 100644 --- a/scripts/plot_attributions.py +++ b/scripts/plot_global_attributions.py @@ -8,9 +8,11 @@ import torch # Configuration -wandb_id = "c0k3z78g" -n_blocks = 2 -edge_threshold = 0.1 +# wandb_id = "c0k3z78g" # ss_gpt2_simple-2L +# n_blocks = 2 +wandb_id = "8ynfbr38" # ss_gpt2_simple-1L +n_blocks = 1 +edge_threshold = 1e-1 # Load saved data out_dir = Path(__file__).parent / "out" @@ -28,9 +30,15 @@ print(f"Loaded attributions for {len(global_attributions)} layer pairs") print(f"Total alive components: {sum(len(v) for v in alive_indices.values())}") -# Count edges -total_edges = sum((attr > edge_threshold).sum().item() for attr in global_attributions.values()) -print(f"Edges to draw (threshold={edge_threshold}): {total_edges:,}") +# Count edges before and after thresholding +total_edges = sum(attr.numel() for attr in global_attributions.values()) +print(f"Total edges: {total_edges:,}") +thresholds = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] +for threshold in thresholds: + total_edges_threshold = sum( + (attr > threshold).sum().item() for attr in global_attributions.values() + ) + print(f"Edges > {threshold}: {total_edges_threshold:,}") # %% # Plot the attribution graph @@ -59,14 +67,25 @@ node_positions = {} block_spacing = 6.0 # Vertical spacing between blocks -# Layer y-offsets within a block: down_proj at top, q/k/v at bottom +# Layer y-offsets within a block: down_proj at top, q/k/v at same level at bottom +# q_proj, k_proj, v_proj are placed side by side since they never connect to each other layer_y_offsets = { - "mlp.down_proj": 2.5, - "mlp.c_fc": 1.5, - "attn.o_proj": 0.5, - "attn.v_proj": -0.5, + "mlp.down_proj": 2.0, + "mlp.c_fc": 1.0, + "attn.o_proj": 0.0, + "attn.v_proj": -1.0, # Same y-level for q/k/v "attn.k_proj": -1.0, - "attn.q_proj": -1.5, + "attn.q_proj": -1.0, +} + +# X-offsets for q/k/v to place them side by side with much more spacing +layer_x_offsets = { + "mlp.down_proj": 0.0, + "mlp.c_fc": 0.0, + "attn.o_proj": 0.0, + "attn.q_proj": -20.0, # Left (much more spacing) + "attn.k_proj": 0.0, # Center + "attn.v_proj": 20.0, # Right (much more spacing) } for layer in all_layers: @@ -80,13 +99,14 @@ # Block 1 on top (higher y), Block 0 on bottom (lower y) y_base = block_idx * block_spacing + layer_y_offsets[layer_name] - # X-axis for spreading components horizontally - x_base = 0 + # X-axis base depends on layer type (q/k/v are offset) + x_base = layer_x_offsets[layer_name] for comp_idx, local_idx in enumerate(alive_indices.get(layer, [])): node_id = f"{layer}:{local_idx}" G.add_node(node_id, layer=layer, component=local_idx) - x = x_base + (comp_idx - n_alive / 2) * 0.15 + # Increase spacing between nodes from 0.15 to 0.25 for less overlap + x = x_base + (comp_idx - n_alive / 2) * 0.25 y = y_base node_positions[node_id] = (x, y) @@ -108,8 +128,8 @@ print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") -# Create figure (tall layout for vertical block arrangement) -fig, ax = plt.subplots(1, 1, figsize=(12, 14)) +# Create figure (extra wide to accommodate q/k/v side by side with large spacing) +fig, ax = plt.subplots(1, 1, figsize=(32, 12)) # Draw nodes grouped by layer layer_colors = { @@ -176,31 +196,6 @@ ax=ax, ) -# Add layer labels -for block_idx in range(n_blocks): - # Block 1 on top, Block 0 on bottom - block_y_base = block_idx * block_spacing - for layer_name, y_offset in layer_y_offsets.items(): - short_name = layer_name.split(".")[-1] - ax.annotate( - short_name, - (-1.5, block_y_base + y_offset), - fontsize=8, - ha="right", - va="center", - color="#555555", - ) - - ax.annotate( - f"Block {block_idx}", - (-2.5, block_y_base), - fontsize=10, - ha="center", - va="center", - fontweight="bold", - rotation=90, - ) - # Add legend legend_elements = [ plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markersize=10, label=name) @@ -220,7 +215,9 @@ plt.tight_layout() # Save -output_path = out_dir / f"attribution_graph_{wandb_id}.png" +# Make an edge threshold string in scientific notation which doesn't include decimal places +edge_threshold_str = f"{edge_threshold:.1e}".replace(".0", "") +output_path = out_dir / f"attribution_graph_{wandb_id}_edge_threshold_{edge_threshold_str}.png" fig.savefig(output_path, dpi=150, bbox_inches="tight") print(f"Saved to {output_path}") From 4c45678efbbe44256b1edbd9d31d6ff80b889416 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 26 Nov 2025 14:06:45 +0000 Subject: [PATCH 05/36] Add naive global attribution calc with for loops --- scripts/calc_global_attributions.py | 77 ++++++++++++++++------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/scripts/calc_global_attributions.py b/scripts/calc_global_attributions.py index 4df862c69..074392451 100644 --- a/scripts/calc_global_attributions.py +++ b/scripts/calc_global_attributions.py @@ -212,6 +212,7 @@ def compute_global_attributions( valid_pairs: list[tuple[str, str]], max_batches: int, alive_indices: dict[str, list[int]], + ci_threshold: float, ) -> dict[tuple[str, str], Tensor]: """Compute global attributions accumulated over the dataset. @@ -227,6 +228,7 @@ def compute_global_attributions( valid_pairs: List of (in_layer, out_layer) pairs to compute attributions for. max_batches: Maximum number of batches to process. alive_indices: Dictionary mapping module path -> list of alive component indices. + ci_threshold: Threshold for considering a component for the attribution calculation. Returns: Dictionary mapping (in_layer, out_layer) -> attribution tensor of shape [n_alive_in, n_alive_out] where attribution[i, j] is the mean absolute gradient from the i-th alive input component to the j-th alive output component. @@ -249,7 +251,10 @@ def compute_global_attributions( if batch_idx >= max_batches: break - batch = extract_batch_data(batch_raw).to(device) + batch: Float[Tensor, "b s C"] = extract_batch_data(batch_raw).to(device) + + batch_size, n_seq = batch.shape + total_samples += batch_size * n_seq # Forward pass to get pre-weight activations with torch.no_grad(): @@ -280,46 +285,48 @@ def compute_global_attributions( cache = comp_output_with_cache.cache # Compute attributions for each valid pair - pair_pbar = tqdm(valid_pairs, desc="Layer pairs", leave=False) - for in_layer, out_layer in pair_pbar: - out_pre_detach: Float[Tensor, "b s n_components"] = cache[f"{out_layer}_pre_detach"] - in_post_detach: Float[Tensor, "b s n_components"] = cache[f"{in_layer}_post_detach"] + for in_layer, out_layer in tqdm(valid_pairs, desc="Layer pairs", leave=False): + out_pre_detach: Float[Tensor, "b s C"] = cache[f"{out_layer}_pre_detach"] + weighted_out_pre_detach = out_pre_detach * ci.lower_leaky[out_layer].detach() + in_post_detach: Float[Tensor, "b s C"] = cache[f"{in_layer}_post_detach"] batch_attribution = torch.zeros( len(alive_indices[in_layer]), len(alive_indices[out_layer]), device=device ) - alive_out = alive_indices[out_layer] - # Detach CI weights - we only need values, not gradients through them - ci_weights = ci.lower_leaky[out_layer].detach() - for i, c_out in enumerate(alive_out): - # Sum over batch and seq, weighted by the out ci values - out_sum = (out_pre_detach[..., c_out] * ci_weights[..., c_out]).sum() - - grads: Float[Tensor, "b s n_components"] = torch.autograd.grad( - outputs=out_sum, inputs=in_post_detach, retain_graph=True - )[0] - - assert grads is not None, "Gradient is None" - # Detach in_post_detach for the multiplication - grads already captures the gradient info - raw_attributions: Float[Tensor, "b s n_components"] = ( - grads * in_post_detach.detach() - ) - alive_attributions: Float[Tensor, "b s n_alive_in"] = raw_attributions[ - ..., alive_indices[in_layer] - ] - mean_abs_attributions: Float[Tensor, " n_alive_in"] = alive_attributions.abs().mean( - dim=(0, 1) - ) - batch_attribution[:, i] = mean_abs_attributions + alive_out: list[int] = alive_indices[out_layer] + c_pbar = tqdm( + enumerate(alive_out), desc="Components", leave=False, total=len(alive_out) + ) + for c, c_idx in c_pbar: + n_grads_computed = 0 + for s in range(n_seq): + for b in range(batch_size): + if ci.lower_leaky[out_layer][b, s, c_idx] <= ci_threshold: + continue + # TODO: Handle the case with o_proj in numerator and other attn in denominator + out_value = weighted_out_pre_detach[b, s, c_idx] + grads: Float[Tensor, " C"] = torch.autograd.grad( + outputs=out_value, + inputs=in_post_detach, + retain_graph=True, + allow_unused=True, + )[0] + assert grads is not None, "Gradient is None" + with torch.no_grad(): + act_weighted_grads: Float[Tensor, " C"] = ( + grads[b, s, :] + * in_post_detach[b, s, :] + * ci.lower_leaky[in_layer][b, s, :] + )[alive_indices[in_layer]].pow(2) + batch_attribution[:, c] += act_weighted_grads + n_grads_computed += 1 + tqdm.write(f"Computed {n_grads_computed} gradients for {in_layer} -> {out_layer}") attribution_sums[(in_layer, out_layer)] += batch_attribution - total_samples += 1 # Count batches (already averaged over batch/seq within) - - # Average over number of samples global_attributions = { - pair: attr_sum / total_samples for pair, attr_sum in attribution_sums.items() + pair: (attr_sum / total_samples).sqrt() for pair, attr_sum in attribution_sums.items() } print(f"Computed global attributions over {total_samples} samples") @@ -332,12 +339,13 @@ def compute_global_attributions( # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L n_blocks = 1 -batch_size = 600 +batch_size = 20 # n_attribution_batches = 20 n_attribution_batches = 2 -n_alive_calc_batches = 50 +n_alive_calc_batches = 5 # n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 +ci_attribution_threshold = 1e-3 dataset_seed = 0 out_dir = Path(__file__).parent / "out" @@ -422,6 +430,7 @@ def compute_global_attributions( valid_pairs=valid_pairs, max_batches=n_attribution_batches, alive_indices=alive_indices, + ci_threshold=ci_mean_alive_threshold, ) # Print summary statistics From 03dd761f4c3c439867ab9746fd110a792d813db1 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 26 Nov 2025 16:08:24 +0000 Subject: [PATCH 06/36] Speedups --- scripts/calc_global_attributions.py | 128 ++++++++++++++++++++-------- 1 file changed, 94 insertions(+), 34 deletions(-) diff --git a/scripts/calc_global_attributions.py b/scripts/calc_global_attributions.py index 074392451..dbaf4a4fc 100644 --- a/scripts/calc_global_attributions.py +++ b/scripts/calc_global_attributions.py @@ -21,6 +21,23 @@ from spd.utils.general_utils import extract_batch_data +def is_qkv_to_o_pair(in_layer: str, out_layer: str) -> bool: + """Check if pair requires per-sequence-position gradient computation. + + For q/k/v → o_proj within the same attention block, output at s_out + has gradients w.r.t. inputs at all s_in ≤ s_out (causal attention). + """ + in_is_qkv = any(x in in_layer for x in ["q_proj", "k_proj", "v_proj"]) + out_is_o = "o_proj" in out_layer + if not (in_is_qkv and out_is_o): + return False + + # Check same attention block: "h.{idx}.attn.{proj}" + in_block = in_layer.split(".")[1] + out_block = out_layer.split(".")[1] + return in_block == out_block + + # %% def compute_mean_ci_per_component( model: ComponentModel, @@ -204,6 +221,7 @@ def get_valid_pairs( return valid_pairs +# @profile def compute_global_attributions( model: ComponentModel, data_loader: Iterable[dict[str, Any]], @@ -212,7 +230,7 @@ def compute_global_attributions( valid_pairs: list[tuple[str, str]], max_batches: int, alive_indices: dict[str, list[int]], - ci_threshold: float, + ci_attribution_threshold: float, ) -> dict[tuple[str, str], Tensor]: """Compute global attributions accumulated over the dataset. @@ -228,14 +246,16 @@ def compute_global_attributions( valid_pairs: List of (in_layer, out_layer) pairs to compute attributions for. max_batches: Maximum number of batches to process. alive_indices: Dictionary mapping module path -> list of alive component indices. - ci_threshold: Threshold for considering a component for the attribution calculation. + ci_attribution_threshold: Threshold for considering a component for the attribution calculation. Returns: Dictionary mapping (in_layer, out_layer) -> attribution tensor of shape [n_alive_in, n_alive_out] where attribution[i, j] is the mean absolute gradient from the i-th alive input component to the j-th alive output component. """ # Initialize accumulators for each valid pair + # Track samples separately per pair since attention pairs aggregate differently attribution_sums: dict[tuple[str, str], Float[Tensor, "n_alive_in n_alive_out"]] = {} + samples_per_pair: dict[tuple[str, str], int] = {} for pair in valid_pairs: in_layer, out_layer = pair n_alive_in = len(alive_indices[in_layer]) @@ -243,8 +263,7 @@ def compute_global_attributions( attribution_sums[(in_layer, out_layer)] = torch.zeros( n_alive_in, n_alive_out, device=device ) - - total_samples = 0 # Track total (batch * seq) samples processed + samples_per_pair[(in_layer, out_layer)] = 0 batch_pbar = tqdm(enumerate(data_loader), desc="Batches", total=max_batches) for batch_idx, batch_raw in batch_pbar: @@ -254,7 +273,6 @@ def compute_global_attributions( batch: Float[Tensor, "b s C"] = extract_batch_data(batch_raw).to(device) batch_size, n_seq = batch.shape - total_samples += batch_size * n_seq # Forward pass to get pre-weight activations with torch.no_grad(): @@ -287,49 +305,91 @@ def compute_global_attributions( # Compute attributions for each valid pair for in_layer, out_layer in tqdm(valid_pairs, desc="Layer pairs", leave=False): out_pre_detach: Float[Tensor, "b s C"] = cache[f"{out_layer}_pre_detach"] - weighted_out_pre_detach = out_pre_detach * ci.lower_leaky[out_layer].detach() in_post_detach: Float[Tensor, "b s C"] = cache[f"{in_layer}_post_detach"] - batch_attribution = torch.zeros( - len(alive_indices[in_layer]), len(alive_indices[out_layer]), device=device - ) - alive_out: list[int] = alive_indices[out_layer] - c_pbar = tqdm( + alive_in: list[int] = alive_indices[in_layer] + batch_attribution = torch.zeros(len(alive_in), len(alive_out), device=device) + + ci_out = ci.lower_leaky[out_layer] + ci_in = ci.lower_leaky[in_layer] + + ci_weighted_in_post_detach = in_post_detach * ci_in + + is_attention_pair = is_qkv_to_o_pair(in_layer, out_layer) + tqdm.write(f"Attention pair: {in_layer} -> {out_layer}") + + grad_outputs: Float[Tensor, "b s C"] = torch.zeros_like(out_pre_detach) + + for c_enum, c_idx in tqdm( enumerate(alive_out), desc="Components", leave=False, total=len(alive_out) - ) - for c, c_idx in c_pbar: - n_grads_computed = 0 - for s in range(n_seq): - for b in range(batch_size): - if ci.lower_leaky[out_layer][b, s, c_idx] <= ci_threshold: + ): + if is_attention_pair: + # Attention pair: loop over output sequence positions because + # output at s_out has gradients w.r.t. inputs at all s_in <= s_out + for s_out in range(n_seq): + torch.cuda.synchronize() + if ci_out[:, s_out, c_idx].sum() <= ci_attribution_threshold: continue - # TODO: Handle the case with o_proj in numerator and other attn in denominator - out_value = weighted_out_pre_detach[b, s, c_idx] - grads: Float[Tensor, " C"] = torch.autograd.grad( - outputs=out_value, + torch.cuda.synchronize() + grad_outputs.zero_() + grad_outputs[:, s_out, c_idx] = ci_out[:, s_out, c_idx].detach() + + torch.cuda.synchronize() + grads = torch.autograd.grad( + outputs=out_pre_detach, inputs=in_post_detach, + grad_outputs=grad_outputs, retain_graph=True, allow_unused=True, )[0] + torch.cuda.synchronize() assert grads is not None, "Gradient is None" + with torch.no_grad(): - act_weighted_grads: Float[Tensor, " C"] = ( - grads[b, s, :] - * in_post_detach[b, s, :] - * ci.lower_leaky[in_layer][b, s, :] - )[alive_indices[in_layer]].pow(2) - batch_attribution[:, c] += act_weighted_grads - n_grads_computed += 1 - tqdm.write(f"Computed {n_grads_computed} gradients for {in_layer} -> {out_layer}") + weighted = grads * ci_weighted_in_post_detach + # Only sum contributions from positions s_in <= s_out (causal) + weighted_alive = weighted[:, : s_out + 1, alive_in] + batch_attribution[:, c_enum] += weighted_alive.pow(2).sum(dim=(0, 1)) + torch.cuda.synchronize() + else: + if ci_out[:, :, c_idx].sum() <= ci_attribution_threshold: + continue + # Standard case: vectorize over all (b, s) positions + grad_outputs.zero_() + grad_outputs[:, :, c_idx] = ci_out[:, :, c_idx].detach() + + grads = torch.autograd.grad( + outputs=out_pre_detach, + inputs=in_post_detach, + grad_outputs=grad_outputs, + retain_graph=True, + allow_unused=True, + )[0] + assert grads is not None, "Gradient is None" + + with torch.no_grad(): + weighted = grads * in_post_detach * ci_in + weighted_alive = weighted[:, :, alive_in] + batch_attribution[:, c_enum] += weighted_alive.pow(2).sum(dim=(0, 1)) attribution_sums[(in_layer, out_layer)] += batch_attribution + # Track samples: for attention pairs, we have batch_size * (1+2+...+n_seq) = batch_size * n_seq*(n_seq+1)/2 + # input positions per batch (triangular sum due to causal masking). + # For standard pairs, we have batch_size * n_seq positions. + if is_attention_pair: + samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq * (n_seq + 1) // 2 + else: + samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq + global_attributions = { - pair: (attr_sum / total_samples).sqrt() for pair, attr_sum in attribution_sums.items() + pair: (attr_sum / samples_per_pair[pair]).sqrt() + for pair, attr_sum in attribution_sums.items() } - print(f"Computed global attributions over {total_samples} samples") + total_samples = sum(samples_per_pair.values()) // len(valid_pairs) if valid_pairs else 0 + print(f"Computed global attributions over ~{total_samples} samples per pair") return global_attributions @@ -339,13 +399,13 @@ def compute_global_attributions( # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L n_blocks = 1 -batch_size = 20 +batch_size = 32 # n_attribution_batches = 20 n_attribution_batches = 2 n_alive_calc_batches = 5 # n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 -ci_attribution_threshold = 1e-3 +ci_attribution_threshold = 1e-6 dataset_seed = 0 out_dir = Path(__file__).parent / "out" @@ -430,7 +490,7 @@ def compute_global_attributions( valid_pairs=valid_pairs, max_batches=n_attribution_batches, alive_indices=alive_indices, - ci_threshold=ci_mean_alive_threshold, + ci_attribution_threshold=ci_attribution_threshold, ) # Print summary statistics From f9db75cb3c60a64110bb60231363a331c06a95c9 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 26 Nov 2025 16:09:35 +0000 Subject: [PATCH 07/36] Remove stray profiler comments --- scripts/calc_global_attributions.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/scripts/calc_global_attributions.py b/scripts/calc_global_attributions.py index dbaf4a4fc..2caa0522a 100644 --- a/scripts/calc_global_attributions.py +++ b/scripts/calc_global_attributions.py @@ -221,7 +221,6 @@ def get_valid_pairs( return valid_pairs -# @profile def compute_global_attributions( model: ComponentModel, data_loader: Iterable[dict[str, Any]], @@ -328,14 +327,11 @@ def compute_global_attributions( # Attention pair: loop over output sequence positions because # output at s_out has gradients w.r.t. inputs at all s_in <= s_out for s_out in range(n_seq): - torch.cuda.synchronize() if ci_out[:, s_out, c_idx].sum() <= ci_attribution_threshold: continue - torch.cuda.synchronize() grad_outputs.zero_() grad_outputs[:, s_out, c_idx] = ci_out[:, s_out, c_idx].detach() - torch.cuda.synchronize() grads = torch.autograd.grad( outputs=out_pre_detach, inputs=in_post_detach, @@ -343,7 +339,6 @@ def compute_global_attributions( retain_graph=True, allow_unused=True, )[0] - torch.cuda.synchronize() assert grads is not None, "Gradient is None" with torch.no_grad(): @@ -351,7 +346,6 @@ def compute_global_attributions( # Only sum contributions from positions s_in <= s_out (causal) weighted_alive = weighted[:, : s_out + 1, alive_in] batch_attribution[:, c_enum] += weighted_alive.pow(2).sum(dim=(0, 1)) - torch.cuda.synchronize() else: if ci_out[:, :, c_idx].sum() <= ci_attribution_threshold: continue From 929bc65947e5e436c18db32a6b3bcd1e502bad95 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 10:14:23 +0000 Subject: [PATCH 08/36] More speedups --- scripts/calc_global_attributions.py | 190 ++++++++++++++++++---------- 1 file changed, 124 insertions(+), 66 deletions(-) diff --git a/scripts/calc_global_attributions.py b/scripts/calc_global_attributions.py index 2caa0522a..c6a0fcdd4 100644 --- a/scripts/calc_global_attributions.py +++ b/scripts/calc_global_attributions.py @@ -2,6 +2,7 @@ import gzip import json +from collections import defaultdict from collections.abc import Iterable from pathlib import Path from typing import Any @@ -134,13 +135,18 @@ def compute_alive_components( return mean_cis, alive_indices, images -def get_valid_pairs( +def get_sources_by_target( model: ComponentModel, data_loader: Iterable[dict[str, Any]], device: str, config: Config, n_blocks: int, -) -> list[tuple[str, str]]: +) -> dict[str, list[str]]: + """Find valid gradient connections grouped by target layer. + + Returns: + Dict mapping out_layer -> list of in_layers that have gradient flow to it. + """ # Get an arbitrary batch batch_raw = next(iter(data_loader)) batch = extract_batch_data(batch_raw).to(device) @@ -191,7 +197,7 @@ def get_valid_pairs( if layers.index(in_layer) < layers.index(out_layer): test_pairs.append((in_layer, out_layer)) - valid_pairs = [] + sources_by_target: dict[str, list[str]] = defaultdict(list) for in_layer, out_layer in test_pairs: out_pre_detach = cache[f"{out_layer}_pre_detach"] in_post_detach = cache[f"{in_layer}_post_detach"] @@ -217,8 +223,32 @@ def get_valid_pairs( except RuntimeError: has_grad = False if has_grad: - valid_pairs.append((in_layer, out_layer)) - return valid_pairs + sources_by_target[out_layer].append(in_layer) + return dict(sources_by_target) + + +def validate_attention_pair_structure(sources_by_target: dict[str, list[str]]) -> None: + """Assert that o_proj layers only receive from same-block QKV. + + This structural property allows us to handle attention and non-attention + cases separately without mixing within a single target layer. + """ + for out_layer, in_layers in sources_by_target.items(): + if "o_proj" in out_layer: + out_block = out_layer.split(".")[1] + for in_layer in in_layers: + assert any(x in in_layer for x in ["q_proj", "k_proj", "v_proj"]), ( + f"o_proj output {out_layer} has non-QKV input {in_layer}" + ) + in_block = in_layer.split(".")[1] + assert in_block == out_block, ( + f"o_proj output {out_layer} has input from different block: {in_layer}" + ) + else: + for in_layer in in_layers: + assert not is_qkv_to_o_pair(in_layer, out_layer), ( + f"Non-o_proj output {out_layer} has attention pair input {in_layer}" + ) def compute_global_attributions( @@ -226,7 +256,7 @@ def compute_global_attributions( data_loader: Iterable[dict[str, Any]], device: str, config: Config, - valid_pairs: list[tuple[str, str]], + sources_by_target: dict[str, list[str]], max_batches: int, alive_indices: dict[str, list[int]], ci_attribution_threshold: float, @@ -237,32 +267,35 @@ def compute_global_attributions( of output component activations with respect to input component activations, averaged over batch, sequence positions, and number of batches. + Optimization: For each target layer, we batch all source layers into a single + autograd.grad call, sharing backward computation. + Args: model: The ComponentModel to analyze. data_loader: DataLoader providing batches. device: Device to run on. config: SPD config with sampling settings. - valid_pairs: List of (in_layer, out_layer) pairs to compute attributions for. + sources_by_target: Dict mapping out_layer -> list of in_layers. max_batches: Maximum number of batches to process. alive_indices: Dictionary mapping module path -> list of alive component indices. ci_attribution_threshold: Threshold for considering a component for the attribution calculation. + Returns: Dictionary mapping (in_layer, out_layer) -> attribution tensor of shape [n_alive_in, n_alive_out] - where attribution[i, j] is the mean absolute gradient from the i-th alive input component to the j-th alive output component. + where attribution[i, j] is the mean absolute gradient from the i-th alive input component + to the j-th alive output component. """ - - # Initialize accumulators for each valid pair - # Track samples separately per pair since attention pairs aggregate differently + # Initialize accumulators for each (in_layer, out_layer) pair attribution_sums: dict[tuple[str, str], Float[Tensor, "n_alive_in n_alive_out"]] = {} samples_per_pair: dict[tuple[str, str], int] = {} - for pair in valid_pairs: - in_layer, out_layer = pair - n_alive_in = len(alive_indices[in_layer]) - n_alive_out = len(alive_indices[out_layer]) - attribution_sums[(in_layer, out_layer)] = torch.zeros( - n_alive_in, n_alive_out, device=device - ) - samples_per_pair[(in_layer, out_layer)] = 0 + for out_layer, in_layers in sources_by_target.items(): + for in_layer in in_layers: + n_alive_in = len(alive_indices[in_layer]) + n_alive_out = len(alive_indices[out_layer]) + attribution_sums[(in_layer, out_layer)] = torch.zeros( + n_alive_in, n_alive_out, device=device + ) + samples_per_pair[(in_layer, out_layer)] = 0 batch_pbar = tqdm(enumerate(data_loader), desc="Batches", total=max_batches) for batch_idx, batch_raw in batch_pbar: @@ -301,88 +334,108 @@ def compute_global_attributions( cache = comp_output_with_cache.cache - # Compute attributions for each valid pair - for in_layer, out_layer in tqdm(valid_pairs, desc="Layer pairs", leave=False): + # Compute attributions grouped by target layer + for out_layer, in_layers in tqdm( + sources_by_target.items(), desc="Target layers", leave=False + ): out_pre_detach: Float[Tensor, "b s C"] = cache[f"{out_layer}_pre_detach"] - in_post_detach: Float[Tensor, "b s C"] = cache[f"{in_layer}_post_detach"] - alive_out: list[int] = alive_indices[out_layer] - alive_in: list[int] = alive_indices[in_layer] - batch_attribution = torch.zeros(len(alive_in), len(alive_out), device=device) - ci_out = ci.lower_leaky[out_layer] - ci_in = ci.lower_leaky[in_layer] - - ci_weighted_in_post_detach = in_post_detach * ci_in - is_attention_pair = is_qkv_to_o_pair(in_layer, out_layer) - tqdm.write(f"Attention pair: {in_layer} -> {out_layer}") + # Gather all input tensors for this target layer + in_tensors = [cache[f"{in_layer}_post_detach"] for in_layer in in_layers] + ci_weighted_inputs = [ + in_tensors[i] * ci.lower_leaky[in_layers[i]] for i in range(len(in_layers)) + ] + + # Initialize batch attributions for each input layer + batch_attributions = { + in_layer: torch.zeros(len(alive_indices[in_layer]), len(alive_out), device=device) + for in_layer in in_layers + } + + is_attention_output = "o_proj" in out_layer + tqdm.write( + f"{'Attention' if is_attention_output else 'Non-attention'} target: " + f"{out_layer} <- {in_layers}" + ) grad_outputs: Float[Tensor, "b s C"] = torch.zeros_like(out_pre_detach) for c_enum, c_idx in tqdm( enumerate(alive_out), desc="Components", leave=False, total=len(alive_out) ): - if is_attention_pair: - # Attention pair: loop over output sequence positions because - # output at s_out has gradients w.r.t. inputs at all s_in <= s_out + if is_attention_output: + # Attention target: loop over output sequence positions for s_out in range(n_seq): if ci_out[:, s_out, c_idx].sum() <= ci_attribution_threshold: continue grad_outputs.zero_() grad_outputs[:, s_out, c_idx] = ci_out[:, s_out, c_idx].detach() - grads = torch.autograd.grad( + # Single autograd call for all input layers + grads_tuple = torch.autograd.grad( outputs=out_pre_detach, - inputs=in_post_detach, + inputs=in_tensors, grad_outputs=grad_outputs, retain_graph=True, allow_unused=True, - )[0] - assert grads is not None, "Gradient is None" + ) with torch.no_grad(): - weighted = grads * ci_weighted_in_post_detach - # Only sum contributions from positions s_in <= s_out (causal) - weighted_alive = weighted[:, : s_out + 1, alive_in] - batch_attribution[:, c_enum] += weighted_alive.pow(2).sum(dim=(0, 1)) + for i, in_layer in enumerate(in_layers): + grads = grads_tuple[i] + assert grads is not None, f"Gradient is None for {in_layer}" + alive_in = alive_indices[in_layer] + weighted = grads * ci_weighted_inputs[i] + # Only sum contributions from positions s_in <= s_out (causal) + weighted_alive = weighted[:, : s_out + 1, alive_in] + batch_attributions[in_layer][:, c_enum] += weighted_alive.pow( + 2 + ).sum(dim=(0, 1)) else: + # Non-attention target: vectorize over all (b, s) positions if ci_out[:, :, c_idx].sum() <= ci_attribution_threshold: continue - # Standard case: vectorize over all (b, s) positions grad_outputs.zero_() grad_outputs[:, :, c_idx] = ci_out[:, :, c_idx].detach() - grads = torch.autograd.grad( + # Single autograd call for all input layers + grads_tuple = torch.autograd.grad( outputs=out_pre_detach, - inputs=in_post_detach, + inputs=in_tensors, grad_outputs=grad_outputs, retain_graph=True, allow_unused=True, - )[0] - assert grads is not None, "Gradient is None" + ) with torch.no_grad(): - weighted = grads * in_post_detach * ci_in - weighted_alive = weighted[:, :, alive_in] - batch_attribution[:, c_enum] += weighted_alive.pow(2).sum(dim=(0, 1)) - - attribution_sums[(in_layer, out_layer)] += batch_attribution - - # Track samples: for attention pairs, we have batch_size * (1+2+...+n_seq) = batch_size * n_seq*(n_seq+1)/2 - # input positions per batch (triangular sum due to causal masking). - # For standard pairs, we have batch_size * n_seq positions. - if is_attention_pair: - samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq * (n_seq + 1) // 2 - else: - samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq + for i, in_layer in enumerate(in_layers): + grads = grads_tuple[i] + assert grads is not None, f"Gradient is None for {in_layer}" + alive_in = alive_indices[in_layer] + weighted = grads * ci_weighted_inputs[i] + weighted_alive = weighted[:, :, alive_in] + batch_attributions[in_layer][:, c_enum] += weighted_alive.pow(2).sum( + dim=(0, 1) + ) + + # Accumulate batch results and track samples + for in_layer in in_layers: + attribution_sums[(in_layer, out_layer)] += batch_attributions[in_layer] + # Track samples: attention pairs have triangular sum due to causal masking + if is_attention_output: + samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq * (n_seq + 1) // 2 + else: + samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq global_attributions = { pair: (attr_sum / samples_per_pair[pair]).sqrt() for pair, attr_sum in attribution_sums.items() } - total_samples = sum(samples_per_pair.values()) // len(valid_pairs) if valid_pairs else 0 + n_pairs = len(attribution_sums) + total_samples = sum(samples_per_pair.values()) // n_pairs if n_pairs else 0 print(f"Computed global attributions over ~{total_samples} samples per pair") return global_attributions @@ -393,9 +446,9 @@ def compute_global_attributions( # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L n_blocks = 1 -batch_size = 32 +batch_size = 512 # n_attribution_batches = 20 -n_attribution_batches = 2 +n_attribution_batches = 1 n_alive_calc_batches = 5 # n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 @@ -447,8 +500,13 @@ def compute_global_attributions( ddp_world_size=1, ) -valid_pairs = get_valid_pairs(model, data_loader, device, config, n_blocks) -print(f"Valid layer pairs: {valid_pairs}") +sources_by_target = get_sources_by_target(model, data_loader, device, config, n_blocks) +validate_attention_pair_structure(sources_by_target) + +n_pairs = sum(len(ins) for ins in sources_by_target.values()) +print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") +for out_layer, in_layers in sources_by_target.items(): + print(f" {out_layer} <- {in_layers}") # %% # Compute alive components based on mean CI threshold print("\nComputing alive components based on mean CI...") @@ -481,7 +539,7 @@ def compute_global_attributions( data_loader=data_loader, device=device, config=config, - valid_pairs=valid_pairs, + sources_by_target=sources_by_target, max_batches=n_attribution_batches, alive_indices=alive_indices, ci_attribution_threshold=ci_attribution_threshold, From 52d449a6753b794c2ba419820c94de4235e0df38 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 11:39:30 +0000 Subject: [PATCH 09/36] remove double multiplication by ci weights in in_layers --- scripts/calc_global_attributions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/scripts/calc_global_attributions.py b/scripts/calc_global_attributions.py index c6a0fcdd4..1acf03a31 100644 --- a/scripts/calc_global_attributions.py +++ b/scripts/calc_global_attributions.py @@ -344,9 +344,6 @@ def compute_global_attributions( # Gather all input tensors for this target layer in_tensors = [cache[f"{in_layer}_post_detach"] for in_layer in in_layers] - ci_weighted_inputs = [ - in_tensors[i] * ci.lower_leaky[in_layers[i]] for i in range(len(in_layers)) - ] # Initialize batch attributions for each input layer batch_attributions = { @@ -387,7 +384,7 @@ def compute_global_attributions( grads = grads_tuple[i] assert grads is not None, f"Gradient is None for {in_layer}" alive_in = alive_indices[in_layer] - weighted = grads * ci_weighted_inputs[i] + weighted = grads * in_tensors[i] # Only sum contributions from positions s_in <= s_out (causal) weighted_alive = weighted[:, : s_out + 1, alive_in] batch_attributions[in_layer][:, c_enum] += weighted_alive.pow( @@ -414,7 +411,7 @@ def compute_global_attributions( grads = grads_tuple[i] assert grads is not None, f"Gradient is None for {in_layer}" alive_in = alive_indices[in_layer] - weighted = grads * ci_weighted_inputs[i] + weighted = grads * in_tensors[i] weighted_alive = weighted[:, :, alive_in] batch_attributions[in_layer][:, c_enum] += weighted_alive.pow(2).sum( dim=(0, 1) From 04c093ff6f4a4bf776200ec84e46738bf706ac24 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 14:07:28 +0000 Subject: [PATCH 10/36] Normalise attrs and move to spd/scripts --- .../scripts}/calc_global_attributions.py | 17 +++++++++-------- .../scripts}/plot_global_attributions.py | 0 2 files changed, 9 insertions(+), 8 deletions(-) rename {scripts => spd/scripts}/calc_global_attributions.py (98%) rename {scripts => spd/scripts}/plot_global_attributions.py (100%) diff --git a/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py similarity index 98% rename from scripts/calc_global_attributions.py rename to spd/scripts/calc_global_attributions.py index 1acf03a31..8ee57e74c 100644 --- a/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -426,10 +426,10 @@ def compute_global_attributions( else: samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq - global_attributions = { - pair: (attr_sum / samples_per_pair[pair]).sqrt() - for pair, attr_sum in attribution_sums.items() - } + global_attributions = {} + for pair, attr_sum in attribution_sums.items(): + attr = attr_sum / samples_per_pair[pair] + global_attributions[pair] = attr / attr.sum() n_pairs = len(attribution_sums) total_samples = sum(samples_per_pair.values()) // n_pairs if n_pairs else 0 @@ -443,10 +443,11 @@ def compute_global_attributions( # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L n_blocks = 1 -batch_size = 512 +batch_size = 1024 +n_ctx = 64 # n_attribution_batches = 20 -n_attribution_batches = 1 -n_alive_calc_batches = 5 +n_attribution_batches = 5 +n_alive_calc_batches = 200 # n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 ci_attribution_threshold = 1e-6 @@ -479,7 +480,7 @@ def compute_global_attributions( name=task_config.dataset_name, hf_tokenizer_path=config.tokenizer_name, split=task_config.train_data_split, # Using train split for now - n_ctx=task_config.max_seq_len, + n_ctx=n_ctx, is_tokenized=task_config.is_tokenized, streaming=task_config.streaming, column_name=task_config.column_name, diff --git a/scripts/plot_global_attributions.py b/spd/scripts/plot_global_attributions.py similarity index 100% rename from scripts/plot_global_attributions.py rename to spd/scripts/plot_global_attributions.py From 85ee0062dc04a2b7ed3d1a2030a8adb9dab996e1 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 14:38:35 +0000 Subject: [PATCH 11/36] Simplify get_sources_by_target --- spd/scripts/calc_global_attributions.py | 32 ++++++++----------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 8ee57e74c..10050e5eb 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -201,28 +201,16 @@ def get_sources_by_target( for in_layer, out_layer in test_pairs: out_pre_detach = cache[f"{out_layer}_pre_detach"] in_post_detach = cache[f"{in_layer}_post_detach"] - batch_idx = 0 - seq_idx = 50 - target_component_idx = 10 - out_value = out_pre_detach[batch_idx, seq_idx, target_component_idx] - try: - grads = torch.autograd.grad( - outputs=out_value, - inputs=in_post_detach, - retain_graph=True, - allow_unused=True, - ) - assert len(grads) == 1, "Expected 1 gradient" - grad = grads[0] - # torch.autograd.grad returns None for unused inputs when allow_unused=True - has_grad = ( - grad.abs().max().item() > 1e-8 - if grad is not None # pyright: ignore[reportUnnecessaryComparison] - else False - ) - except RuntimeError: - has_grad = False - if has_grad: + out_value = out_pre_detach[0, 0, 0] # Pick arbitrary value + grads = torch.autograd.grad( + outputs=out_value, + inputs=in_post_detach, + retain_graph=True, + allow_unused=True, + ) + assert len(grads) == 1, "Expected 1 gradient" + grad = grads[0] + if grad.abs().max().item() > 0: sources_by_target[out_layer].append(in_layer) return dict(sources_by_target) From 6915a1f0bd2ece49447571d4cf5836aa0c3d6860 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 15:33:25 +0000 Subject: [PATCH 12/36] Misc tweaks --- spd/scripts/calc_global_attributions.py | 4 ++-- spd/scripts/plot_global_attributions.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 10050e5eb..ae06d3d1b 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -210,7 +210,7 @@ def get_sources_by_target( ) assert len(grads) == 1, "Expected 1 gradient" grad = grads[0] - if grad.abs().max().item() > 0: + if grad is not None and grad.abs().max().item() > 0: # pyright: ignore[reportUnnecessaryComparison] sources_by_target[out_layer].append(in_layer) return dict(sources_by_target) @@ -435,7 +435,7 @@ def compute_global_attributions( n_ctx = 64 # n_attribution_batches = 20 n_attribution_batches = 5 -n_alive_calc_batches = 200 +n_alive_calc_batches = 100 # n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 ci_attribution_threshold = 1e-6 diff --git a/spd/scripts/plot_global_attributions.py b/spd/scripts/plot_global_attributions.py index 2f0c6f6b5..a072132a8 100644 --- a/spd/scripts/plot_global_attributions.py +++ b/spd/scripts/plot_global_attributions.py @@ -12,7 +12,7 @@ # n_blocks = 2 wandb_id = "8ynfbr38" # ss_gpt2_simple-1L n_blocks = 1 -edge_threshold = 1e-1 +edge_threshold = 1e-3 # Load saved data out_dir = Path(__file__).parent / "out" @@ -33,7 +33,7 @@ # Count edges before and after thresholding total_edges = sum(attr.numel() for attr in global_attributions.values()) print(f"Total edges: {total_edges:,}") -thresholds = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] +thresholds = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-12, 1e-15] for threshold in thresholds: total_edges_threshold = sum( (attr > threshold).sum().item() for attr in global_attributions.values() From 27930432669aed21aab669c14fc2a14384c82f19 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 16:49:54 +0000 Subject: [PATCH 13/36] Normalize over sum to output node --- spd/scripts/calc_global_attributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index ae06d3d1b..b557e81b7 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -416,8 +416,8 @@ def compute_global_attributions( global_attributions = {} for pair, attr_sum in attribution_sums.items(): - attr = attr_sum / samples_per_pair[pair] - global_attributions[pair] = attr / attr.sum() + attr: Float[Tensor, "n_alive_in n_alive_out"] = attr_sum / samples_per_pair[pair] + global_attributions[pair] = attr / attr.sum(dim=1, keepdim=True) n_pairs = len(attribution_sums) total_samples = sum(samples_per_pair.values()) // n_pairs if n_pairs else 0 From c471c2901500defb8f5c31e9300428f207b03e5b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Nov 2025 17:06:08 +0000 Subject: [PATCH 14/36] Add more thresholds --- spd/scripts/plot_global_attributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/scripts/plot_global_attributions.py b/spd/scripts/plot_global_attributions.py index a072132a8..b3862b9dc 100644 --- a/spd/scripts/plot_global_attributions.py +++ b/spd/scripts/plot_global_attributions.py @@ -12,7 +12,7 @@ # n_blocks = 2 wandb_id = "8ynfbr38" # ss_gpt2_simple-1L n_blocks = 1 -edge_threshold = 1e-3 +edge_threshold = 1e-2 # Load saved data out_dir = Path(__file__).parent / "out" @@ -33,7 +33,7 @@ # Count edges before and after thresholding total_edges = sum(attr.numel() for attr in global_attributions.values()) print(f"Total edges: {total_edges:,}") -thresholds = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-12, 1e-15] +thresholds = [1, 0.6, 0.2, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-12, 1e-15] for threshold in thresholds: total_edges_threshold = sum( (attr > threshold).sum().item() for attr in global_attributions.values() From 4d7c0b3bcc0d9f2ee72962d9eb8bb5c7ef65f5ec Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 12:17:49 +0000 Subject: [PATCH 15/36] Add local attribution calcs --- spd/scripts/calc_global_attributions.py | 281 +++++++++++------------- spd/scripts/calc_local_attributions.py | 242 ++++++++++++++++++++ spd/scripts/model_loading.py | 115 ++++++++++ 3 files changed, 481 insertions(+), 157 deletions(-) create mode 100644 spd/scripts/calc_local_attributions.py create mode 100644 spd/scripts/model_loading.py diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index b557e81b7..7e1be4322 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -4,7 +4,6 @@ import json from collections import defaultdict from collections.abc import Iterable -from pathlib import Path from typing import Any import torch @@ -14,11 +13,14 @@ from tqdm.auto import tqdm from spd.configs import Config -from spd.data import DatasetConfig, create_data_loader -from spd.experiments.lm.configs import LMTaskConfig -from spd.models.component_model import ComponentModel, OutputWithCache, SPDRunInfo +from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.plotting import plot_mean_component_cis_both_scales +from spd.scripts.model_loading import ( + create_data_loader_from_config, + get_out_dir, + load_model_from_wandb, +) from spd.utils.general_utils import extract_batch_data @@ -137,7 +139,6 @@ def compute_alive_components( def get_sources_by_target( model: ComponentModel, - data_loader: Iterable[dict[str, Any]], device: str, config: Config, n_blocks: int, @@ -147,16 +148,12 @@ def get_sources_by_target( Returns: Dict mapping out_layer -> list of in_layers that have gradient flow to it. """ - # Get an arbitrary batch - batch_raw = next(iter(data_loader)) - batch = extract_batch_data(batch_raw).to(device) - print(f"Batch shape: {batch.shape}") + # Use a small dummy batch - we only need to trace gradient connections + batch: Float[Tensor, "batch seq"] = torch.zeros(2, 3, dtype=torch.long, device=device) with torch.no_grad(): output_with_cache: OutputWithCache = model(batch, cache_type="input") print(f"Output shape: {output_with_cache.output.shape}") - print(f"Number of cached layers: {len(output_with_cache.cache)}") - print(f"Cached layer names: {list(output_with_cache.cache.keys())}") with torch.no_grad(): ci = model.calc_causal_importances( @@ -210,7 +207,7 @@ def get_sources_by_target( ) assert len(grads) == 1, "Expected 1 gradient" grad = grads[0] - if grad is not None and grad.abs().max().item() > 0: # pyright: ignore[reportUnnecessaryComparison] + if grad is not None: # pyright: ignore[reportUnnecessaryComparison] sources_by_target[out_layer].append(in_layer) return dict(sources_by_target) @@ -364,7 +361,6 @@ def compute_global_attributions( inputs=in_tensors, grad_outputs=grad_outputs, retain_graph=True, - allow_unused=True, ) with torch.no_grad(): @@ -425,154 +421,125 @@ def compute_global_attributions( return global_attributions -# %% -# Configuration -# wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) -# wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L -wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L -n_blocks = 1 -batch_size = 1024 -n_ctx = 64 -# n_attribution_batches = 20 -n_attribution_batches = 5 -n_alive_calc_batches = 100 -# n_alive_calc_batches = 200 -ci_mean_alive_threshold = 1e-6 -ci_attribution_threshold = 1e-6 -dataset_seed = 0 - -out_dir = Path(__file__).parent / "out" -out_dir.mkdir(parents=True, exist_ok=True) -wandb_id = wandb_path.split("/")[-1] - -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"Using device: {device}") -print(f"Loading model from {wandb_path}...") - -# Load the model -run_info = SPDRunInfo.from_path(wandb_path) -config: Config = run_info.config -model = ComponentModel.from_run_info(run_info) -model = model.to(device) -model.eval() - -print("Model loaded successfully!") -print(f"Number of components: {model.C}") -print(f"Target module paths: {model.target_module_paths}") - -# Load the dataset -task_config = config.task_config -assert isinstance(task_config, LMTaskConfig), "Expected LM task config" - -dataset_config = DatasetConfig( - name=task_config.dataset_name, - hf_tokenizer_path=config.tokenizer_name, - split=task_config.train_data_split, # Using train split for now - n_ctx=n_ctx, - is_tokenized=task_config.is_tokenized, - streaming=task_config.streaming, - column_name=task_config.column_name, - shuffle_each_epoch=False, # No need to shuffle for testing - seed=dataset_seed, -) +if __name__ == "__main__": + # %% + # Configuration + # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) + # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L + wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L + n_blocks = 1 + batch_size = 1024 + n_ctx = 64 + # n_attribution_batches = 20 + n_attribution_batches = 5 + n_alive_calc_batches = 100 + # n_alive_calc_batches = 200 + ci_mean_alive_threshold = 1e-6 + ci_attribution_threshold = 1e-6 + dataset_seed = 0 + + out_dir = get_out_dir() + + # Load model using shared utility + loaded = load_model_from_wandb(wandb_path) + model = loaded.model + config = loaded.config + device = loaded.device + wandb_id = loaded.wandb_id + + # Load the dataset + data_loader, tokenizer = create_data_loader_from_config( + config=config, + batch_size=batch_size, + n_ctx=n_ctx, + seed=dataset_seed, + ) -print(f"\nLoading dataset {dataset_config.name}...") -data_loader, tokenizer = create_data_loader( - dataset_config=dataset_config, - batch_size=batch_size, - buffer_size=task_config.buffer_size, - global_seed=dataset_seed, - ddp_rank=0, - ddp_world_size=1, -) + sources_by_target = get_sources_by_target(model, device, config, n_blocks) + validate_attention_pair_structure(sources_by_target) + n_pairs = sum(len(ins) for ins in sources_by_target.values()) + print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") + for out_layer, in_layers in sources_by_target.items(): + print(f" {out_layer} <- {in_layers}") + # %% + # Compute alive components based on mean CI threshold + print("\nComputing alive components based on mean CI...") + mean_cis, alive_indices, (img_linear, img_log) = compute_alive_components( + model=model, + data_loader=data_loader, + device=device, + config=config, + max_batches=n_alive_calc_batches, + threshold=ci_mean_alive_threshold, + ) -sources_by_target = get_sources_by_target(model, data_loader, device, config, n_blocks) -validate_attention_pair_structure(sources_by_target) + # Print summary + print("\nAlive components per layer:") + for module_name, indices in alive_indices.items(): + n_alive = len(indices) + print(f" {module_name}: {n_alive}/{model.C} alive") + + # Save images for verification + img_linear.save(out_dir / f"ci_mean_per_component_linear_{wandb_id}.png") + img_log.save(out_dir / f"ci_mean_per_component_log_{wandb_id}.png") + print( + f"Saved verification images to {out_dir / f'ci_mean_per_component_linear_{wandb_id}.png'} and {out_dir / f'ci_mean_per_component_log_{wandb_id}.png'}" + ) + # %% + # Compute global attributions over the dataset + print("\nComputing global attributions...") + global_attributions = compute_global_attributions( + model=model, + data_loader=data_loader, + device=device, + config=config, + sources_by_target=sources_by_target, + max_batches=n_attribution_batches, + alive_indices=alive_indices, + ci_attribution_threshold=ci_attribution_threshold, + ) -n_pairs = sum(len(ins) for ins in sources_by_target.values()) -print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") -for out_layer, in_layers in sources_by_target.items(): - print(f" {out_layer} <- {in_layers}") -# %% -# Compute alive components based on mean CI threshold -print("\nComputing alive components based on mean CI...") -mean_cis, alive_indices, (img_linear, img_log) = compute_alive_components( - model=model, - data_loader=data_loader, - device=device, - config=config, - max_batches=n_alive_calc_batches, - threshold=ci_mean_alive_threshold, -) + # Print summary statistics + for pair, attr in global_attributions.items(): + print(f"{pair[0]} -> {pair[1]}: mean={attr.mean():.6f}, max={attr.max():.6f}") + + # %% + # Save attributions in both PyTorch and JSON formats + print("\nSaving attribution data...") + + # Save PyTorch format + pt_path = out_dir / f"global_attributions_{wandb_id}.pt" + torch.save(global_attributions, pt_path) + print(f"Saved PyTorch format to {pt_path}") + + # Convert and save JSON format for web visualization + attributions_json = {} + for (in_layer, out_layer), attr_tensor in global_attributions.items(): + key = f"('{in_layer}', '{out_layer}')" + # Keep full precision - just convert to list + attributions_json[key] = attr_tensor.cpu().tolist() + + json_data = { + "n_blocks": n_blocks, + "attributions": attributions_json, + "alive_indices": alive_indices, + } -# Print summary -print("\nAlive components per layer:") -for module_name, indices in alive_indices.items(): - n_alive = len(indices) - print(f" {module_name}: {n_alive}/{model.C} alive") - -# Save images for verification -img_linear.save(out_dir / f"ci_mean_per_component_linear_{wandb_id}.png") -img_log.save(out_dir / f"ci_mean_per_component_log_{wandb_id}.png") -print( - f"Saved verification images to {out_dir / f'ci_mean_per_component_linear_{wandb_id}.png'} and {out_dir / f'ci_mean_per_component_log_{wandb_id}.png'}" -) -# %% -# Compute global attributions over the dataset -print("\nComputing global attributions...") -global_attributions = compute_global_attributions( - model=model, - data_loader=data_loader, - device=device, - config=config, - sources_by_target=sources_by_target, - max_batches=n_attribution_batches, - alive_indices=alive_indices, - ci_attribution_threshold=ci_attribution_threshold, -) + json_path = out_dir / f"global_attributions_{wandb_id}.json" -# Print summary statistics -for pair, attr in global_attributions.items(): - print(f"{pair[0]} -> {pair[1]}: mean={attr.mean():.6f}, max={attr.max():.6f}") + # Write JSON with compact formatting + with open(json_path, "w") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) -# %% -# Save attributions in both PyTorch and JSON formats -print("\nSaving attribution data...") -out_dir = Path(__file__).parent / "out" - -# Save PyTorch format -pt_path = out_dir / f"global_attributions_{wandb_id}.pt" -torch.save(global_attributions, pt_path) -print(f"Saved PyTorch format to {pt_path}") - -# Convert and save JSON format for web visualization -attributions_json = {} -for (in_layer, out_layer), attr_tensor in global_attributions.items(): - key = f"('{in_layer}', '{out_layer}')" - # Keep full precision - just convert to list - attributions_json[key] = attr_tensor.cpu().tolist() - -json_data = { - "n_blocks": n_blocks, - "attributions": attributions_json, - "alive_indices": alive_indices, -} - -json_path = out_dir / f"global_attributions_{wandb_id}.json" - -# Write JSON with compact formatting -with open(json_path, "w") as f: - json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) - -# Also save a compressed version for very large files -gz_path = out_dir / f"global_attributions_{wandb_id}.json.gz" -with gzip.open(gz_path, "wt", encoding="utf-8") as f: - json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) - -print(f"Saved JSON format to {json_path}") -print(f"Saved compressed format to {gz_path}") -print(f" - {len(attributions_json)} layer pairs") -print(f" - {sum(len(v) for v in alive_indices.values())} total alive components") -print(f"\nTo visualize: Open scripts/plot_attributions.html and load {json_path}") + # Also save a compressed version for very large files + gz_path = out_dir / f"global_attributions_{wandb_id}.json.gz" + with gzip.open(gz_path, "wt", encoding="utf-8") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) -# %% + print(f"Saved JSON format to {json_path}") + print(f"Saved compressed format to {gz_path}") + print(f" - {len(attributions_json)} layer pairs") + print(f" - {sum(len(v) for v in alive_indices.values())} total alive components") + print(f"\nTo visualize: Open scripts/plot_attributions.html and load {json_path}") + + # %% diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py new file mode 100644 index 000000000..9c92b7d91 --- /dev/null +++ b/spd/scripts/calc_local_attributions.py @@ -0,0 +1,242 @@ +# %% +"""Compute local attributions for a single prompt.""" + +import gzip +import json + +import torch +from jaxtyping import Float +from torch import Tensor +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from spd.configs import SamplingType +from spd.models.component_model import ComponentModel, OutputWithCache +from spd.models.components import make_mask_infos +from spd.scripts.calc_global_attributions import ( + get_sources_by_target, + is_qkv_to_o_pair, + validate_attention_pair_structure, +) +from spd.scripts.model_loading import ( + get_out_dir, + load_model_from_wandb, +) + + +def compute_local_attributions( + model: ComponentModel, + tokens: Float[Tensor, "1 seq"], + sources_by_target: dict[str, list[str]], + ci_threshold: float, + sampling: SamplingType, + device: str, +) -> dict[tuple[str, str], Float[Tensor, "s_in C s_out C"]]: + """Compute local attributions for a single prompt. + + For each valid layer pair (in_layer, out_layer), computes the gradient-based + attribution of output component activations with respect to input component + activations, preserving sequence position information. + + Args: + model: The ComponentModel to analyze. + tokens: Tokenized prompt of shape [1, seq_len]. + sources_by_target: Dict mapping out_layer -> list of in_layers. + ci_threshold: Threshold for considering a component alive at a position. + sampling: Sampling type to use for causal importances. + device: Device to run on. + + Returns: + Dictionary mapping (in_layer, out_layer) -> attribution tensor. + For non-attention pairs: shape [seq, C, seq, C] but only diagonal (s_in == s_out) is nonzero. + For Q/K/V -> O pairs: shape [seq, C, seq, C] with causal structure (s_in <= s_out). + """ + n_seq = tokens.shape[1] + C = model.C + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + + component_masks = ci.lower_leaky + mask_infos = make_mask_infos(component_masks=component_masks, routing_masks="all") + + with torch.enable_grad(): + comp_output_with_cache: OutputWithCache = model( + tokens, mask_infos=mask_infos, cache_type="component_acts" + ) + + cache = comp_output_with_cache.cache + + # Initialize output attributions + local_attributions: dict[tuple[str, str], Float[Tensor, "s_in C s_out C"]] = {} + + for out_layer, in_layers in tqdm(sources_by_target.items(), desc="Target layers"): + out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{out_layer}_pre_detach"] + ci_out: Float[Tensor, "1 s C"] = ci.lower_leaky[out_layer] + + for in_layer in in_layers: + in_post_detach: Float[Tensor, "1 s C"] = cache[f"{in_layer}_post_detach"] + ci_in: Float[Tensor, "1 s C"] = ci.lower_leaky[in_layer] + + attribution: Float[Tensor, "s_in C s_out C"] = torch.zeros( + n_seq, C, n_seq, C, device=device + ) + + is_attention_pair = is_qkv_to_o_pair(in_layer, out_layer) + + # Determine which (s_out, c_out) pairs are alive + alive_out_mask: Float[Tensor, "1 s C"] = ci_out >= ci_threshold + alive_in_mask: Float[Tensor, "1 s C"] = ci_in >= ci_threshold + + grad_outputs: Float[Tensor, "1 s C"] = torch.zeros_like(out_pre_detach) + + for s_out in tqdm(range(n_seq), desc=f"{in_layer} -> {out_layer}", leave=False): + # Get alive output components at this position + alive_c_out: list[int] = torch.where(alive_out_mask[0, s_out])[0].tolist() + if len(alive_c_out) == 0: + continue + + for c_out in alive_c_out: + grad_outputs.zero_() + grad_outputs[0, s_out, c_out] = 1.0 + + grads = torch.autograd.grad( + outputs=out_pre_detach, + inputs=in_post_detach, + grad_outputs=grad_outputs, + retain_graph=True, + ) + + assert len(grads) == 1 + in_post_detach_grad: Float[Tensor, "1 s C"] = grads[0] + assert in_post_detach_grad is not None, f"Gradient is None for {in_layer}" + + # Weight by input acts and square (we index into the singular batch dimension) + weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] ** 2 + + # Handle causal attention mask + s_in_range = range(s_out + 1) if is_attention_pair else range(s_out, s_out + 1) + + with torch.no_grad(): + for s_in in s_in_range: + # Only include alive input components + alive_c_in: list[int] = torch.where(alive_in_mask[0, s_in])[0].tolist() + for c_in in alive_c_in: + attribution[s_in, c_in, s_out, c_out] = weighted[s_in, c_in] + + local_attributions[(in_layer, out_layer)] = attribution + + return local_attributions + + +# %% +# Configuration +# wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L +wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L +# wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) +n_blocks = 2 +batch_size = 1 # Only need 1 for getting sources_by_target +n_ctx = 64 +ci_threshold = 1e-6 +# prompt = "The quick brown fox" +prompt = "Eagerly, a girl named Kim went" + +# Load model +loaded = load_model_from_wandb(wandb_path) +model, config, device = loaded.model, loaded.config, loaded.device + +tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) +assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" +sources_by_target = get_sources_by_target(model, device, config, n_blocks) +validate_attention_pair_structure(sources_by_target) + +n_pairs = sum(len(ins) for ins in sources_by_target.values()) +print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") +for out_layer, in_layers in sources_by_target.items(): + print(f" {out_layer} <- {in_layers}") + +# %% +# Tokenize the prompt +print(f"\nPrompt: {prompt!r}") +tokens = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) +assert isinstance(tokens, Tensor), "Expected Tensor" +tokens = tokens.to(device) +print(f"Tokens shape: {tokens.shape}") +print(f"Tokens: {tokens[0].tolist()}") +token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] +print(f"Token strings: {token_strings}") + +# %% +# Compute local attributions +print("\nComputing local attributions...") +local_attributions = compute_local_attributions( + model=model, + tokens=tokens, + sources_by_target=sources_by_target, + ci_threshold=ci_threshold, + sampling=config.sampling, + device=device, +) + +# Print summary statistics +print("\nAttribution summary:") +for pair, attr in local_attributions.items(): + nonzero = (attr > 0).sum().item() + total = attr.numel() + print( + f" {pair[0]} -> {pair[1]}: " + f"nonzero={nonzero}/{total} ({100 * nonzero / total:.2f}%), " + f"max={attr.max():.6f}" + ) + +# %% +# Save attributions +out_dir = get_out_dir() + +# Save PyTorch format +pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" +save_data = { + "attributions": local_attributions, + "tokens": tokens.cpu(), + "token_strings": token_strings, + "prompt": prompt, +} +torch.save(save_data, pt_path) +print(f"\nSaved PyTorch format to {pt_path}") + +# Convert and save JSON format for web visualization +attributions_json = {} +for (in_layer, out_layer), attr_tensor in local_attributions.items(): + key = f"('{in_layer}', '{out_layer}')" + attributions_json[key] = attr_tensor.cpu().tolist() + +json_data = { + "n_blocks": n_blocks, + "attributions": attributions_json, + "tokens": tokens[0].cpu().tolist(), + "token_strings": token_strings, + "prompt": prompt, +} + +json_path = out_dir / f"local_attributions_{loaded.wandb_id}.json" +with open(json_path, "w") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) + +gz_path = out_dir / f"local_attributions_{loaded.wandb_id}.json.gz" +with gzip.open(gz_path, "wt", encoding="utf-8") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) + +print(f"Saved JSON format to {json_path}") +print(f"Saved compressed format to {gz_path}") +print(f" - {len(attributions_json)} layer pairs") +print(f" - Sequence length: {tokens.shape[1]}") + +# %% diff --git a/spd/scripts/model_loading.py b/spd/scripts/model_loading.py new file mode 100644 index 000000000..6bc63000c --- /dev/null +++ b/spd/scripts/model_loading.py @@ -0,0 +1,115 @@ +# %% +"""Shared model loading utilities for attribution scripts.""" + +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from transformers import PreTrainedTokenizer + +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, SPDRunInfo + + +@dataclass +class LoadedModel: + """Container for a loaded SPD model and its configuration.""" + + model: ComponentModel + config: Config + run_info: SPDRunInfo + device: str + wandb_id: str + + +def load_model_from_wandb(wandb_path: str, device: str | None = None) -> LoadedModel: + """Load a ComponentModel from a wandb run path. + + Args: + wandb_path: Path like "wandb:goodfire/spd/runs/8ynfbr38" + device: Device to load model on. If None, uses cuda if available. + + Returns: + LoadedModel containing the model, config, and metadata. + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + wandb_id = wandb_path.split("/")[-1] + + print(f"Using device: {device}") + print(f"Loading model from {wandb_path}...") + + run_info = SPDRunInfo.from_path(wandb_path) + config: Config = run_info.config + model = ComponentModel.from_run_info(run_info) + model = model.to(device) + model.eval() + + print("Model loaded successfully!") + print(f"Number of components: {model.C}") + print(f"Target module paths: {model.target_module_paths}") + + return LoadedModel( + model=model, + config=config, + run_info=run_info, + device=device, + wandb_id=wandb_id, + ) + + +def create_data_loader_from_config( + config: Config, + batch_size: int, + n_ctx: int, + seed: int = 0, +) -> tuple[Iterable[dict[str, Any]], PreTrainedTokenizer]: + """Create a data loader from an SPD config. + + Args: + config: SPD Config with task configuration. + batch_size: Batch size for the data loader. + n_ctx: Context length. + seed: Random seed for shuffling. + + Returns: + Tuple of (data_loader, tokenizer). + """ + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig), "Expected LM task config" + + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.train_data_split, + n_ctx=n_ctx, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + seed=seed, + ) + + print(f"\nLoading dataset {dataset_config.name}...") + data_loader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=task_config.buffer_size, + global_seed=seed, + ddp_rank=0, + ddp_world_size=1, + ) + + return data_loader, tokenizer + + +def get_out_dir() -> Path: + """Get the output directory for attribution scripts.""" + out_dir = Path(__file__).parent / "out" + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir From 344671b2efcf618b53ed02e53698c03b0199a39b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 13:01:55 +0000 Subject: [PATCH 16/36] Remove grad_outputs --- spd/scripts/calc_local_attributions.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 9c92b7d91..81c65b36a 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -96,8 +96,6 @@ def compute_local_attributions( alive_out_mask: Float[Tensor, "1 s C"] = ci_out >= ci_threshold alive_in_mask: Float[Tensor, "1 s C"] = ci_in >= ci_threshold - grad_outputs: Float[Tensor, "1 s C"] = torch.zeros_like(out_pre_detach) - for s_out in tqdm(range(n_seq), desc=f"{in_layer} -> {out_layer}", leave=False): # Get alive output components at this position alive_c_out: list[int] = torch.where(alive_out_mask[0, s_out])[0].tolist() @@ -105,13 +103,9 @@ def compute_local_attributions( continue for c_out in alive_c_out: - grad_outputs.zero_() - grad_outputs[0, s_out, c_out] = 1.0 - grads = torch.autograd.grad( - outputs=out_pre_detach, + outputs=out_pre_detach[0, s_out, c_out], inputs=in_post_detach, - grad_outputs=grad_outputs, retain_graph=True, ) @@ -120,7 +114,7 @@ def compute_local_attributions( assert in_post_detach_grad is not None, f"Gradient is None for {in_layer}" # Weight by input acts and square (we index into the singular batch dimension) - weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] ** 2 + weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] # Handle causal attention mask s_in_range = range(s_out + 1) if is_attention_pair else range(s_out, s_out + 1) From 563bbbb9f4d572b1b5025cf1449bf3ad42277641 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 13:08:26 +0000 Subject: [PATCH 17/36] Add new 1L model --- spd/scripts/calc_global_attributions.py | 3 ++- spd/scripts/calc_local_attributions.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 7e1be4322..9282dabe6 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -426,7 +426,8 @@ def compute_global_attributions( # Configuration # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L - wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L + # wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) + wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (New) n_blocks = 1 batch_size = 1024 n_ctx = 64 diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 81c65b36a..2c96d09a4 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -133,10 +133,11 @@ def compute_local_attributions( # %% # Configuration -# wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L -wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L +# wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) +wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (new) +# wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) -n_blocks = 2 +n_blocks = 1 batch_size = 1 # Only need 1 for getting sources_by_target n_ctx = 64 ci_threshold = 1e-6 From 958f1502d4706c38233fb7278fe458e3b54c9823 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 14:12:05 +0000 Subject: [PATCH 18/36] Allow diff w.r.t multiple inputs in local attr --- spd/scripts/calc_global_attributions.py | 41 ++---- spd/scripts/calc_local_attributions.py | 184 +++++++++++------------- spd/scripts/model_loading.py | 2 - 3 files changed, 92 insertions(+), 135 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 9282dabe6..10724d1d2 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -24,15 +24,15 @@ from spd.utils.general_utils import extract_batch_data -def is_qkv_to_o_pair(in_layer: str, out_layer: str) -> bool: +def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: """Check if pair requires per-sequence-position gradient computation. - For q/k/v → o_proj within the same attention block, output at s_out + For k/v → o_proj within the same attention block, output at s_out has gradients w.r.t. inputs at all s_in ≤ s_out (causal attention). """ - in_is_qkv = any(x in in_layer for x in ["q_proj", "k_proj", "v_proj"]) + in_is_kv = any(x in in_layer for x in ["k_proj", "v_proj"]) out_is_o = "o_proj" in out_layer - if not (in_is_qkv and out_is_o): + if not (in_is_kv and out_is_o): return False # Check same attention block: "h.{idx}.attn.{proj}" @@ -212,30 +212,6 @@ def get_sources_by_target( return dict(sources_by_target) -def validate_attention_pair_structure(sources_by_target: dict[str, list[str]]) -> None: - """Assert that o_proj layers only receive from same-block QKV. - - This structural property allows us to handle attention and non-attention - cases separately without mixing within a single target layer. - """ - for out_layer, in_layers in sources_by_target.items(): - if "o_proj" in out_layer: - out_block = out_layer.split(".")[1] - for in_layer in in_layers: - assert any(x in in_layer for x in ["q_proj", "k_proj", "v_proj"]), ( - f"o_proj output {out_layer} has non-QKV input {in_layer}" - ) - in_block = in_layer.split(".")[1] - assert in_block == out_block, ( - f"o_proj output {out_layer} has input from different block: {in_layer}" - ) - else: - for in_layer in in_layers: - assert not is_qkv_to_o_pair(in_layer, out_layer), ( - f"Non-o_proj output {out_layer} has attention pair input {in_layer}" - ) - - def compute_global_attributions( model: ComponentModel, data_loader: Iterable[dict[str, Any]], @@ -336,10 +312,10 @@ def compute_global_attributions( for in_layer in in_layers } - is_attention_output = "o_proj" in out_layer - tqdm.write( - f"{'Attention' if is_attention_output else 'Non-attention'} target: " - f"{out_layer} <- {in_layers}" + # NOTE: o->q will be treated as an attention pair even though there are no attrs + # across sequence positions. This is just so we don't have to special case it. + is_attention_output = any( + is_kv_to_o_pair(in_layer, out_layer) for in_layer in in_layers ) grad_outputs: Float[Tensor, "b s C"] = torch.zeros_like(out_pre_detach) @@ -457,7 +433,6 @@ def compute_global_attributions( ) sources_by_target = get_sources_by_target(model, device, config, n_blocks) - validate_attention_pair_structure(sources_by_target) n_pairs = sum(len(ins) for ins in sources_by_target.values()) print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") for out_layer, in_layers in sources_by_target.items(): diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 2c96d09a4..7fedf2fed 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -1,8 +1,7 @@ # %% """Compute local attributions for a single prompt.""" -import gzip -import json +from dataclasses import dataclass import torch from jaxtyping import Float @@ -14,15 +13,18 @@ from spd.configs import SamplingType from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos -from spd.scripts.calc_global_attributions import ( - get_sources_by_target, - is_qkv_to_o_pair, - validate_attention_pair_structure, -) -from spd.scripts.model_loading import ( - get_out_dir, - load_model_from_wandb, -) +from spd.scripts.calc_global_attributions import get_sources_by_target, is_kv_to_o_pair +from spd.scripts.model_loading import get_out_dir, load_model_from_wandb + + +@dataclass +class PairAttribution: + source: str + target: str + attribution: Float[Tensor, "s_in trimmed_c_in s_out trimmed_c_out"] + trimmed_c_in_idxs: list[int] + trimmed_c_out_idxs: list[int] + is_kv_to_o_pair: bool def compute_local_attributions( @@ -32,7 +34,7 @@ def compute_local_attributions( ci_threshold: float, sampling: SamplingType, device: str, -) -> dict[tuple[str, str], Float[Tensor, "s_in C s_out C"]]: +) -> list[PairAttribution]: """Compute local attributions for a single prompt. For each valid layer pair (in_layer, out_layer), computes the gradient-based @@ -48,9 +50,7 @@ def compute_local_attributions( device: Device to run on. Returns: - Dictionary mapping (in_layer, out_layer) -> attribution tensor. - For non-attention pairs: shape [seq, C, seq, C] but only diagonal (s_in == s_out) is nonzero. - For Q/K/V -> O pairs: shape [seq, C, seq, C] with causal structure (s_in <= s_out). + List of PairAttribution objects. """ n_seq = tokens.shape[1] C = model.C @@ -75,58 +75,78 @@ def compute_local_attributions( cache = comp_output_with_cache.cache - # Initialize output attributions - local_attributions: dict[tuple[str, str], Float[Tensor, "s_in C s_out C"]] = {} + local_attributions: list[PairAttribution] = [] for out_layer, in_layers in tqdm(sources_by_target.items(), desc="Target layers"): out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{out_layer}_pre_detach"] ci_out: Float[Tensor, "1 s C"] = ci.lower_leaky[out_layer] - for in_layer in in_layers: - in_post_detach: Float[Tensor, "1 s C"] = cache[f"{in_layer}_post_detach"] - ci_in: Float[Tensor, "1 s C"] = ci.lower_leaky[in_layer] - - attribution: Float[Tensor, "s_in C s_out C"] = torch.zeros( - n_seq, C, n_seq, C, device=device - ) - - is_attention_pair = is_qkv_to_o_pair(in_layer, out_layer) - - # Determine which (s_out, c_out) pairs are alive - alive_out_mask: Float[Tensor, "1 s C"] = ci_out >= ci_threshold - alive_in_mask: Float[Tensor, "1 s C"] = ci_in >= ci_threshold - - for s_out in tqdm(range(n_seq), desc=f"{in_layer} -> {out_layer}", leave=False): - # Get alive output components at this position - alive_c_out: list[int] = torch.where(alive_out_mask[0, s_out])[0].tolist() - if len(alive_c_out) == 0: - continue - - for c_out in alive_c_out: - grads = torch.autograd.grad( - outputs=out_pre_detach[0, s_out, c_out], - inputs=in_post_detach, - retain_graph=True, - ) - - assert len(grads) == 1 - in_post_detach_grad: Float[Tensor, "1 s C"] = grads[0] - assert in_post_detach_grad is not None, f"Gradient is None for {in_layer}" - - # Weight by input acts and square (we index into the singular batch dimension) - weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] - - # Handle causal attention mask - s_in_range = range(s_out + 1) if is_attention_pair else range(s_out, s_out + 1) - - with torch.no_grad(): + in_post_detaches: list[Float[Tensor, "1 s C"]] = [ + cache[f"{in_layer}_post_detach"] for in_layer in in_layers + ] + ci_ins: list[Float[Tensor, "1 s C"]] = [ci.lower_leaky[in_layer] for in_layer in in_layers] + + attributions: list[Float[Tensor, "s_in C s_out C"]] = [ + torch.zeros(n_seq, C, n_seq, C, device=device) for _ in in_layers + ] + + # NOTE: o->q will be treated as an attention pair even though there are no attrs + # across sequence positions. This is just so we don't have to special case it. + is_attention_output = any(is_kv_to_o_pair(in_layer, out_layer) for in_layer in in_layers) + + # Determine which (s_out, c_out) pairs are alive + alive_out_mask: Float[Tensor, "1 s C"] = ci_out >= ci_threshold + alive_out_c_idxs: list[int] = torch.where(alive_out_mask[0].any(dim=0))[0].tolist() + + alive_in_masks: list[Float[Tensor, "1 s C"]] = [ci_in >= ci_threshold for ci_in in ci_ins] + alive_in_c_idxs: list[list[int]] = [ + torch.where(alive_in_mask[0].any(dim=0))[0].tolist() for alive_in_mask in alive_in_masks + ] + + for s_out in tqdm(range(n_seq), desc=f"{out_layer} -> {in_layers}", leave=False): + # Get alive output components at this position + s_out_alive_c_idxs: list[int] = torch.where(alive_out_mask[0, s_out])[0].tolist() + if len(s_out_alive_c_idxs) == 0: + continue + + for c_out in s_out_alive_c_idxs: + in_post_detach_grads = torch.autograd.grad( + outputs=out_pre_detach[0, s_out, c_out], + inputs=in_post_detaches, + retain_graph=True, + ) + # Handle causal attention mask + s_in_range = range(s_out + 1) if is_attention_output else range(s_out, s_out + 1) + + with torch.no_grad(): + for in_post_detach_grad, in_post_detach, alive_in_mask, attribution in zip( + in_post_detach_grads, + in_post_detaches, + alive_in_masks, + attributions, + strict=True, + ): + # Weight by input acts and square (we index into the singular batch dimension) + weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] for s_in in s_in_range: - # Only include alive input components alive_c_in: list[int] = torch.where(alive_in_mask[0, s_in])[0].tolist() for c_in in alive_c_in: attribution[s_in, c_in, s_out, c_out] = weighted[s_in, c_in] - local_attributions[(in_layer, out_layer)] = attribution + for in_layer, attribution, layer_alive_in_c_idxs in zip( + in_layers, attributions, alive_in_c_idxs, strict=True + ): + trimmed_attribution = attribution[:, layer_alive_in_c_idxs][:, :, :, alive_out_c_idxs] + local_attributions.append( + PairAttribution( + source=in_layer, + target=out_layer, + attribution=trimmed_attribution, + trimmed_c_in_idxs=layer_alive_in_c_idxs, + trimmed_c_out_idxs=alive_out_c_idxs, + is_kv_to_o_pair=is_kv_to_o_pair(in_layer, out_layer), + ) + ) return local_attributions @@ -151,7 +171,6 @@ def compute_local_attributions( tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" sources_by_target = get_sources_by_target(model, device, config, n_blocks) -validate_attention_pair_structure(sources_by_target) n_pairs = sum(len(ins) for ins in sources_by_target.values()) print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") @@ -172,7 +191,7 @@ def compute_local_attributions( # %% # Compute local attributions print("\nComputing local attributions...") -local_attributions = compute_local_attributions( +attr_pairs = compute_local_attributions( model=model, tokens=tokens, sources_by_target=sources_by_target, @@ -183,13 +202,14 @@ def compute_local_attributions( # Print summary statistics print("\nAttribution summary:") -for pair, attr in local_attributions.items(): - nonzero = (attr > 0).sum().item() - total = attr.numel() +# for pair, attr in local_attributions: +for attr_pair in attr_pairs: + nonzero = (attr_pair.attribution > 0).sum().item() + total = attr_pair.attribution.numel() print( - f" {pair[0]} -> {pair[1]}: " + f" {attr_pair.source} -> {attr_pair.target}: " f"nonzero={nonzero}/{total} ({100 * nonzero / total:.2f}%), " - f"max={attr.max():.6f}" + f"max={attr_pair.attribution.max():.6f}" ) # %% @@ -198,40 +218,4 @@ def compute_local_attributions( # Save PyTorch format pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" -save_data = { - "attributions": local_attributions, - "tokens": tokens.cpu(), - "token_strings": token_strings, - "prompt": prompt, -} -torch.save(save_data, pt_path) -print(f"\nSaved PyTorch format to {pt_path}") - -# Convert and save JSON format for web visualization -attributions_json = {} -for (in_layer, out_layer), attr_tensor in local_attributions.items(): - key = f"('{in_layer}', '{out_layer}')" - attributions_json[key] = attr_tensor.cpu().tolist() - -json_data = { - "n_blocks": n_blocks, - "attributions": attributions_json, - "tokens": tokens[0].cpu().tolist(), - "token_strings": token_strings, - "prompt": prompt, -} - -json_path = out_dir / f"local_attributions_{loaded.wandb_id}.json" -with open(json_path, "w") as f: - json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) - -gz_path = out_dir / f"local_attributions_{loaded.wandb_id}.json.gz" -with gzip.open(gz_path, "wt", encoding="utf-8") as f: - json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) - -print(f"Saved JSON format to {json_path}") -print(f"Saved compressed format to {gz_path}") -print(f" - {len(attributions_json)} layer pairs") -print(f" - Sequence length: {tokens.shape[1]}") - # %% diff --git a/spd/scripts/model_loading.py b/spd/scripts/model_loading.py index 6bc63000c..a61fc7ed2 100644 --- a/spd/scripts/model_loading.py +++ b/spd/scripts/model_loading.py @@ -101,8 +101,6 @@ def create_data_loader_from_config( batch_size=batch_size, buffer_size=task_config.buffer_size, global_seed=seed, - ddp_rank=0, - ddp_world_size=1, ) return data_loader, tokenizer From c341c0515442ff3a69f028bfdd7115c5085e9620 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 14:14:11 +0000 Subject: [PATCH 19/36] Misc removals --- spd/scripts/calc_local_attributions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 7fedf2fed..0f4118b48 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -158,13 +158,10 @@ def compute_local_attributions( # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) n_blocks = 1 -batch_size = 1 # Only need 1 for getting sources_by_target -n_ctx = 64 ci_threshold = 1e-6 # prompt = "The quick brown fox" prompt = "Eagerly, a girl named Kim went" -# Load model loaded = load_model_from_wandb(wandb_path) model, config, device = loaded.model, loaded.config, loaded.device From 7d3fc5e17aa1c70cf2959c5562ac14203edffa4c Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 16:51:20 +0000 Subject: [PATCH 20/36] Add wte (NOTE: breaks calc_global_attributions) --- spd/scripts/calc_global_attributions.py | 25 +++++++++++-- spd/scripts/calc_local_attributions.py | 50 ++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 10724d1d2..1f509cb71 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -9,7 +9,7 @@ import torch from jaxtyping import Float from PIL import Image -from torch import Tensor +from torch import Tensor, nn from tqdm.auto import tqdm from spd.configs import Config @@ -168,15 +168,34 @@ def get_sources_by_target( component_masks=component_masks, routing_masks="all", ) + + wte_cache: dict[str, Tensor] = {} + + # Add an extra forward hook to the model to cache the output of model.target_model.wte + def wte_hook( + _module: nn.Module, _args: tuple[Any, ...], _kwargs: dict[Any, Any], output: Tensor + ) -> Any: + output.requires_grad_(True) + # We call it "post_detach" for consistency, we don't bother detaching here as there are + # no modules before it that we care about + wte_cache["wte_post_detach"] = output + return output + + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + with torch.enable_grad(): comp_output_with_cache_grad: OutputWithCache = model( batch, mask_infos=mask_infos, cache_type="component_acts", ) + wte_handle.remove() cache = comp_output_with_cache_grad.cache - layers = [] + cache["wte_post_detach"] = wte_cache["wte_post_detach"] + + layers = ["wte"] layer_names = [ "attn.q_proj", "attn.k_proj", @@ -190,7 +209,7 @@ def get_sources_by_target( test_pairs = [] for in_layer in layers: - for out_layer in layers: + for out_layer in layers[1:]: # Skip wte if layers.index(in_layer) < layers.index(out_layer): test_pairs.append((in_layer, out_layer)) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 0f4118b48..6f00d006a 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -2,10 +2,11 @@ """Compute local attributions for a single prompt.""" from dataclasses import dataclass +from typing import Any import torch from jaxtyping import Float -from torch import Tensor +from torch import Tensor, nn from tqdm.auto import tqdm from transformers import AutoTokenizer from transformers.tokenization_utils_fast import PreTrainedTokenizerFast @@ -68,12 +69,33 @@ def compute_local_attributions( component_masks = ci.lower_leaky mask_infos = make_mask_infos(component_masks=component_masks, routing_masks="all") + # For wte, our component_acts will be (b, s, embedding_dim), instead of (b, s, C). We pretend + # that it has ci values of 1 for the 0th index of the embedding dimension and 0 elsewhere. This + # is because later we sum over the embedding_dim and add a new singleton dimension for the + # component + wte_cache: dict[str, Tensor] = {} + + def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> Any: + output.requires_grad_(True) + # We call it "post_detach" for consistency, we don't bother detaching here as there are + # no modules before it that we care about + wte_cache["wte_post_detach"] = output + return output + + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + with torch.enable_grad(): comp_output_with_cache: OutputWithCache = model( tokens, mask_infos=mask_infos, cache_type="component_acts" ) + wte_handle.remove() + + ci.lower_leaky["wte"] = torch.zeros_like(wte_cache["wte_post_detach"]) + ci.lower_leaky["wte"][:, :, 0] = 1.0 cache = comp_output_with_cache.cache + cache["wte_post_detach"] = wte_cache["wte_post_detach"] local_attributions: list[PairAttribution] = [] @@ -119,15 +141,25 @@ def compute_local_attributions( s_in_range = range(s_out + 1) if is_attention_output else range(s_out, s_out + 1) with torch.no_grad(): - for in_post_detach_grad, in_post_detach, alive_in_mask, attribution in zip( + for ( + in_layer, + in_post_detach_grad, + in_post_detach, + alive_in_mask, + attribution, + ) in zip( + in_layers, in_post_detach_grads, in_post_detaches, alive_in_masks, attributions, strict=True, ): - # Weight by input acts and square (we index into the singular batch dimension) weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] + if in_layer == "wte": + # We actually have shape "s embedding_dim", so we sum over the embedding + # dimension and add a new singleton component dimension + weighted = weighted.sum(dim=1).unsqueeze(1) for s_in in s_in_range: alive_c_in: list[int] = torch.where(alive_in_mask[0, s_in])[0].tolist() for c_in in alive_c_in: @@ -213,6 +245,14 @@ def compute_local_attributions( # Save attributions out_dir = get_out_dir() -# Save PyTorch format +# Save PyTorch format with all necessary data pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" -# %% +save_data = { + "attr_pairs": attr_pairs, + "token_strings": token_strings, + "prompt": prompt, + "ci_threshold": ci_threshold, + "wandb_id": loaded.wandb_id, +} +torch.save(save_data, pt_path) +print(f"\nSaved local attributions to {pt_path}") From 3c0ebae2df0c927959e5fbc7a2d9299b894995b7 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 17:04:27 +0000 Subject: [PATCH 21/36] Add wte to calc_global_attributions.py --- spd/scripts/calc_global_attributions.py | 42 +++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 1f509cb71..f621b8130 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -265,6 +265,8 @@ def compute_global_attributions( where attribution[i, j] is the mean absolute gradient from the i-th alive input component to the j-th alive output component. """ + alive_indices["wte"] = [0] # Treat wte as single alive component + # Initialize accumulators for each (in_layer, out_layer) pair attribution_sums: dict[tuple[str, str], Float[Tensor, "n_alive_in n_alive_out"]] = {} samples_per_pair: dict[tuple[str, str], int] = {} @@ -277,6 +279,18 @@ def compute_global_attributions( ) samples_per_pair[(in_layer, out_layer)] = 0 + # Set up wte hook + wte_handle = None + wte_cache: dict[str, Tensor] = {} + + def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> Any: + output.requires_grad_(True) + wte_cache["wte_post_detach"] = output + return output + + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + batch_pbar = tqdm(enumerate(data_loader), desc="Batches", total=max_batches) for batch_idx, batch_raw in batch_pbar: if batch_idx >= max_batches: @@ -314,6 +328,12 @@ def compute_global_attributions( cache = comp_output_with_cache.cache + # Add wte to cache and CI + cache["wte_post_detach"] = wte_cache["wte_post_detach"] + # Add fake CI for wte: shape (b, s, embedding_dim) with 1.0 at index 0 + ci.lower_leaky["wte"] = torch.zeros_like(wte_cache["wte_post_detach"]) + ci.lower_leaky["wte"][:, :, 0] = 1.0 + # Compute attributions grouped by target layer for out_layer, in_layers in tqdm( sources_by_target.items(), desc="Target layers", leave=False @@ -362,8 +382,16 @@ def compute_global_attributions( for i, in_layer in enumerate(in_layers): grads = grads_tuple[i] assert grads is not None, f"Gradient is None for {in_layer}" - alive_in = alive_indices[in_layer] weighted = grads * in_tensors[i] + + # Special handling for wte: sum over embedding_dim to get single component + if in_layer == "wte": + # weighted is (b, s, embedding_dim), sum to (b, s, 1) + weighted = weighted.sum(dim=-1, keepdim=True) + alive_in = [0] + else: + alive_in = alive_indices[in_layer] + # Only sum contributions from positions s_in <= s_out (causal) weighted_alive = weighted[:, : s_out + 1, alive_in] batch_attributions[in_layer][:, c_enum] += weighted_alive.pow( @@ -389,8 +417,16 @@ def compute_global_attributions( for i, in_layer in enumerate(in_layers): grads = grads_tuple[i] assert grads is not None, f"Gradient is None for {in_layer}" - alive_in = alive_indices[in_layer] weighted = grads * in_tensors[i] + + # Special handling for wte: sum over embedding_dim to get single component + if in_layer == "wte": + # weighted is (b, s, embedding_dim), sum to (b, s, 1) + weighted = weighted.sum(dim=-1, keepdim=True) + alive_in = [0] + else: + alive_in = alive_indices[in_layer] + weighted_alive = weighted[:, :, alive_in] batch_attributions[in_layer][:, c_enum] += weighted_alive.pow(2).sum( dim=(0, 1) @@ -405,6 +441,8 @@ def compute_global_attributions( else: samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq + wte_handle.remove() + global_attributions = {} for pair, attr_sum in attribution_sums.items(): attr: Float[Tensor, "n_alive_in n_alive_out"] = attr_sum / samples_per_pair[pair] From 631013ab580882a52784830619609b3f94429315 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 17:08:47 +0000 Subject: [PATCH 22/36] Add plot_local_attributions.py --- spd/scripts/plot_local_attributions.py | 444 +++++++++++++++++++++++++ 1 file changed, 444 insertions(+) create mode 100644 spd/scripts/plot_local_attributions.py diff --git a/spd/scripts/plot_local_attributions.py b/spd/scripts/plot_local_attributions.py new file mode 100644 index 000000000..eb42dca98 --- /dev/null +++ b/spd/scripts/plot_local_attributions.py @@ -0,0 +1,444 @@ +# %% +"""Plot local attribution graph from saved .pt file.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Float +from matplotlib.collections import LineCollection +from torch import Tensor + +from spd.scripts.calc_local_attributions import PairAttribution +from spd.scripts.model_loading import get_out_dir + + +@dataclass +class NodeInfo: + """Information about a node in the attribution graph.""" + + layer: str + seq_pos: int + component_idx: int + x: float + y: float + importance: float + + +def get_layer_order() -> list[str]: + """Get the canonical ordering of sublayers within a block.""" + return [ + "wte", # Word token embeddings (first layer) + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "attn.o_proj", + "mlp.c_fc", + "mlp.down_proj", + ] + + +def get_layer_color(layer: str) -> str: + """Get color for a layer based on its type.""" + if layer == "wte": + return "#34495E" # Dark blue-gray for embeddings + + colors = { + "attn.q_proj": "#E67E22", # Orange + "attn.k_proj": "#27AE60", # Green + "attn.v_proj": "#F1C40F", # Yellow + "attn.o_proj": "#E74C3C", # Red + "mlp.c_fc": "#3498DB", # Blue + "mlp.down_proj": "#9B59B6", # Purple + } + for sublayer, color in colors.items(): + if sublayer in layer: + return color + return "#95A5A6" # Gray fallback + + +def parse_layer_name(layer: str) -> tuple[int, str]: + """Parse layer name into block index and sublayer type. + + E.g., "h.0.attn.q_proj" -> (0, "attn.q_proj") + "wte" -> (-1, "wte") + """ + if layer == "wte": + return -1, "wte" + + parts = layer.split(".") + block_idx = int(parts[1]) + sublayer = ".".join(parts[2:]) + return block_idx, sublayer + + +def compute_layer_y_positions( + attr_pairs: list[PairAttribution], +) -> dict[str, float]: + """Compute Y position for each layer. + + Layers are ordered by block, then by sublayer type within block. + Each layer is equidistant. + + Returns: + Dict mapping layer name to y position. + """ + # Collect all unique layers + layers = set() + for pair in attr_pairs: + layers.add(pair.source) + layers.add(pair.target) + + # Parse and sort + layer_order = get_layer_order() + parsed = [(layer, *parse_layer_name(layer)) for layer in layers] + + def sort_key(item: tuple[str, int, str]) -> tuple[int, int]: + _, block_idx, sublayer = item + sublayer_idx = layer_order.index(sublayer) if sublayer in layer_order else 999 + return (block_idx, sublayer_idx) + + sorted_layers = sorted(parsed, key=sort_key) + + # Assign equidistant Y positions + y_positions = {} + for i, (layer, _, _) in enumerate(sorted_layers): + y_positions[layer] = float(i) + + return y_positions + + +def compute_node_importances( + attr_pairs: list[PairAttribution], + n_seq: int, +) -> dict[str, Float[Tensor, "seq C"]]: + """Compute importance values for nodes based on total attribution flow. + + Returns a dict mapping layer -> tensor of shape [n_seq, max_component_idx+1]. + Importance is the sum of incoming and outgoing attributions. + """ + # First pass: determine max component index per layer + layer_max_c: dict[str, int] = {} + for pair in attr_pairs: + src_max = max(pair.trimmed_c_in_idxs) if pair.trimmed_c_in_idxs else 0 + tgt_max = max(pair.trimmed_c_out_idxs) if pair.trimmed_c_out_idxs else 0 + layer_max_c[pair.source] = max(layer_max_c.get(pair.source, 0), src_max) + layer_max_c[pair.target] = max(layer_max_c.get(pair.target, 0), tgt_max) + + # Initialize importance tensors + importances: dict[str, Float[Tensor, "seq C"]] = {} + for layer, max_c in layer_max_c.items(): + importances[layer] = torch.zeros(n_seq, max_c + 1) + + # Accumulate attribution magnitudes + for pair in attr_pairs: + attr = pair.attribution.abs() # [s_in, trimmed_c_in, s_out, trimmed_c_out] + + # Sum over output dimensions -> importance for source nodes + src_importance = attr.sum(dim=(2, 3)) # [s_in, trimmed_c_in] + for i, c_in in enumerate(pair.trimmed_c_in_idxs): + importances[pair.source][:, c_in] += src_importance[:, i] + + # Sum over input dimensions -> importance for target nodes + tgt_importance = attr.sum(dim=(0, 1)) # [s_out, trimmed_c_out] + for j, c_out in enumerate(pair.trimmed_c_out_idxs): + importances[pair.target][:, c_out] += tgt_importance[:, j] + + return importances + + +def plot_local_attribution_graph( + attr_pairs: list[PairAttribution], + token_strings: list[str], + min_edge_weight: float = 0.001, + node_scale: float = 30.0, + edge_alpha_scale: float = 0.5, + figsize: tuple[float, float] | None = None, + max_grid_cols: int = 8, +) -> plt.Figure: + """Plot the local attribution graph. + + Args: + attr_pairs: List of PairAttribution objects from compute_local_attributions. + token_strings: List of token strings for x-axis labels. + min_edge_weight: Minimum edge weight to display. + node_scale: Fixed size for all nodes. + edge_alpha_scale: Scale factor for edge transparency. + figsize: Figure size (width, height). Auto-computed if None. + max_grid_cols: Maximum number of columns in the grid per layer. + + Returns: + Matplotlib figure. + """ + n_seq = len(token_strings) + + # Compute node importances first + importances = compute_node_importances(attr_pairs, n_seq) + + # Compute layout + layer_y = compute_layer_y_positions(attr_pairs) + + # Auto-compute figure size + if figsize is None: + total_height = max(layer_y.values()) + figsize = (max(16, n_seq * 2), max(8, total_height * 1.2)) + + fig, ax = plt.subplots(figsize=figsize, facecolor="white") + ax.set_facecolor("#FAFAFA") + + # Collect all nodes and their positions + nodes: list[NodeInfo] = [] + node_lookup: dict[tuple[str, int, int], NodeInfo] = {} # (layer, seq, comp) -> NodeInfo + + # X spacing: spread tokens across the plot + x_positions = np.linspace(0.1, 0.9, n_seq) + + # Grid layout parameters + col_spacing = 0.012 # Horizontal spacing between columns in grid + row_spacing = 0.08 # Vertical spacing between rows in grid + + for layer, y_center in layer_y.items(): + if layer not in importances: + continue + + layer_imp = importances[layer] # [n_seq, max_c+1] + alive_mask = layer_imp > 0 + + # Find all components that are alive at ANY sequence position + all_alive_components = torch.where(alive_mask.any(dim=0))[0].tolist() + n_components = len(all_alive_components) + + if n_components == 0: + continue + + # Calculate grid dimensions (same for all positions) + n_rows = (n_components + max_grid_cols - 1) // max_grid_cols + n_cols = min(n_components, max_grid_cols) + + # Process each sequence position separately + for s in range(n_seq): + # Center the grid at this sequence position + x_base = x_positions[s] + + # Arrange all components in grid (not just alive ones at this position) + for local_idx, c in enumerate(all_alive_components): + col = local_idx % max_grid_cols + row = local_idx // max_grid_cols + + # Position within grid, centered on sequence position + x_offset = (col - (n_cols - 1) / 2) * col_spacing + y_offset = (row - (n_rows - 1) / 2) * row_spacing + + x = x_base + x_offset + y = y_center + y_offset + + imp = layer_imp[s, c].item() + node = NodeInfo( + layer=layer, + seq_pos=s, + component_idx=c, + x=x, + y=y, + importance=imp, + ) + nodes.append(node) + node_lookup[(layer, s, c)] = node + + # Collect edges + edges: list[tuple[NodeInfo, NodeInfo, float]] = [] + + for pair in attr_pairs: + attr = pair.attribution # [s_in, trimmed_c_in, s_out, trimmed_c_out] + + for i, c_in in enumerate(pair.trimmed_c_in_idxs): + for j, c_out in enumerate(pair.trimmed_c_out_idxs): + for s_in in range(attr.shape[0]): + for s_out in range(attr.shape[2]): + weight = attr[s_in, i, s_out, j].abs().item() + if weight < min_edge_weight: + continue + + src_key = (pair.source, s_in, c_in) + tgt_key = (pair.target, s_out, c_out) + + if src_key in node_lookup and tgt_key in node_lookup: + src_node = node_lookup[src_key] + tgt_node = node_lookup[tgt_key] + edges.append((src_node, tgt_node, weight)) + + edges_by_target: dict[tuple[str, int, int], list[tuple[NodeInfo, NodeInfo, float]]] = {} + for src, tgt, w in edges: + key = (tgt.layer, tgt.seq_pos, tgt.component_idx) + if key not in edges_by_target: + edges_by_target[key] = [] + edges_by_target[key].append((src, tgt, w)) + + sorted_edges = [] + for target_edges in edges_by_target.values(): + sorted_edges.extend(sorted(target_edges, key=lambda e: e[2], reverse=True)) + edges = sorted_edges + + # Normalize edge weights for alpha + if edges: + edge_weights = [e[2] for e in edges] + max_edge = max(edge_weights) + if max_edge > 0: + edges = [(s, t, w / max_edge) for s, t, w in edges] + + # Track which nodes have edges + nodes_with_edges = set() + for src, tgt, _ in edges: + nodes_with_edges.add((src.layer, src.seq_pos, src.component_idx)) + nodes_with_edges.add((tgt.layer, tgt.seq_pos, tgt.component_idx)) + + # Draw edges + if edges: + lines = [] + alphas = [] + for src, tgt, w in edges: + lines.append([(src.x, src.y), (tgt.x, tgt.y)]) + alphas.append(w * edge_alpha_scale) + + lc = LineCollection( + lines, + colors=[(0.5, 0.5, 0.5, a) for a in alphas], + linewidths=0.5, + zorder=1, + ) + ax.add_collection(lc) + + # Draw nodes + for node in nodes: + node_key = (node.layer, node.seq_pos, node.component_idx) + has_edges = node_key in nodes_with_edges + + if has_edges: + color = get_layer_color(node.layer) + alpha = 0.9 + else: + color = "#D3D3D3" + alpha = 0.3 + + ax.scatter( + node.x, + node.y, + s=node_scale, + c=color, + edgecolors="white", + linewidths=0.5, + zorder=2, + alpha=alpha, + ) + + # Configure axes + total_height = max(layer_y.values()) + ax.set_xlim(0, 1) + ax.set_ylim(-0.5, total_height + 0.5) + + # X-axis: token labels + ax.set_xticks(x_positions) + ax.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=9) + ax.xaxis.set_ticks_position("bottom") + + # Y-axis: layer labels + layer_names_sorted = sorted(layer_y.keys(), key=lambda x: layer_y[x]) + layer_centers = [layer_y[layer] for layer in layer_names_sorted] + ax.set_yticks(layer_centers) + ax.set_yticklabels( + [layer.replace(".", "\n", 1) for layer in layer_names_sorted], + fontsize=9, + ) + + # Add horizontal lines to separate layers + for y in layer_y.values(): + ax.axhline(y=y - 0.5, color="gray", linestyle="--", linewidth=0.5, alpha=0.3) + + # Grid + ax.grid(False) + + # Title + ax.set_title("Local Attribution Graph", fontsize=14, fontweight="bold", pad=10) + + # Legend for layer colors + layer_order = get_layer_order() + legend_elements = [] + for sublayer in reversed(layer_order): + color = get_layer_color(sublayer) + legend_elements.append( + plt.scatter([], [], c=color, s=50, label=sublayer, edgecolors="white") + ) + ax.legend( + handles=legend_elements, + loc="upper left", + bbox_to_anchor=(1.01, 1), + fontsize=8, + framealpha=0.9, + ) + + plt.tight_layout() + return fig + + +def load_and_plot( + pt_path: Path, + output_path: Path | None = None, + **plot_kwargs: Any, +) -> plt.Figure: + """Load attributions from .pt file and create plot. + + Args: + pt_path: Path to the saved .pt file. + output_path: Optional path to save the figure. + **plot_kwargs: Additional kwargs passed to plot_local_attribution_graph. + + Returns: + Matplotlib figure. + """ + data = torch.load(pt_path, weights_only=False) + attr_pairs: list[PairAttribution] = data["attr_pairs"] + token_strings: list[str] = data["token_strings"] + + print(f"Loaded attributions from {pt_path}") + print(f" Prompt: {data.get('prompt', 'N/A')!r}") + print(f" Tokens: {token_strings}") + print(f" Number of layer pairs: {len(attr_pairs)}") + + fig = plot_local_attribution_graph( + attr_pairs=attr_pairs, + token_strings=token_strings, + **plot_kwargs, + ) + + if output_path is not None: + fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") + print(f"Saved figure to {output_path}") + + return fig + + +# %% +if __name__ == "__main__": + # Configuration + wandb_id = "33n6xjjt" # ss_gpt2_simple-1L (new) + + out_dir = get_out_dir() + pt_path = out_dir / f"local_attributions_{wandb_id}.pt" + + if not pt_path.exists(): + raise FileNotFoundError( + f"Local attributions file not found: {pt_path}\n" + "Run calc_local_attributions.py first to generate the data." + ) + + output_path = out_dir / f"local_attribution_graph_{wandb_id}.png" + + fig = load_and_plot( + pt_path=pt_path, + output_path=output_path, + min_edge_weight=0.0001, + node_scale=30.0, + edge_alpha_scale=0.7, + ) From 7774afa2b30189bca09ccedf4dd820b8c51fe66c Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 28 Nov 2025 19:14:43 +0000 Subject: [PATCH 23/36] Tweak matplotlib plot --- spd/scripts/plot_local_attributions.py | 113 ++++++++++++++++++++----- 1 file changed, 94 insertions(+), 19 deletions(-) diff --git a/spd/scripts/plot_local_attributions.py b/spd/scripts/plot_local_attributions.py index eb42dca98..92ef577b9 100644 --- a/spd/scripts/plot_local_attributions.py +++ b/spd/scripts/plot_local_attributions.py @@ -33,14 +33,27 @@ def get_layer_order() -> list[str]: return [ "wte", # Word token embeddings (first layer) "attn.q_proj", - "attn.k_proj", "attn.v_proj", + "attn.k_proj", "attn.o_proj", "mlp.c_fc", "mlp.down_proj", ] +# Sublayers that share a row (q, v, k displayed side by side) +QVK_SUBLAYERS = {"attn.q_proj", "attn.v_proj", "attn.k_proj"} + +# Column allocation for q, v, k: (n_cols, start_col) out of 12 total columns +# Layout: Q(2) | gap(1) | V(4) | gap(1) | K(4) = 12 total +QVK_LAYOUT: dict[str, tuple[int, int]] = { + "attn.q_proj": (2, 0), # 2 columns, starts at 0 + "attn.v_proj": (4, 3), # 4 columns, starts at 3 (1 col gap after q) + "attn.k_proj": (4, 8), # 4 columns, starts at 8 (1 col gap after v) +} +QVK_TOTAL_COLS = 12 + + def get_layer_color(layer: str) -> str: """Get color for a layer based on its type.""" if layer == "wte": @@ -81,7 +94,7 @@ def compute_layer_y_positions( """Compute Y position for each layer. Layers are ordered by block, then by sublayer type within block. - Each layer is equidistant. + q, v, k layers share the same row (same y position). Returns: Dict mapping layer name to y position. @@ -103,10 +116,27 @@ def sort_key(item: tuple[str, int, str]) -> tuple[int, int]: sorted_layers = sorted(parsed, key=sort_key) - # Assign equidistant Y positions + # Assign Y positions, grouping q, v, k on the same row y_positions = {} - for i, (layer, _, _) in enumerate(sorted_layers): - y_positions[layer] = float(i) + current_y = 0.0 + prev_block_idx = None + prev_was_qvk = False + + for layer, block_idx, sublayer in sorted_layers: + is_qvk = sublayer in QVK_SUBLAYERS + + # Check if we should share y with previous layer + if is_qvk and prev_was_qvk and block_idx == prev_block_idx: + # Same row as previous q/v/k layer + y_positions[layer] = current_y + else: + # New row + if y_positions: # Not the first layer + current_y += 1.0 + y_positions[layer] = current_y + + prev_block_idx = block_idx + prev_was_qvk = is_qvk return y_positions @@ -157,7 +187,7 @@ def plot_local_attribution_graph( node_scale: float = 30.0, edge_alpha_scale: float = 0.5, figsize: tuple[float, float] | None = None, - max_grid_cols: int = 8, + max_grid_cols: int = 12, ) -> plt.Figure: """Plot the local attribution graph. @@ -214,9 +244,20 @@ def plot_local_attribution_graph( if n_components == 0: continue - # Calculate grid dimensions (same for all positions) - n_rows = (n_components + max_grid_cols - 1) // max_grid_cols - n_cols = min(n_components, max_grid_cols) + # Check if this is a q/v/k layer that shares a row + _, sublayer = parse_layer_name(layer) + is_qvk = sublayer in QVK_SUBLAYERS + + if is_qvk: + # Use the allocated columns for this sublayer + layer_max_cols, start_col = QVK_LAYOUT[sublayer] + else: + layer_max_cols = max_grid_cols + start_col = 0 + + # Calculate grid dimensions for this layer + n_rows = (n_components + layer_max_cols - 1) // layer_max_cols + n_cols = min(n_components, layer_max_cols) # Process each sequence position separately for s in range(n_seq): @@ -225,11 +266,22 @@ def plot_local_attribution_graph( # Arrange all components in grid (not just alive ones at this position) for local_idx, c in enumerate(all_alive_components): - col = local_idx % max_grid_cols - row = local_idx // max_grid_cols + col = local_idx % layer_max_cols + row = local_idx // layer_max_cols + + if is_qvk: + # Position within the allocated horizontal segment + # Center of this sublayer's segment within the total QVK row + segment_center = start_col + layer_max_cols / 2 + total_center = QVK_TOTAL_COLS / 2 + # Offset from center of entire row + segment_offset = (segment_center - total_center) * col_spacing + # Position within segment, centered + x_offset = segment_offset + (col - (n_cols - 1) / 2) * col_spacing + else: + # Position within grid, centered on sequence position + x_offset = (col - (n_cols - 1) / 2) * col_spacing - # Position within grid, centered on sequence position - x_offset = (col - (n_cols - 1) / 2) * col_spacing y_offset = (row - (n_rows - 1) / 2) * row_spacing x = x_base + x_offset @@ -343,14 +395,37 @@ def plot_local_attribution_graph( ax.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=9) ax.xaxis.set_ticks_position("bottom") - # Y-axis: layer labels + # Y-axis: layer labels (group q/v/k into single label) layer_names_sorted = sorted(layer_y.keys(), key=lambda x: layer_y[x]) - layer_centers = [layer_y[layer] for layer in layer_names_sorted] + # Deduplicate y positions and create combined labels for q/v/k rows + y_to_layers: dict[float, list[str]] = {} + for layer in layer_names_sorted: + y = layer_y[layer] + if y not in y_to_layers: + y_to_layers[y] = [] + y_to_layers[y].append(layer) + + layer_centers = sorted(y_to_layers.keys()) + layer_labels = [] + for y in layer_centers: + layers_at_y = y_to_layers[y] + if len(layers_at_y) > 1: + # Multiple layers at same y (q/v/k row) + # Extract block prefix and combine sublayer names + block_idx, _ = parse_layer_name(layers_at_y[0]) + sublayers = [parse_layer_name(layer)[1] for layer in layers_at_y] + # Order: q, v, k + ordered = [] + for sub in ["attn.q_proj", "attn.v_proj", "attn.k_proj"]: + if sub in sublayers: + ordered.append(sub.split(".")[-1]) + label = f"h.{block_idx}\n" + "/".join(ordered) + else: + label = layers_at_y[0].replace(".", "\n", 1) + layer_labels.append(label) + ax.set_yticks(layer_centers) - ax.set_yticklabels( - [layer.replace(".", "\n", 1) for layer in layer_names_sorted], - fontsize=9, - ) + ax.set_yticklabels(layer_labels, fontsize=9) # Add horizontal lines to separate layers for y in layer_y.values(): From f580c83f18dae36e5e40fc2d316d103d42c0fa1e Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Sat, 29 Nov 2025 11:43:40 +0000 Subject: [PATCH 24/36] Add output and fix masks --- spd/scripts/calc_global_attributions.py | 20 +-- spd/scripts/calc_local_attributions.py | 197 ++++++++++++++++-------- spd/scripts/plot_local_attributions.py | 146 +++++++++++++----- 3 files changed, 258 insertions(+), 105 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index f621b8130..7e55206d5 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -162,10 +162,9 @@ def get_sources_by_target( detach_inputs=False, ) - # Create masks for component replacement (use all components with causal importance as mask) - component_masks = ci.lower_leaky + # Create masks so we can use all components (without masks) mask_infos = make_mask_infos( - component_masks=component_masks, + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, routing_masks="all", ) @@ -194,9 +193,10 @@ def wte_hook( cache = comp_output_with_cache_grad.cache cache["wte_post_detach"] = wte_cache["wte_post_detach"] + cache["output_pre_detach"] = comp_output_with_cache_grad.output layers = ["wte"] - layer_names = [ + component_layers = [ "attn.q_proj", "attn.k_proj", "attn.v_proj", @@ -205,11 +205,12 @@ def wte_hook( "mlp.down_proj", ] for i in range(n_blocks): - layers.extend([f"h.{i}.{layer_name}" for layer_name in layer_names]) + layers.extend([f"h.{i}.{layer_name}" for layer_name in component_layers]) + layers.append("output") test_pairs = [] - for in_layer in layers: - for out_layer in layers[1:]: # Skip wte + for in_layer in layers[:-1]: # Don't include "output" in + for out_layer in layers[1:]: # Don't include "wte" in out_layers if layers.index(in_layer) < layers.index(out_layer): test_pairs.append((in_layer, out_layer)) @@ -312,10 +313,9 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An detach_inputs=False, ) - # Create masks and run forward pass with gradient tracking - component_masks = ci.lower_leaky + # Create masks so we can use all components (without masks) mask_infos = make_mask_infos( - component_masks=component_masks, + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, routing_masks="all", ) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 6f00d006a..9c6ef0b9e 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -5,7 +5,7 @@ from typing import Any import torch -from jaxtyping import Float +from jaxtyping import Bool, Float from torch import Tensor, nn from tqdm.auto import tqdm from transformers import AutoTokenizer @@ -18,6 +18,42 @@ from spd.scripts.model_loading import get_out_dir, load_model_from_wandb +@dataclass +class LayerAliveInfo: + """Info about alive components for a layer.""" + + alive_mask: Bool[Tensor, "1 s dim"] # Which (pos, component) pairs are alive + alive_c_idxs: list[int] # Components alive at any position + c_to_trimmed: dict[int, int] # original idx -> trimmed idx + + +def compute_layer_alive_info( + layer_name: str, + ci_lower_leaky: dict[str, Tensor], + output_probs: Float[Tensor, "1 s vocab"] | None, + ci_threshold: float, + output_prob_threshold: float, + n_seq: int, + device: str, +) -> LayerAliveInfo: + """Compute alive info for a layer. Handles regular, wte, and output layers.""" + if layer_name == "wte": + # WTE: single pseudo-component, always alive at all positions + alive_mask = torch.ones(1, n_seq, 1, device=device, dtype=torch.bool) + alive_c_idxs = [0] + elif layer_name == "output": + assert output_probs is not None + alive_mask = output_probs >= output_prob_threshold + alive_c_idxs = torch.where(alive_mask[0].any(dim=0))[0].tolist() + else: + ci = ci_lower_leaky[layer_name] + alive_mask = ci >= ci_threshold + alive_c_idxs = torch.where(alive_mask[0].any(dim=0))[0].tolist() + + c_to_trimmed = {c: i for i, c in enumerate(alive_c_idxs)} + return LayerAliveInfo(alive_mask, alive_c_idxs, c_to_trimmed) + + @dataclass class PairAttribution: source: str @@ -33,6 +69,7 @@ def compute_local_attributions( tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], ci_threshold: float, + output_prob_threshold: float, sampling: SamplingType, device: str, ) -> list[PairAttribution]: @@ -47,6 +84,7 @@ def compute_local_attributions( tokens: Tokenized prompt of shape [1, seq_len]. sources_by_target: Dict mapping out_layer -> list of in_layers. ci_threshold: Threshold for considering a component alive at a position. + output_prob_threshold: Threshold for considering an output logit alive (on softmax probs). sampling: Sampling type to use for causal importances. device: Device to run on. @@ -54,7 +92,6 @@ def compute_local_attributions( List of PairAttribution objects. """ n_seq = tokens.shape[1] - C = model.C with torch.no_grad(): output_with_cache: OutputWithCache = model(tokens, cache_type="input") @@ -66,13 +103,12 @@ def compute_local_attributions( detach_inputs=False, ) - component_masks = ci.lower_leaky - mask_infos = make_mask_infos(component_masks=component_masks, routing_masks="all") + # Create masks so we can use all components (without masks) + mask_infos = make_mask_infos( + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + routing_masks="all", + ) - # For wte, our component_acts will be (b, s, embedding_dim), instead of (b, s, C). We pretend - # that it has ci values of 1 for the 0th index of the embedding dimension and 0 elsewhere. This - # is because later we sum over the embedding_dim and add a new singleton dimension for the - # component wte_cache: dict[str, Tensor] = {} def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> Any: @@ -91,47 +127,60 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An ) wte_handle.remove() - ci.lower_leaky["wte"] = torch.zeros_like(wte_cache["wte_post_detach"]) - ci.lower_leaky["wte"][:, :, 0] = 1.0 - cache = comp_output_with_cache.cache cache["wte_post_detach"] = wte_cache["wte_post_detach"] + cache["output_pre_detach"] = comp_output_with_cache.output + + # Compute output probabilities for thresholding + output_probs = torch.softmax(comp_output_with_cache.output, dim=-1) + + # Compute alive info for all layers upfront + all_layers: set[str] = set(sources_by_target.keys()) + for sources in sources_by_target.values(): + all_layers.update(sources) + + alive_info: dict[str, LayerAliveInfo] = {} + for layer in all_layers: + alive_info[layer] = compute_layer_alive_info( + layer, ci.lower_leaky, output_probs, ci_threshold, output_prob_threshold, n_seq, device + ) local_attributions: list[PairAttribution] = [] - for out_layer, in_layers in tqdm(sources_by_target.items(), desc="Target layers"): - out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{out_layer}_pre_detach"] - ci_out: Float[Tensor, "1 s C"] = ci.lower_leaky[out_layer] + for target, sources in tqdm(sources_by_target.items(), desc="Target layers"): + target_info = alive_info[target] + out_pre_detach: Float[Tensor, "1 s dim"] = cache[f"{target}_pre_detach"] - in_post_detaches: list[Float[Tensor, "1 s C"]] = [ - cache[f"{in_layer}_post_detach"] for in_layer in in_layers + source_infos = [alive_info[source] for source in sources] + in_post_detaches: list[Float[Tensor, "1 s dim"]] = [ + cache[f"{source}_post_detach"] for source in sources ] - ci_ins: list[Float[Tensor, "1 s C"]] = [ci.lower_leaky[in_layer] for in_layer in in_layers] - attributions: list[Float[Tensor, "s_in C s_out C"]] = [ - torch.zeros(n_seq, C, n_seq, C, device=device) for _ in in_layers + # Initialize attribution tensors at final trimmed size + attributions: list[Float[Tensor, "s_in n_c_in s_out n_c_out"]] = [ + torch.zeros( + n_seq, + len(source_info.alive_c_idxs), + n_seq, + len(target_info.alive_c_idxs), + device=device, + ) + for source_info in source_infos ] # NOTE: o->q will be treated as an attention pair even though there are no attrs # across sequence positions. This is just so we don't have to special case it. - is_attention_output = any(is_kv_to_o_pair(in_layer, out_layer) for in_layer in in_layers) - - # Determine which (s_out, c_out) pairs are alive - alive_out_mask: Float[Tensor, "1 s C"] = ci_out >= ci_threshold - alive_out_c_idxs: list[int] = torch.where(alive_out_mask[0].any(dim=0))[0].tolist() - - alive_in_masks: list[Float[Tensor, "1 s C"]] = [ci_in >= ci_threshold for ci_in in ci_ins] - alive_in_c_idxs: list[list[int]] = [ - torch.where(alive_in_mask[0].any(dim=0))[0].tolist() for alive_in_mask in alive_in_masks - ] + is_attention_output = any(is_kv_to_o_pair(source, target) for source in sources) - for s_out in tqdm(range(n_seq), desc=f"{out_layer} -> {in_layers}", leave=False): + for s_out in tqdm(range(n_seq), desc=f"{target} <- {sources}", leave=False): # Get alive output components at this position - s_out_alive_c_idxs: list[int] = torch.where(alive_out_mask[0, s_out])[0].tolist() - if len(s_out_alive_c_idxs) == 0: + s_out_alive_c: list[int] = [ + c for c in target_info.alive_c_idxs if target_info.alive_mask[0, s_out, c] + ] + if not s_out_alive_c: continue - for c_out in s_out_alive_c_idxs: + for c_out in s_out_alive_c: in_post_detach_grads = torch.autograd.grad( outputs=out_pre_detach[0, s_out, c_out], inputs=in_post_detaches, @@ -139,44 +188,43 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An ) # Handle causal attention mask s_in_range = range(s_out + 1) if is_attention_output else range(s_out, s_out + 1) + trimmed_c_out = target_info.c_to_trimmed[c_out] with torch.no_grad(): - for ( - in_layer, - in_post_detach_grad, - in_post_detach, - alive_in_mask, - attribution, - ) in zip( - in_layers, + for source, source_info, grad, in_post_detach, attr in zip( + sources, + source_infos, in_post_detach_grads, in_post_detaches, - alive_in_masks, attributions, strict=True, ): - weighted: Float[Tensor, "s C"] = (in_post_detach_grad * in_post_detach)[0] - if in_layer == "wte": - # We actually have shape "s embedding_dim", so we sum over the embedding - # dimension and add a new singleton component dimension - weighted = weighted.sum(dim=1).unsqueeze(1) + weighted: Float[Tensor, "s dim"] = (grad * in_post_detach)[0] + if source == "wte": + # Sum over embedding_dim to get single pseudo-component + weighted = weighted.sum(dim=1, keepdim=True) + for s_in in s_in_range: - alive_c_in: list[int] = torch.where(alive_in_mask[0, s_in])[0].tolist() + alive_c_in = [ + c + for c in source_info.alive_c_idxs + if source_info.alive_mask[0, s_in, c] + ] for c_in in alive_c_in: - attribution[s_in, c_in, s_out, c_out] = weighted[s_in, c_in] + trimmed_c_in = source_info.c_to_trimmed[c_in] + attr[s_in, trimmed_c_in, s_out, trimmed_c_out] = weighted[ + s_in, c_in + ] - for in_layer, attribution, layer_alive_in_c_idxs in zip( - in_layers, attributions, alive_in_c_idxs, strict=True - ): - trimmed_attribution = attribution[:, layer_alive_in_c_idxs][:, :, :, alive_out_c_idxs] + for source, source_info, attr in zip(sources, source_infos, attributions, strict=True): local_attributions.append( PairAttribution( - source=in_layer, - target=out_layer, - attribution=trimmed_attribution, - trimmed_c_in_idxs=layer_alive_in_c_idxs, - trimmed_c_out_idxs=alive_out_c_idxs, - is_kv_to_o_pair=is_kv_to_o_pair(in_layer, out_layer), + source=source, + target=target, + attribution=attr, + trimmed_c_in_idxs=source_info.alive_c_idxs, + trimmed_c_out_idxs=target_info.alive_c_idxs, + is_kv_to_o_pair=is_kv_to_o_pair(source, target), ) ) @@ -191,8 +239,10 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) n_blocks = 1 ci_threshold = 1e-6 +output_prob_threshold = 1e-1 # prompt = "The quick brown fox" -prompt = "Eagerly, a girl named Kim went" +# prompt = "Eagerly, a girl named Kim went" +prompt = "They walked hand in" loaded = load_model_from_wandb(wandb_path) model, config, device = loaded.model, loaded.config, loaded.device @@ -225,6 +275,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An tokens=tokens, sources_by_target=sources_by_target, ci_threshold=ci_threshold, + output_prob_threshold=output_prob_threshold, sampling=config.sampling, device=device, ) @@ -237,6 +288,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An total = attr_pair.attribution.numel() print( f" {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, " f"nonzero={nonzero}/{total} ({100 * nonzero / total:.2f}%), " f"max={attr_pair.attribution.max():.6f}" ) @@ -247,11 +299,36 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An # Save PyTorch format with all necessary data pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" + +# Get output token labels and probabilities for alive output components +# We need the output probabilities to get per-position probs +with torch.no_grad(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + output_probs_full = torch.softmax(output_with_cache.output, dim=-1) # [1, seq, vocab] + +output_token_labels: dict[int, str] = {} +# output_probs_by_pos: dict mapping (seq_pos, component_idx) -> probability +output_probs_by_pos: dict[tuple[int, int], float] = {} +for attr_pair in attr_pairs: + if attr_pair.target == "output": + for c_idx in attr_pair.trimmed_c_out_idxs: + if c_idx not in output_token_labels: + output_token_labels[c_idx] = tokenizer.decode([c_idx]) + # Store probability for each position + for s in range(tokens.shape[1]): + prob = output_probs_full[0, s, c_idx].item() + if prob >= output_prob_threshold: + output_probs_by_pos[(s, c_idx)] = prob + break + save_data = { "attr_pairs": attr_pairs, "token_strings": token_strings, "prompt": prompt, "ci_threshold": ci_threshold, + "output_prob_threshold": output_prob_threshold, + "output_token_labels": output_token_labels, + "output_probs_by_pos": output_probs_by_pos, "wandb_id": loaded.wandb_id, } torch.save(save_data, pt_path) diff --git a/spd/scripts/plot_local_attributions.py b/spd/scripts/plot_local_attributions.py index 92ef577b9..736d12059 100644 --- a/spd/scripts/plot_local_attributions.py +++ b/spd/scripts/plot_local_attributions.py @@ -78,9 +78,12 @@ def parse_layer_name(layer: str) -> tuple[int, str]: E.g., "h.0.attn.q_proj" -> (0, "attn.q_proj") "wte" -> (-1, "wte") + "output" -> (999, "output") """ if layer == "wte": return -1, "wte" + if layer == "output": + return 999, "output" parts = layer.split(".") block_idx = int(parts[1]) @@ -188,6 +191,9 @@ def plot_local_attribution_graph( edge_alpha_scale: float = 0.5, figsize: tuple[float, float] | None = None, max_grid_cols: int = 12, + output_token_labels: dict[int, str] | None = None, + output_prob_threshold: float | None = None, + output_probs_by_pos: dict[tuple[int, int], float] | None = None, ) -> plt.Figure: """Plot the local attribution graph. @@ -199,6 +205,9 @@ def plot_local_attribution_graph( edge_alpha_scale: Scale factor for edge transparency. figsize: Figure size (width, height). Auto-computed if None. max_grid_cols: Maximum number of columns in the grid per layer. + output_token_labels: Dict mapping output component indices to token strings. + output_prob_threshold: Threshold used for filtering output probabilities. + output_probs_by_pos: Dict mapping (seq_pos, component_idx) -> probability for output layer. Returns: Matplotlib figure. @@ -247,6 +256,7 @@ def plot_local_attribution_graph( # Check if this is a q/v/k layer that shares a row _, sublayer = parse_layer_name(layer) is_qvk = sublayer in QVK_SUBLAYERS + is_output_layer = layer == "output" if is_qvk: # Use the allocated columns for this sublayer @@ -264,40 +274,76 @@ def plot_local_attribution_graph( # Center the grid at this sequence position x_base = x_positions[s] - # Arrange all components in grid (not just alive ones at this position) - for local_idx, c in enumerate(all_alive_components): - col = local_idx % layer_max_cols - row = local_idx // layer_max_cols - - if is_qvk: - # Position within the allocated horizontal segment - # Center of this sublayer's segment within the total QVK row - segment_center = start_col + layer_max_cols / 2 - total_center = QVK_TOTAL_COLS / 2 - # Offset from center of entire row - segment_offset = (segment_center - total_center) * col_spacing - # Position within segment, centered - x_offset = segment_offset + (col - (n_cols - 1) / 2) * col_spacing - else: - # Position within grid, centered on sequence position - x_offset = (col - (n_cols - 1) / 2) * col_spacing - - y_offset = (row - (n_rows - 1) / 2) * row_spacing - - x = x_base + x_offset - y = y_center + y_offset - - imp = layer_imp[s, c].item() - node = NodeInfo( - layer=layer, - seq_pos=s, - component_idx=c, - x=x, - y=y, - importance=imp, - ) - nodes.append(node) - node_lookup[(layer, s, c)] = node + if is_output_layer: + # For output layer: only show components active at THIS position + active_at_pos = [c for c in all_alive_components if alive_mask[s, c]] + n_active = len(active_at_pos) + if n_active == 0: + continue + + n_rows_pos = (n_active + max_grid_cols - 1) // max_grid_cols + n_cols_pos = min(n_active, max_grid_cols) + + # Use full width of max_grid_cols for output layer to spread out labels + max_width = (max_grid_cols - 1) * col_spacing + output_col_spacing = max_width / max(n_cols_pos - 1, 1) if n_cols_pos > 1 else 0 + + for local_idx, c in enumerate(active_at_pos): + col = local_idx % max_grid_cols + row = local_idx // max_grid_cols + + x_offset = (col - (n_cols_pos - 1) / 2) * output_col_spacing + y_offset = (row - (n_rows_pos - 1) / 2) * row_spacing + + x = x_base + x_offset + y = y_center + y_offset + + imp = layer_imp[s, c].item() + node = NodeInfo( + layer=layer, + seq_pos=s, + component_idx=c, + x=x, + y=y, + importance=imp, + ) + nodes.append(node) + node_lookup[(layer, s, c)] = node + else: + # For other layers: arrange all components in grid + for local_idx, c in enumerate(all_alive_components): + col = local_idx % layer_max_cols + row = local_idx // layer_max_cols + + if is_qvk: + # Position within the allocated horizontal segment + # Center of this sublayer's segment within the total QVK row + segment_center = start_col + layer_max_cols / 2 + total_center = QVK_TOTAL_COLS / 2 + # Offset from center of entire row + segment_offset = (segment_center - total_center) * col_spacing + # Position within segment, centered + x_offset = segment_offset + (col - (n_cols - 1) / 2) * col_spacing + else: + # Position within grid, centered on sequence position + x_offset = (col - (n_cols - 1) / 2) * col_spacing + + y_offset = (row - (n_rows - 1) / 2) * row_spacing + + x = x_base + x_offset + y = y_center + y_offset + + imp = layer_imp[s, c].item() + node = NodeInfo( + layer=layer, + seq_pos=s, + component_idx=c, + x=x, + y=y, + importance=imp, + ) + nodes.append(node) + node_lookup[(layer, s, c)] = node # Collect edges edges: list[tuple[NodeInfo, NodeInfo, float]] = [] @@ -385,10 +431,32 @@ def plot_local_attribution_graph( alpha=alpha, ) + # Add token label and probability for output layer nodes + if node.layer == "output" and output_token_labels is not None and has_edges: + token_label = output_token_labels.get(node.component_idx, "") + if token_label: + # Build label with probability if available + label_text = repr(token_label)[1:-1] # Strip quotes but show escape chars + if output_probs_by_pos is not None: + prob = output_probs_by_pos.get((node.seq_pos, node.component_idx)) + if prob is not None: + label_text = f"({prob:.2f})\n{label_text}" + ax.annotate( + label_text, + (node.x, node.y), + xytext=(0, 6), + textcoords="offset points", + fontsize=6, + ha="center", + va="bottom", + alpha=0.8, + ) + # Configure axes total_height = max(layer_y.values()) ax.set_xlim(0, 1) - ax.set_ylim(-0.5, total_height + 0.5) + # Extra top margin to fit output token labels with probabilities + ax.set_ylim(-0.5, total_height + 1.0) # X-axis: token labels ax.set_xticks(x_positions) @@ -420,6 +488,8 @@ def plot_local_attribution_graph( if sub in sublayers: ordered.append(sub.split(".")[-1]) label = f"h.{block_idx}\n" + "/".join(ordered) + elif layers_at_y[0] == "output" and output_prob_threshold is not None: + label = f"output\n(prob>{output_prob_threshold})" else: label = layers_at_y[0].replace(".", "\n", 1) layer_labels.append(label) @@ -475,6 +545,9 @@ def load_and_plot( data = torch.load(pt_path, weights_only=False) attr_pairs: list[PairAttribution] = data["attr_pairs"] token_strings: list[str] = data["token_strings"] + output_token_labels: dict[int, str] | None = data.get("output_token_labels") + output_prob_threshold: float | None = data.get("output_prob_threshold") + output_probs_by_pos: dict[tuple[int, int], float] | None = data.get("output_probs_by_pos") print(f"Loaded attributions from {pt_path}") print(f" Prompt: {data.get('prompt', 'N/A')!r}") @@ -484,6 +557,9 @@ def load_and_plot( fig = plot_local_attribution_graph( attr_pairs=attr_pairs, token_strings=token_strings, + output_token_labels=output_token_labels, + output_prob_threshold=output_prob_threshold, + output_probs_by_pos=output_probs_by_pos, **plot_kwargs, ) From 16353c9f4a23dda691d0d2bc4f4827887badbeb6 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Sat, 29 Nov 2025 12:02:50 +0000 Subject: [PATCH 25/36] Minor tweaks --- spd/scripts/calc_local_attributions.py | 15 ++++++++++----- spd/scripts/plot_local_attributions.py | 4 +++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 9c6ef0b9e..b1e31c336 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -234,10 +234,10 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An # %% # Configuration # wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) -wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (new) +# wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (new) # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L -# wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) -n_blocks = 1 +wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) +n_blocks = 4 ci_threshold = 1e-6 output_prob_threshold = 1e-1 # prompt = "The quick brown fox" @@ -284,12 +284,17 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An print("\nAttribution summary:") # for pair, attr in local_attributions: for attr_pair in attr_pairs: - nonzero = (attr_pair.attribution > 0).sum().item() total = attr_pair.attribution.numel() + if total == 0: + print( + f"Ignoring {attr_pair.source} -> {attr_pair.target}: shape={list(attr_pair.attribution.shape)}, zero" + ) + continue + nonzero = (attr_pair.attribution > 0).sum().item() print( f" {attr_pair.source} -> {attr_pair.target}: " f"shape={list(attr_pair.attribution.shape)}, " - f"nonzero={nonzero}/{total} ({100 * nonzero / total:.2f}%), " + f"nonzero={nonzero}/{total} ({100 * nonzero / (total + 1e-12):.2f}%), " f"max={attr_pair.attribution.max():.6f}" ) diff --git a/spd/scripts/plot_local_attributions.py b/spd/scripts/plot_local_attributions.py index 736d12059..3ac319b70 100644 --- a/spd/scripts/plot_local_attributions.py +++ b/spd/scripts/plot_local_attributions.py @@ -573,7 +573,9 @@ def load_and_plot( # %% if __name__ == "__main__": # Configuration - wandb_id = "33n6xjjt" # ss_gpt2_simple-1L (new) + # wandb_id = "33n6xjjt" # ss_gpt2_simple-1L (new) + # wandb_id = "c0k3z78g" # ss_gpt2_simple-2L + wandb_id = "jyo9duz5" # ss_gpt2_simple-1.25M (4L) out_dir = get_out_dir() pt_path = out_dir / f"local_attributions_{wandb_id}.pt" From 48e406d2d8260501bfa4d0a119a63b511d681b6b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 1 Dec 2025 10:52:05 +0000 Subject: [PATCH 26/36] Cleanup plotting --- spd/scripts/calc_global_attributions.py | 56 ++- spd/scripts/plot_global_attributions.py | 444 +++++++++++++++--------- 2 files changed, 320 insertions(+), 180 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 7e55206d5..68270a7ad 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -51,6 +51,8 @@ def compute_mean_ci_per_component( ) -> dict[str, Tensor]: """Compute mean causal importance per component over the dataset. + Also computes mean output probability per vocab token. + Args: model: The ComponentModel to analyze. data_loader: DataLoader providing batches. @@ -60,6 +62,7 @@ def compute_mean_ci_per_component( Returns: Dictionary mapping module path -> tensor of shape [C] with mean CI per component. + Also includes "output" -> tensor of shape [vocab_size] with mean output probability. """ # Initialize accumulators ci_sums: dict[str, Tensor] = { @@ -67,6 +70,10 @@ def compute_mean_ci_per_component( } examples_seen: dict[str, int] = {module_name: 0 for module_name in model.components} + # Output prob accumulators (initialized on first batch to get vocab_size) + output_prob_sum: Tensor | None = None + output_examples_seen = 0 + if max_batches is not None: batch_pbar = tqdm(enumerate(data_loader), desc="Computing mean CI", total=max_batches) else: @@ -94,11 +101,21 @@ def compute_mean_ci_per_component( leading_dim_idxs = tuple(range(n_leading_dims)) ci_sums[module_name] += ci_vals.sum(dim=leading_dim_idxs) + # Accumulate output probabilities + output_probs = torch.softmax(output_with_cache.output, dim=-1) # [b, s, vocab] + if output_prob_sum is None: + vocab_size = output_probs.shape[-1] + output_prob_sum = torch.zeros(vocab_size, device=device) + output_prob_sum += output_probs.sum(dim=(0, 1)) + output_examples_seen += output_probs.shape[0] * output_probs.shape[1] + # Compute means - mean_cis = { + mean_cis: dict[str, Tensor] = { module_name: ci_sums[module_name] / examples_seen[module_name] for module_name in model.components } + assert output_prob_sum is not None, "No batches processed" + mean_cis["output"] = output_prob_sum / output_examples_seen return mean_cis @@ -110,6 +127,7 @@ def compute_alive_components( config: Config, max_batches: int | None, threshold: float, + output_mean_prob_threshold: float, ) -> tuple[dict[str, Tensor], dict[str, list[int]], tuple[Image.Image, Image.Image]]: """Compute alive components based on mean CI threshold. @@ -120,18 +138,25 @@ def compute_alive_components( config: SPD config with sampling settings. max_batches: Maximum number of batches to process. threshold: Minimum mean CI to consider a component alive. + output_mean_prob_threshold: Minimum mean output probability to consider a token alive. Returns: Tuple of: - mean_cis: Dictionary mapping module path -> tensor of mean CI per component + (includes "output" key with mean output probabilities) - alive_indices: Dictionary mapping module path -> list of alive component indices + (includes "output" key) - images: Tuple of (linear_scale_image, log_scale_image) for verification """ mean_cis = compute_mean_ci_per_component(model, data_loader, device, config, max_batches) + alive_indices = {} - for module_name, mean_ci in mean_cis.items(): - alive_mask = mean_ci >= threshold + for module_name, mean_val in mean_cis.items(): + # Use output_mean_prob_threshold for output layer, threshold for components + thresh = output_mean_prob_threshold if module_name == "output" else threshold + alive_mask = mean_val >= thresh alive_indices[module_name] = torch.where(alive_mask)[0].tolist() + images = plot_mean_component_cis_both_scales(mean_cis) return mean_cis, alive_indices, images @@ -241,6 +266,7 @@ def compute_global_attributions( max_batches: int, alive_indices: dict[str, list[int]], ci_attribution_threshold: float, + output_mean_prob_threshold: float, ) -> dict[tuple[str, str], Tensor]: """Compute global attributions accumulated over the dataset. @@ -260,6 +286,7 @@ def compute_global_attributions( max_batches: Maximum number of batches to process. alive_indices: Dictionary mapping module path -> list of alive component indices. ci_attribution_threshold: Threshold for considering a component for the attribution calculation. + output_mean_prob_threshold: Threshold for considering an output logit alive (on softmax probs). Returns: Dictionary mapping (in_layer, out_layer) -> attribution tensor of shape [n_alive_in, n_alive_out] @@ -334,6 +361,17 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An ci.lower_leaky["wte"] = torch.zeros_like(wte_cache["wte_post_detach"]) ci.lower_leaky["wte"][:, :, 0] = 1.0 + # Add output to cache and CI + cache["output_pre_detach"] = comp_output_with_cache.output + # Use output probs as fake CI for output layer + output_probs: Float[Tensor, "b s vocab"] = torch.softmax( + comp_output_with_cache.output, dim=-1 + ) + # Only consider tokens above threshold as "alive" for this batch + ci.lower_leaky["output"] = torch.where( + output_probs >= output_mean_prob_threshold, output_probs, torch.zeros_like(output_probs) + ) + # Compute attributions grouped by target layer for out_layer, in_layers in tqdm( sources_by_target.items(), desc="Target layers", leave=False @@ -462,7 +500,8 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An # wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (New) n_blocks = 1 - batch_size = 1024 + # batch_size = 1024 + batch_size = 128 n_ctx = 64 # n_attribution_batches = 20 n_attribution_batches = 5 @@ -470,6 +509,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An # n_alive_calc_batches = 200 ci_mean_alive_threshold = 1e-6 ci_attribution_threshold = 1e-6 + output_mean_prob_threshold = 1e-8 dataset_seed = 0 out_dir = get_out_dir() @@ -495,8 +535,8 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An for out_layer, in_layers in sources_by_target.items(): print(f" {out_layer} <- {in_layers}") # %% - # Compute alive components based on mean CI threshold - print("\nComputing alive components based on mean CI...") + # Compute alive components based on mean CI threshold (and output probability threshold) + print("\nComputing alive components...") mean_cis, alive_indices, (img_linear, img_log) = compute_alive_components( model=model, data_loader=data_loader, @@ -504,13 +544,14 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An config=config, max_batches=n_alive_calc_batches, threshold=ci_mean_alive_threshold, + output_mean_prob_threshold=output_mean_prob_threshold, ) # Print summary print("\nAlive components per layer:") for module_name, indices in alive_indices.items(): n_alive = len(indices) - print(f" {module_name}: {n_alive}/{model.C} alive") + print(f" {module_name}: {n_alive} alive") # Save images for verification img_linear.save(out_dir / f"ci_mean_per_component_linear_{wandb_id}.png") @@ -530,6 +571,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An max_batches=n_attribution_batches, alive_indices=alive_indices, ci_attribution_threshold=ci_attribution_threshold, + output_mean_prob_threshold=output_mean_prob_threshold, ) # Print summary statistics diff --git a/spd/scripts/plot_global_attributions.py b/spd/scripts/plot_global_attributions.py index b3862b9dc..5189ff874 100644 --- a/spd/scripts/plot_global_attributions.py +++ b/spd/scripts/plot_global_attributions.py @@ -1,170 +1,219 @@ -# %% """Plot attribution graph from saved global attributions.""" +from __future__ import annotations + +from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt import networkx as nx import torch +from torch import Tensor -# Configuration -# wandb_id = "c0k3z78g" # ss_gpt2_simple-2L -# n_blocks = 2 -wandb_id = "8ynfbr38" # ss_gpt2_simple-1L -n_blocks = 1 -edge_threshold = 1e-2 - -# Load saved data -out_dir = Path(__file__).parent / "out" -global_attributions = torch.load(out_dir / f"global_attributions_{wandb_id}.pt") - -# Reconstruct alive_indices from attribution tensor shapes -alive_indices: dict[str, list[int]] = {} -for (in_layer, out_layer), attr in global_attributions.items(): - n_alive_in, n_alive_out = attr.shape - if in_layer not in alive_indices: - alive_indices[in_layer] = list(range(n_alive_in)) - if out_layer not in alive_indices: - alive_indices[out_layer] = list(range(n_alive_out)) - -print(f"Loaded attributions for {len(global_attributions)} layer pairs") -print(f"Total alive components: {sum(len(v) for v in alive_indices.values())}") - -# Count edges before and after thresholding -total_edges = sum(attr.numel() for attr in global_attributions.values()) -print(f"Total edges: {total_edges:,}") -thresholds = [1, 0.6, 0.2, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-12, 1e-15] -for threshold in thresholds: - total_edges_threshold = sum( - (attr > threshold).sum().item() for attr in global_attributions.values() - ) - print(f"Edges > {threshold}: {total_edges_threshold:,}") - -# %% -# Plot the attribution graph -print("\nPlotting attribution graph...") - -# Define layer order within a block (network order) -layer_names_in_block = [ - "attn.q_proj", - "attn.k_proj", - "attn.v_proj", - "attn.o_proj", - "mlp.c_fc", - "mlp.down_proj", -] - -# Build full layer list in network order -all_layers = [] -for block_idx in range(n_blocks): - for layer_name in layer_names_in_block: - all_layers.append(f"h.{block_idx}.{layer_name}") - -# Create graph -G = nx.DiGraph() - -# Add nodes for each (layer, component) pair -node_positions = {} -block_spacing = 6.0 # Vertical spacing between blocks - -# Layer y-offsets within a block: down_proj at top, q/k/v at same level at bottom -# q_proj, k_proj, v_proj are placed side by side since they never connect to each other -layer_y_offsets = { - "mlp.down_proj": 2.0, - "mlp.c_fc": 1.0, - "attn.o_proj": 0.0, - "attn.v_proj": -1.0, # Same y-level for q/k/v - "attn.k_proj": -1.0, - "attn.q_proj": -1.0, -} +from spd.scripts.model_loading import get_out_dir -# X-offsets for q/k/v to place them side by side with much more spacing -layer_x_offsets = { - "mlp.down_proj": 0.0, - "mlp.c_fc": 0.0, - "attn.o_proj": 0.0, - "attn.q_proj": -20.0, # Left (much more spacing) - "attn.k_proj": 0.0, # Center - "attn.v_proj": 20.0, # Right (much more spacing) -} +if TYPE_CHECKING: + Graph = nx.DiGraph[Any] +else: + Graph = nx.DiGraph + + +@dataclass +class LayerInfo: + """All display information for a layer type.""" + + name: str + color: str + y_offset: float + x_offset: float = 0.0 + legend_name: str | None = None # If different from name -for layer in all_layers: - parts = layer.split(".") - block_idx = int(parts[1]) - layer_name = ".".join(parts[2:]) - - n_alive = len(alive_indices.get(layer, [])) - if n_alive == 0: - continue - - # Block 1 on top (higher y), Block 0 on bottom (lower y) - y_base = block_idx * block_spacing + layer_y_offsets[layer_name] - # X-axis base depends on layer type (q/k/v are offset) - x_base = layer_x_offsets[layer_name] - - for comp_idx, local_idx in enumerate(alive_indices.get(layer, [])): - node_id = f"{layer}:{local_idx}" - G.add_node(node_id, layer=layer, component=local_idx) - # Increase spacing between nodes from 0.15 to 0.25 for less overlap - x = x_base + (comp_idx - n_alive / 2) * 0.25 - y = y_base - node_positions[node_id] = (x, y) - -# Add edges based on attributions -edge_weights = [] -for (in_layer, out_layer), attr_tensor in global_attributions.items(): - in_alive = alive_indices.get(in_layer, []) - out_alive = alive_indices.get(out_layer, []) - - for i, in_comp in enumerate(in_alive): - for j, out_comp in enumerate(out_alive): - weight = attr_tensor[i, j].item() - if weight > edge_threshold: - in_node = f"{in_layer}:{in_comp}" - out_node = f"{out_layer}:{out_comp}" - if in_node in G.nodes and out_node in G.nodes: - G.add_edge(in_node, out_node, weight=weight) - edge_weights.append(weight) - -print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") - -# Create figure (extra wide to accommodate q/k/v side by side with large spacing) -fig, ax = plt.subplots(1, 1, figsize=(32, 12)) - -# Draw nodes grouped by layer -layer_colors = { - "attn.q_proj": "#1f77b4", - "attn.k_proj": "#2ca02c", - "attn.v_proj": "#9467bd", - "attn.o_proj": "#d62728", - "mlp.c_fc": "#ff7f0e", - "mlp.down_proj": "#8c564b", + @property + def display_name(self) -> str: + return self.legend_name if self.legend_name is not None else self.name + + +# fmt: off +LAYER_INFOS: dict[str, LayerInfo] = { + "wte": LayerInfo("wte", color="#34495E", y_offset=-3.0), + "attn.q_proj": LayerInfo("attn.q_proj", color="#1f77b4", y_offset=-1.0, x_offset=-20.0, legend_name="q_proj"), + "attn.k_proj": LayerInfo("attn.k_proj", color="#2ca02c", y_offset=-1.0, x_offset=0.0, legend_name="k_proj"), + "attn.v_proj": LayerInfo("attn.v_proj", color="#9467bd", y_offset=-1.0, x_offset=20.0, legend_name="v_proj"), + "attn.o_proj": LayerInfo("attn.o_proj", color="#d62728", y_offset=0.0, legend_name="o_proj"), + "mlp.c_fc": LayerInfo("mlp.c_fc", color="#ff7f0e", y_offset=1.0, legend_name="c_fc"), + "mlp.down_proj": LayerInfo("mlp.down_proj", color="#8c564b", y_offset=2.0, legend_name="down_proj"), + "output": LayerInfo("output", color="#17A589", y_offset=4.0), } +# fmt: on -for layer in all_layers: - parts = layer.split(".") - layer_name = ".".join(parts[2:]) - color = layer_colors.get(layer_name, "#333333") - layer_nodes = [n for n in G.nodes if G.nodes[n].get("layer") == layer] - if layer_nodes: - pos_subset = {n: node_positions[n] for n in layer_nodes} - nx.draw_networkx_nodes( - G, - pos_subset, - nodelist=layer_nodes, - node_color=color, - node_size=100, - alpha=0.8, - ax=ax, - ) +def load_attributions( + out_dir: Path, wandb_id: str +) -> tuple[dict[tuple[str, str], Tensor], dict[str, list[int]]]: + """Load global attributions and reconstruct alive_indices from tensor shapes.""" + global_attributions: dict[tuple[str, str], Tensor] = torch.load( + out_dir / f"global_attributions_{wandb_id}.pt" + ) + + alive_indices: dict[str, list[int]] = {} + for (in_layer, out_layer), attr in global_attributions.items(): + n_alive_in, n_alive_out = attr.shape + if in_layer not in alive_indices: + alive_indices[in_layer] = list(range(n_alive_in)) + if out_layer not in alive_indices: + alive_indices[out_layer] = list(range(n_alive_out)) + + return global_attributions, alive_indices + + +def print_edge_statistics(global_attributions: dict[tuple[str, str], Tensor]) -> None: + """Print statistics about edges at various thresholds.""" + total_edges = sum(attr.numel() for attr in global_attributions.values()) + print(f"Total edges: {total_edges:,}") + + thresholds = [1, 0.6, 0.2, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-12, 1e-15] + for threshold in thresholds: + edges_above = sum((attr > threshold).sum().item() for attr in global_attributions.values()) + print(f"Edges > {threshold}: {edges_above:,}") + + +def build_layer_list(n_blocks: int) -> list[str]: + """Build full layer list in network order (wte -> blocks -> output).""" + all_layers = ["wte"] + for block_idx in range(n_blocks): + for layer_name in [k for k in LAYER_INFOS if k not in ["wte", "output"]]: + all_layers.append(f"h.{block_idx}.{layer_name}") + all_layers.append("output") + return all_layers + + +def find_nodes_with_edges( + global_attributions: dict[tuple[str, str], Tensor], + alive_indices: dict[str, list[int]], + edge_threshold: float, +) -> set[str]: + """Find all nodes that participate in edges above threshold.""" + nodes_with_edges: set[str] = set() + for (in_layer, out_layer), attr_tensor in global_attributions.items(): + in_alive = alive_indices.get(in_layer, []) + out_alive = alive_indices.get(out_layer, []) + for i, in_comp in enumerate(in_alive): + for j, out_comp in enumerate(out_alive): + if attr_tensor[i, j].item() > edge_threshold: + nodes_with_edges.add(f"{in_layer}:{in_comp}") + nodes_with_edges.add(f"{out_layer}:{out_comp}") + return nodes_with_edges + + +def parse_layer_name(layer: str, n_blocks: int) -> tuple[int, str]: + """Parse layer name to get block_idx and base layer_name.""" + if layer == "wte": + return 0, "wte" + elif layer == "output": + return n_blocks - 1, "output" + else: + parts = layer.split(".") + return int(parts[1]), ".".join(parts[2:]) + + +def build_graph( + all_layers: list[str], + global_attributions: dict[tuple[str, str], Tensor], + alive_indices: dict[str, list[int]], + nodes_with_edges: set[str], + n_blocks: int, + edge_threshold: float, + block_spacing: float = 6.0, +) -> tuple[Graph, dict[str, tuple[float, float]], list[float]]: + """Build the attribution graph with node positions and edge weights.""" + G: Graph = nx.DiGraph() + node_positions: dict[str, tuple[float, float]] = {} + + # Add nodes + for layer in all_layers: + block_idx, layer_name = parse_layer_name(layer, n_blocks) + info = LAYER_INFOS[layer_name] + + y_base = block_idx * block_spacing + info.y_offset + x_base = info.x_offset + + layer_alive = alive_indices.get(layer, []) + layer_nodes_with_edges = [ + (idx, comp) + for idx, comp in enumerate(layer_alive) + if f"{layer}:{comp}" in nodes_with_edges + ] + n_layer_nodes = len(layer_nodes_with_edges) + + for pos_idx, (_, local_idx) in enumerate(layer_nodes_with_edges): + node_id = f"{layer}:{local_idx}" + G.add_node(node_id, layer=layer, component=local_idx) + x = x_base + (pos_idx - n_layer_nodes / 2) * 0.25 + node_positions[node_id] = (x, y_base) + + # Add edges + edge_weights: list[float] = [] + for (in_layer, out_layer), attr_tensor in global_attributions.items(): + in_alive = alive_indices.get(in_layer, []) + out_alive = alive_indices.get(out_layer, []) + + for i, in_comp in enumerate(in_alive): + for j, out_comp in enumerate(out_alive): + weight = attr_tensor[i, j].item() + if weight > edge_threshold: + in_node = f"{in_layer}:{in_comp}" + out_node = f"{out_layer}:{out_comp}" + if in_node in G.nodes and out_node in G.nodes: + G.add_edge(in_node, out_node, weight=weight) + edge_weights.append(weight) + + return G, node_positions, edge_weights + + +def draw_nodes( + ax: plt.Axes, + G: Graph, + node_positions: dict[str, tuple[float, float]], + all_layers: list[str], +) -> None: + """Draw nodes grouped by layer.""" + for layer in all_layers: + if layer in ("wte", "output"): + layer_name = layer + else: + parts = layer.split(".") + layer_name = ".".join(parts[2:]) + color = LAYER_INFOS[layer_name].color + + layer_nodes = [n for n in G.nodes if G.nodes[n].get("layer") == layer] + if layer_nodes: + pos_subset = {n: node_positions[n] for n in layer_nodes} + nx.draw_networkx_nodes( + G, + pos_subset, + nodelist=layer_nodes, + node_color=color, + node_size=100, + alpha=0.8, + ax=ax, + ) + + +def draw_edges( + ax: plt.Axes, + G: Graph, + node_positions: dict[str, tuple[float, float]], + edge_weights: list[float], + n_buckets: int = 10, +) -> None: + """Draw edges batched by weight bucket for performance.""" + if not edge_weights: + return -# Draw edges batched by weight bucket for performance -if edge_weights: max_weight = max(edge_weights) min_weight = min(edge_weights) - n_buckets = 10 edge_buckets: list[list[tuple[str, str]]] = [[] for _ in range(n_buckets)] for u, v, data in G.edges(data=True): @@ -196,31 +245,80 @@ ax=ax, ) -# Add legend -legend_elements = [ - plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markersize=10, label=name) - for name, color in [ - ("q_proj", "#1f77b4"), - ("k_proj", "#2ca02c"), - ("v_proj", "#9467bd"), - ("o_proj", "#d62728"), - ("c_fc", "#ff7f0e"), - ("down_proj", "#8c564b"), + +def add_legend(ax: plt.Axes) -> None: + """Add legend for layer types.""" + legend_elements = [ + plt.Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=info.color, + markersize=10, + label=info.display_name, + ) + for info in LAYER_INFOS.values() ] -] -ax.legend(handles=legend_elements, loc="upper right", fontsize=8) + ax.legend(handles=legend_elements, loc="upper right", fontsize=8) + + +def plot_attribution_graph( + global_attributions: dict[tuple[str, str], Tensor], + alive_indices: dict[str, list[int]], + n_blocks: int, + edge_threshold: float, + output_path: Path, +) -> None: + """Create and save the attribution graph visualization.""" + print("\nPlotting attribution graph...") + + all_layers = build_layer_list(n_blocks) + nodes_with_edges = find_nodes_with_edges(global_attributions, alive_indices, edge_threshold) + G, node_positions, edge_weights = build_graph( + all_layers, global_attributions, alive_indices, nodes_with_edges, n_blocks, edge_threshold + ) + + print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") + + fig, ax = plt.subplots(1, 1, figsize=(32, 12)) + + draw_nodes(ax, G, node_positions, all_layers) + draw_edges(ax, G, node_positions, edge_weights) + add_legend(ax) -ax.set_title("Global Attribution Graph", fontsize=14, fontweight="bold") -ax.axis("off") -plt.tight_layout() + ax.set_title("Global Attribution Graph", fontsize=14, fontweight="bold") + ax.axis("off") + plt.tight_layout() -# Save -# Make an edge threshold string in scientific notation which doesn't include decimal places -edge_threshold_str = f"{edge_threshold:.1e}".replace(".0", "") -output_path = out_dir / f"attribution_graph_{wandb_id}_edge_threshold_{edge_threshold_str}.png" -fig.savefig(output_path, dpi=150, bbox_inches="tight") -print(f"Saved to {output_path}") + fig.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"Saved to {output_path}") + plt.close(fig) -plt.close(fig) -# %% +if __name__ == "__main__": + # Configuration + # wandb_id = "jyo9duz5" # ss_gpt2_simple-1.25M (4L) + # wandb_id = "c0k3z78g" # ss_gpt2_simple-2L + wandb_id = "33n6xjjt" # ss_gpt2_simple-1L (New) + n_blocks = 1 + edge_threshold = 1e-1 + + out_dir = get_out_dir() + + global_attributions, alive_indices = load_attributions(out_dir, wandb_id) + print(f"Loaded attributions for {len(global_attributions)} layer pairs") + print(f"Total alive components: {sum(len(v) for v in alive_indices.values())}") + + print_edge_statistics(global_attributions) + + edge_threshold_str = f"{edge_threshold:.1e}".replace(".0", "") + output_path = out_dir / f"attribution_graph_{wandb_id}_edge_threshold_{edge_threshold_str}.png" + + plot_attribution_graph( + global_attributions, + alive_indices, + n_blocks, + edge_threshold, + output_path, + ) From 56370b198a86f24cee30e48798c1998b45a7586b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 1 Dec 2025 13:42:37 +0000 Subject: [PATCH 27/36] Show l0 for final seq position --- spd/scripts/calc_local_attributions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index b1e31c336..833488a0a 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -103,6 +103,13 @@ def compute_local_attributions( detach_inputs=False, ) + # Log the l0 (lower_leaky values > ci_threshold) for each layer + print("L0 values for final seq position:") + for layer, ci_vals in ci.lower_leaky.items(): + # We only care about the final position + l0_vals = (ci_vals[0, -1] > ci_threshold).sum().item() + print(f" Layer {layer} has {l0_vals} components alive at {ci_threshold}") + # Create masks so we can use all components (without masks) mask_infos = make_mask_infos( component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, From 6bf3c31f42830f1677d530d91d05201d1b50c15b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 1 Dec 2025 17:05:29 +0000 Subject: [PATCH 28/36] Use ci-masks and minor tweaks --- spd/scripts/calc_global_attributions.py | 4 +- spd/scripts/calc_local_attributions.py | 329 +++++++++++++++--------- spd/scripts/plot_local_attributions.py | 122 +++++++-- 3 files changed, 309 insertions(+), 146 deletions(-) diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py index 68270a7ad..62134d404 100644 --- a/spd/scripts/calc_global_attributions.py +++ b/spd/scripts/calc_global_attributions.py @@ -340,9 +340,9 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An detach_inputs=False, ) - # Create masks so we can use all components (without masks) mask_infos = make_mask_infos( - component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + # component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + component_masks=ci.lower_leaky, routing_masks="all", ) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 833488a0a..010e69508 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -1,7 +1,8 @@ -# %% """Compute local attributions for a single prompt.""" +import json from dataclasses import dataclass +from pathlib import Path from typing import Any import torch @@ -16,6 +17,7 @@ from spd.models.components import make_mask_infos from spd.scripts.calc_global_attributions import get_sources_by_target, is_kv_to_o_pair from spd.scripts.model_loading import get_out_dir, load_model_from_wandb +from spd.scripts.plot_local_attributions import PairAttribution, plot_local_graph @dataclass @@ -54,14 +56,40 @@ def compute_layer_alive_info( return LayerAliveInfo(alive_mask, alive_c_idxs, c_to_trimmed) -@dataclass -class PairAttribution: - source: str - target: str - attribution: Float[Tensor, "s_in trimmed_c_in s_out trimmed_c_out"] - trimmed_c_in_idxs: list[int] - trimmed_c_out_idxs: list[int] - is_kv_to_o_pair: bool +def load_ci_from_json( + ci_vals_path: str | Path, + expected_prompt: str, + device: str, +) -> dict[str, Float[Tensor, "1 seq C"]]: + """Load precomputed CI values from a JSON file. + + Args: + ci_vals_path: Path to JSON file from run_optim_cis.py + expected_prompt: The prompt we're analyzing (must match the JSON) + device: Device to load tensors to + + Returns: + Dict mapping layer_name -> CI tensor of shape [1, seq, C] + + Raises: + ValueError: If the prompt in the JSON doesn't match expected_prompt + """ + with open(ci_vals_path) as f: + data = json.load(f) + + json_prompt = data["prompt"] + if json_prompt != expected_prompt: + raise ValueError( + f"Prompt mismatch: JSON has {json_prompt!r}, but expected {expected_prompt!r}" + ) + + ci_lower_leaky: dict[str, Tensor] = {} + for layer_name, ci_list in data["optimized_ci"].items(): + # ci_list is [seq][C], convert to tensor [1, seq, C] + ci_tensor = torch.tensor(ci_list, device=device).unsqueeze(0) + ci_lower_leaky[layer_name] = ci_tensor + + return ci_lower_leaky def compute_local_attributions( @@ -72,7 +100,8 @@ def compute_local_attributions( output_prob_threshold: float, sampling: SamplingType, device: str, -) -> list[PairAttribution]: + ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]] | None = None, +) -> tuple[list[PairAttribution], Float[Tensor, "1 seq vocab"]]: """Compute local attributions for a single prompt. For each valid layer pair (in_layer, out_layer), computes the gradient-based @@ -87,6 +116,8 @@ def compute_local_attributions( output_prob_threshold: Threshold for considering an output logit alive (on softmax probs). sampling: Sampling type to use for causal importances. device: Device to run on. + ci_lower_leaky: Optional precomputed/optimized CI values. If None, will use model CI. + When provided, we still compute original model CI to track "ghost" nodes. Returns: List of PairAttribution objects. @@ -96,26 +127,26 @@ def compute_local_attributions( with torch.no_grad(): output_with_cache: OutputWithCache = model(tokens, cache_type="input") + # Always compute original CI from model (needed for ghost nodes when using optimized CI) with torch.no_grad(): ci = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, sampling=sampling, detach_inputs=False, ) + ci_original = ci.lower_leaky + + # Use provided CI values if given, otherwise use original + if ci_lower_leaky is None: + ci_lower_leaky = ci_original # Log the l0 (lower_leaky values > ci_threshold) for each layer print("L0 values for final seq position:") - for layer, ci_vals in ci.lower_leaky.items(): + for layer, ci_vals in ci_lower_leaky.items(): # We only care about the final position l0_vals = (ci_vals[0, -1] > ci_threshold).sum().item() print(f" Layer {layer} has {l0_vals} components alive at {ci_threshold}") - # Create masks so we can use all components (without masks) - mask_infos = make_mask_infos( - component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, - routing_masks="all", - ) - wte_cache: dict[str, Tensor] = {} def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> Any: @@ -128,6 +159,11 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + mask_infos = make_mask_infos( + # component_masks={k: torch.ones_like(v) for k, v in ci_lower_leaky.items()}, + component_masks=ci_lower_leaky, + routing_masks="all", + ) with torch.enable_grad(): comp_output_with_cache: OutputWithCache = model( tokens, mask_infos=mask_infos, cache_type="component_acts" @@ -147,9 +183,14 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An all_layers.update(sources) alive_info: dict[str, LayerAliveInfo] = {} + original_alive_info: dict[str, LayerAliveInfo] = {} for layer in all_layers: alive_info[layer] = compute_layer_alive_info( - layer, ci.lower_leaky, output_probs, ci_threshold, output_prob_threshold, n_seq, device + layer, ci_lower_leaky, output_probs, ci_threshold, output_prob_threshold, n_seq, device + ) + # Compute original alive info (from model CI, not optimized CI) + original_alive_info[layer] = compute_layer_alive_info( + layer, ci_original, output_probs, ci_threshold, output_prob_threshold, n_seq, device ) local_attributions: list[PairAttribution] = [] @@ -224,6 +265,8 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An ] for source, source_info, attr in zip(sources, source_infos, attributions, strict=True): + original_source_info = original_alive_info[source] + original_target_info = original_alive_info[target] local_attributions.append( PairAttribution( source=source, @@ -232,116 +275,154 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An trimmed_c_in_idxs=source_info.alive_c_idxs, trimmed_c_out_idxs=target_info.alive_c_idxs, is_kv_to_o_pair=is_kv_to_o_pair(source, target), + # Pass per-position alive masks (squeeze out batch dim) + original_alive_mask_in=original_source_info.alive_mask[0], # [seq, C] + original_alive_mask_out=original_target_info.alive_mask[0], # [seq, C] ) ) - return local_attributions - - -# %% -# Configuration -# wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) -# wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (new) -# wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L -wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) -n_blocks = 4 -ci_threshold = 1e-6 -output_prob_threshold = 1e-1 -# prompt = "The quick brown fox" -# prompt = "Eagerly, a girl named Kim went" -prompt = "They walked hand in" - -loaded = load_model_from_wandb(wandb_path) -model, config, device = loaded.model, loaded.config, loaded.device - -tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) -assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" -sources_by_target = get_sources_by_target(model, device, config, n_blocks) - -n_pairs = sum(len(ins) for ins in sources_by_target.values()) -print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") -for out_layer, in_layers in sources_by_target.items(): - print(f" {out_layer} <- {in_layers}") - -# %% -# Tokenize the prompt -print(f"\nPrompt: {prompt!r}") -tokens = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) -assert isinstance(tokens, Tensor), "Expected Tensor" -tokens = tokens.to(device) -print(f"Tokens shape: {tokens.shape}") -print(f"Tokens: {tokens[0].tolist()}") -token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] -print(f"Token strings: {token_strings}") - -# %% -# Compute local attributions -print("\nComputing local attributions...") -attr_pairs = compute_local_attributions( - model=model, - tokens=tokens, - sources_by_target=sources_by_target, - ci_threshold=ci_threshold, - output_prob_threshold=output_prob_threshold, - sampling=config.sampling, - device=device, -) - -# Print summary statistics -print("\nAttribution summary:") -# for pair, attr in local_attributions: -for attr_pair in attr_pairs: - total = attr_pair.attribution.numel() - if total == 0: + return local_attributions, output_probs + + +def main( + wandb_path: str, + n_blocks: int, + ci_threshold: float, + output_prob_threshold: float, + prompt: str, + ci_vals_path: str | None, +) -> None: + loaded = load_model_from_wandb(wandb_path) + model, config, device = loaded.model, loaded.config, loaded.device + + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" + sources_by_target = get_sources_by_target(model, device, config, n_blocks) + + n_pairs = sum(len(ins) for ins in sources_by_target.values()) + print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") + for out_layer, in_layers in sources_by_target.items(): + print(f" {out_layer} <- {in_layers}") + + # Tokenize the prompt + print(f"\nPrompt: {prompt!r}") + tokens = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) + assert isinstance(tokens, Tensor), "Expected Tensor" + tokens = tokens.to(device) + print(f"Tokens shape: {tokens.shape}") + print(f"Tokens: {tokens[0].tolist()}") + token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] + print(f"Token strings: {token_strings}") + + # Load precomputed CI values if path is provided + ci_lower_leaky: dict[str, Tensor] | None = None + if ci_vals_path is not None: + print(f"\nLoading precomputed CI values from {ci_vals_path}") + ci_lower_leaky = load_ci_from_json(ci_vals_path, prompt, device) + print(f"Loaded CI values for layers: {list(ci_lower_leaky.keys())}") + + # Compute local attributions + print("\nComputing local attributions...") + attr_pairs, output_probs = compute_local_attributions( + model=model, + tokens=tokens, + sources_by_target=sources_by_target, + ci_threshold=ci_threshold, + output_prob_threshold=output_prob_threshold, + sampling=config.sampling, + device=device, + ci_lower_leaky=ci_lower_leaky, + ) + + # Print summary statistics + print("\nAttribution summary:") + for attr_pair in attr_pairs: + total = attr_pair.attribution.numel() + if total == 0: + print( + f"Ignoring {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, zero" + ) + continue + nonzero = (attr_pair.attribution > 0).sum().item() print( - f"Ignoring {attr_pair.source} -> {attr_pair.target}: shape={list(attr_pair.attribution.shape)}, zero" + f" {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, " + f"nonzero={nonzero}/{total} ({100 * nonzero / (total + 1e-12):.2f}%), " + f"max={attr_pair.attribution.max():.6f}" ) - continue - nonzero = (attr_pair.attribution > 0).sum().item() - print( - f" {attr_pair.source} -> {attr_pair.target}: " - f"shape={list(attr_pair.attribution.shape)}, " - f"nonzero={nonzero}/{total} ({100 * nonzero / (total + 1e-12):.2f}%), " - f"max={attr_pair.attribution.max():.6f}" + + # Save attributions + out_dir = get_out_dir() + pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" + output_path = out_dir / f"local_attribution_graph_{loaded.wandb_id}.png" + if ci_vals_path is not None: + pt_path = pt_path.with_stem(pt_path.stem + "_with_ci_optim") + output_path = output_path.with_stem(output_path.stem + "_with_ci_optim") + + output_token_labels: dict[int, str] = {} + output_probs_by_pos: dict[tuple[int, int], float] = {} + for attr_pair in attr_pairs: + if attr_pair.target == "output": + for c_idx in attr_pair.trimmed_c_out_idxs: + if c_idx not in output_token_labels: + output_token_labels[c_idx] = tokenizer.decode([c_idx]) + # Store probability for each position + for s in range(tokens.shape[1]): + prob = output_probs[0, s, c_idx].item() + if prob >= output_prob_threshold: + output_probs_by_pos[(s, c_idx)] = prob + break + + save_data = { + "attr_pairs": attr_pairs, + "token_strings": token_strings, + "prompt": prompt, + "ci_threshold": ci_threshold, + "output_prob_threshold": output_prob_threshold, + "output_token_labels": output_token_labels, + "output_probs_by_pos": output_probs_by_pos, + "wandb_id": loaded.wandb_id, + } + torch.save(save_data, pt_path) + print(f"\nSaved local attributions to {pt_path}") + + fig = plot_local_graph( + attr_pairs=attr_pairs, + token_strings=token_strings, + output_token_labels=output_token_labels, + output_prob_threshold=output_prob_threshold, + output_probs_by_pos=output_probs_by_pos, + min_edge_weight=0.0001, + node_scale=30.0, + edge_alpha_scale=0.7, ) -# %% -# Save attributions -out_dir = get_out_dir() - -# Save PyTorch format with all necessary data -pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" - -# Get output token labels and probabilities for alive output components -# We need the output probabilities to get per-position probs -with torch.no_grad(): - output_with_cache: OutputWithCache = model(tokens, cache_type="input") - output_probs_full = torch.softmax(output_with_cache.output, dim=-1) # [1, seq, vocab] - -output_token_labels: dict[int, str] = {} -# output_probs_by_pos: dict mapping (seq_pos, component_idx) -> probability -output_probs_by_pos: dict[tuple[int, int], float] = {} -for attr_pair in attr_pairs: - if attr_pair.target == "output": - for c_idx in attr_pair.trimmed_c_out_idxs: - if c_idx not in output_token_labels: - output_token_labels[c_idx] = tokenizer.decode([c_idx]) - # Store probability for each position - for s in range(tokens.shape[1]): - prob = output_probs_full[0, s, c_idx].item() - if prob >= output_prob_threshold: - output_probs_by_pos[(s, c_idx)] = prob - break - -save_data = { - "attr_pairs": attr_pairs, - "token_strings": token_strings, - "prompt": prompt, - "ci_threshold": ci_threshold, - "output_prob_threshold": output_prob_threshold, - "output_token_labels": output_token_labels, - "output_probs_by_pos": output_probs_by_pos, - "wandb_id": loaded.wandb_id, -} -torch.save(save_data, pt_path) -print(f"\nSaved local attributions to {pt_path}") + fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") + print(f"Saved figure to {output_path}") + + +if __name__ == "__main__": + # Configuration + # wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) + wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (new) + # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L + # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) + # n_blocks = 4 + n_blocks = 1 + ci_threshold = 1e-6 + output_prob_threshold = 1e-1 + # prompt = "The quick brown fox" + # prompt = "Eagerly, a girl named Kim went" + prompt = "They walked hand in" + # Path to precomputed CI values from run_optim_cis.py (None to compute from model) + ci_vals_path: str | None = None + # ci_vals_path = "spd/scripts/optim_cis/out/optimized_ci_33n6xjjt.json" + main( + wandb_path=wandb_path, + n_blocks=n_blocks, + ci_threshold=ci_threshold, + output_prob_threshold=output_prob_threshold, + prompt=prompt, + ci_vals_path=ci_vals_path, + ) diff --git a/spd/scripts/plot_local_attributions.py b/spd/scripts/plot_local_attributions.py index 3ac319b70..d812e5030 100644 --- a/spd/scripts/plot_local_attributions.py +++ b/spd/scripts/plot_local_attributions.py @@ -12,10 +12,24 @@ from matplotlib.collections import LineCollection from torch import Tensor -from spd.scripts.calc_local_attributions import PairAttribution from spd.scripts.model_loading import get_out_dir +@dataclass +class PairAttribution: + source: str + target: str + attribution: Float[Tensor, "s_in trimmed_c_in s_out trimmed_c_out"] + trimmed_c_in_idxs: list[int] + trimmed_c_out_idxs: list[int] + is_kv_to_o_pair: bool + # Original alive masks (from model CI, before any optimization) + # Used to show "ghost" nodes that would have been active without CI optimization + # Shape: [n_seq, n_components] where True means alive at that (seq, component) + original_alive_mask_in: Float[Tensor, "seq C"] | None = None + original_alive_mask_out: Float[Tensor, "seq C"] | None = None + + @dataclass class NodeInfo: """Information about a node in the attribution graph.""" @@ -147,26 +161,37 @@ def sort_key(item: tuple[str, int, str]) -> tuple[int, int]: def compute_node_importances( attr_pairs: list[PairAttribution], n_seq: int, -) -> dict[str, Float[Tensor, "seq C"]]: +) -> tuple[dict[str, Float[Tensor, "seq C"]], dict[str, Float[Tensor, "seq C"]]]: """Compute importance values for nodes based on total attribution flow. - Returns a dict mapping layer -> tensor of shape [n_seq, max_component_idx+1]. - Importance is the sum of incoming and outgoing attributions. + Returns: + importances: Dict mapping layer -> tensor of shape [n_seq, max_component_idx+1]. + Importance is the sum of incoming and outgoing attributions. + original_alive_masks: Dict mapping layer -> bool tensor of shape [n_seq, max_component_idx+1]. + True means the component was alive at that (seq_pos, component) in original model CI. """ - # First pass: determine max component index per layer + # First pass: determine max component index per layer (including original alive) layer_max_c: dict[str, int] = {} for pair in attr_pairs: src_max = max(pair.trimmed_c_in_idxs) if pair.trimmed_c_in_idxs else 0 tgt_max = max(pair.trimmed_c_out_idxs) if pair.trimmed_c_out_idxs else 0 + # Also consider original alive mask dimensions + if pair.original_alive_mask_in is not None: + src_max = max(src_max, pair.original_alive_mask_in.shape[1] - 1) + if pair.original_alive_mask_out is not None: + tgt_max = max(tgt_max, pair.original_alive_mask_out.shape[1] - 1) layer_max_c[pair.source] = max(layer_max_c.get(pair.source, 0), src_max) layer_max_c[pair.target] = max(layer_max_c.get(pair.target, 0), tgt_max) - # Initialize importance tensors + # Initialize importance tensors and original alive masks (on same device as attributions) + device = attr_pairs[0].attribution.device if attr_pairs else "cpu" importances: dict[str, Float[Tensor, "seq C"]] = {} + original_alive_masks: dict[str, Float[Tensor, "seq C"]] = {} for layer, max_c in layer_max_c.items(): - importances[layer] = torch.zeros(n_seq, max_c + 1) + importances[layer] = torch.zeros(n_seq, max_c + 1, device=device) + original_alive_masks[layer] = torch.zeros(n_seq, max_c + 1, device=device, dtype=torch.bool) - # Accumulate attribution magnitudes + # Accumulate attribution magnitudes and original alive masks for pair in attr_pairs: attr = pair.attribution.abs() # [s_in, trimmed_c_in, s_out, trimmed_c_out] @@ -180,10 +205,18 @@ def compute_node_importances( for j, c_out in enumerate(pair.trimmed_c_out_idxs): importances[pair.target][:, c_out] += tgt_importance[:, j] - return importances + # Accumulate original alive masks (OR them together since multiple pairs may have same layer) + if pair.original_alive_mask_in is not None: + n_c = pair.original_alive_mask_in.shape[1] + original_alive_masks[pair.source][:, :n_c] |= pair.original_alive_mask_in + if pair.original_alive_mask_out is not None: + n_c = pair.original_alive_mask_out.shape[1] + original_alive_masks[pair.target][:, :n_c] |= pair.original_alive_mask_out + return importances, original_alive_masks -def plot_local_attribution_graph( + +def plot_local_graph( attr_pairs: list[PairAttribution], token_strings: list[str], min_edge_weight: float = 0.001, @@ -214,8 +247,8 @@ def plot_local_attribution_graph( """ n_seq = len(token_strings) - # Compute node importances first - importances = compute_node_importances(attr_pairs, n_seq) + # Compute node importances and original alive masks (per-position) + importances, original_alive_masks = compute_node_importances(attr_pairs, n_seq) # Compute layout layer_y = compute_layer_y_positions(attr_pairs) @@ -230,6 +263,7 @@ def plot_local_attribution_graph( # Collect all nodes and their positions nodes: list[NodeInfo] = [] + ghost_nodes: list[NodeInfo] = [] # Nodes originally alive at this position but now dead node_lookup: dict[tuple[str, int, int], NodeInfo] = {} # (layer, seq, comp) -> NodeInfo # X spacing: spread tokens across the plot @@ -246,9 +280,22 @@ def plot_local_attribution_graph( layer_imp = importances[layer] # [n_seq, max_c+1] alive_mask = layer_imp > 0 - # Find all components that are alive at ANY sequence position + # Get original alive mask for this layer (per-position) + layer_original_mask = original_alive_masks.get(layer) + + # Find all components that are alive at ANY sequence position (current) all_alive_components = torch.where(alive_mask.any(dim=0))[0].tolist() - n_components = len(all_alive_components) + + # Find all components that were originally alive at ANY sequence position + originally_alive_any_pos: list[int] = [] + if layer_original_mask is not None: + originally_alive_any_pos = torch.where(layer_original_mask.any(dim=0))[0].tolist() + + # Combine for layout purposes + all_components_for_layout = sorted( + set(all_alive_components) | set(originally_alive_any_pos) + ) + n_components = len(all_components_for_layout) if n_components == 0: continue @@ -310,8 +357,8 @@ def plot_local_attribution_graph( nodes.append(node) node_lookup[(layer, s, c)] = node else: - # For other layers: arrange all components in grid - for local_idx, c in enumerate(all_alive_components): + # For other layers: arrange all components in grid (including ghost nodes) + for local_idx, c in enumerate(all_components_for_layout): col = local_idx % layer_max_cols row = local_idx // layer_max_cols @@ -333,7 +380,7 @@ def plot_local_attribution_graph( x = x_base + x_offset y = y_center + y_offset - imp = layer_imp[s, c].item() + imp = layer_imp[s, c].item() if c < layer_imp.shape[1] else 0.0 node = NodeInfo( layer=layer, seq_pos=s, @@ -342,7 +389,23 @@ def plot_local_attribution_graph( y=y, importance=imp, ) - nodes.append(node) + + # Check if this is a ghost node at THIS position + # Ghost = originally alive at this (s, c) but not currently alive at this (s, c) + is_currently_alive_here = ( + alive_mask[s, c].item() if c < alive_mask.shape[1] else False + ) + is_originally_alive_here = ( + layer_original_mask is not None + and c < layer_original_mask.shape[1] + and layer_original_mask[s, c].item() + ) + is_ghost = is_originally_alive_here and not is_currently_alive_here + + if is_ghost: + ghost_nodes.append(node) + else: + nodes.append(node) node_lookup[(layer, s, c)] = node # Collect edges @@ -408,7 +471,21 @@ def plot_local_attribution_graph( ) ax.add_collection(lc) - # Draw nodes + # Draw ghost nodes first (so they appear behind regular nodes) + # Ghost nodes = originally alive (in model CI) but not currently alive (in optimized CI) + for node in ghost_nodes: + ax.scatter( + node.x, + node.y, + s=node_scale, + c="#909090", # Darker gray than nodes-without-edges + edgecolors="white", + linewidths=0.5, + zorder=1.5, # Behind regular nodes + alpha=0.5, + ) + + # Draw regular nodes for node in nodes: node_key = (node.layer, node.seq_pos, node.component_idx) has_edges = node_key in nodes_with_edges @@ -515,6 +592,11 @@ def plot_local_attribution_graph( legend_elements.append( plt.scatter([], [], c=color, s=50, label=sublayer, edgecolors="white") ) + # Add ghost node to legend if there are any + if ghost_nodes: + legend_elements.append( + plt.scatter([], [], c="#909090", s=50, label="ghost (orig. alive)", edgecolors="white") + ) ax.legend( handles=legend_elements, loc="upper left", @@ -554,7 +636,7 @@ def load_and_plot( print(f" Tokens: {token_strings}") print(f" Number of layer pairs: {len(attr_pairs)}") - fig = plot_local_attribution_graph( + fig = plot_local_graph( attr_pairs=attr_pairs, token_strings=token_strings, output_token_labels=output_token_labels, From 89913de0be67818c88a1ebd36fa48ea7a0a3853a Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 1 Dec 2025 17:05:57 +0000 Subject: [PATCH 29/36] Add optim_cis --- spd/scripts/optim_cis/__init__.py | 1 + spd/scripts/optim_cis/config.py | 112 +++++++ spd/scripts/optim_cis/run_optim_cis.py | 430 +++++++++++++++++++++++++ 3 files changed, 543 insertions(+) create mode 100644 spd/scripts/optim_cis/__init__.py create mode 100644 spd/scripts/optim_cis/config.py create mode 100644 spd/scripts/optim_cis/run_optim_cis.py diff --git a/spd/scripts/optim_cis/__init__.py b/spd/scripts/optim_cis/__init__.py new file mode 100644 index 000000000..a958aba2a --- /dev/null +++ b/spd/scripts/optim_cis/__init__.py @@ -0,0 +1 @@ +"""CI optimization for single prompts.""" diff --git a/spd/scripts/optim_cis/config.py b/spd/scripts/optim_cis/config.py new file mode 100644 index 000000000..97cd6eb01 --- /dev/null +++ b/spd/scripts/optim_cis/config.py @@ -0,0 +1,112 @@ +"""Configuration for CI optimization on single prompts.""" + +from typing import Annotated, Literal, Self + +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt, model_validator + +from spd.base_config import BaseConfig +from spd.configs import LossMetricConfigType, SamplingType +from spd.spd_types import Probability + + +class OptimCIConfig(BaseConfig): + """Configuration for optimizing CI values on a single prompt.""" + + seed: int = Field( + ..., + description="Random seed for reproducibility", + ) + # Model and prompt + wandb_path: str = Field( + ..., + description="Wandb path to load model from, e.g. 'wandb:goodfire/spd/runs/jyo9duz5'", + ) + prompt: str = Field( + ..., + description="The prompt to optimize CI values for", + ) + label: str = Field( + ..., + description="The label to optimize CI values for", + ) + + # Optimization hyperparameters + lr: PositiveFloat = Field( + ..., + description="Learning rate for AdamW optimizer", + ) + steps: PositiveInt = Field( + ..., + description="Number of optimization steps", + ) + weight_decay: NonNegativeFloat = Field( + ..., + description="Weight decay for AdamW optimizer", + ) + lr_schedule: Literal["linear", "constant", "cosine", "exponential"] = Field( + ..., + description="Type of learning-rate schedule to apply", + ) + lr_exponential_halflife: PositiveFloat | None = Field( + ..., + description="Half-life parameter when using an exponential LR schedule", + ) + lr_warmup_pct: Probability = Field( + ..., + description="Fraction of total steps to linearly warm up the learning rate", + ) + log_freq: PositiveInt = Field( + ..., + description="Frequency of logging during optimization", + ) + + # Loss configuration + loss_metric_configs: list[Annotated[LossMetricConfigType, Field(discriminator="classname")]] = ( + Field( + ..., + description="List of loss metric configs (must have coeff set)", + ) + ) + + # CI thresholds and sampling + ci_threshold: PositiveFloat = Field( + ..., + description="Threshold for considering a component alive in original CI values. " + "Only components with CI > ci_threshold will be optimized.", + ) + sampling: SamplingType = Field( + ..., + description="Sampling mode for stochastic losses: 'continuous' or 'binomial'", + ) + n_mask_samples: PositiveInt = Field( + ..., + description="Number of stochastic masks to sample for recon losses", + ) + output_loss_type: Literal["mse", "kl"] = Field( + ..., + description="Loss type for reconstruction: 'kl' for LMs, 'mse' for vectors", + ) + + # Delta component + use_delta_component: bool = Field( + ..., + description="Whether to use delta component in reconstruction losses", + ) + + # CE/KL metrics + ce_loss_coeff: float = Field( + ..., + description="Coefficient for the CE loss", + ) + ce_kl_rounding_threshold: float = Field( + ..., + description="Threshold for rounding CI values in CE/KL metric computation", + ) + + @model_validator(mode="after") + def validate_model(self) -> Self: + if self.lr_schedule == "exponential": + assert self.lr_exponential_halflife is not None, ( + "lr_exponential_halflife must be set if lr_schedule is exponential" + ) + return self diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py new file mode 100644 index 000000000..1e94b72fb --- /dev/null +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -0,0 +1,430 @@ +# %% +"""Optimize CI values for a single prompt while keeping component weights fixed.""" + +import json +from dataclasses import dataclass +from pathlib import Path + +import torch +import torch.nn.functional as F +import torch.optim as optim +from jaxtyping import Bool, Float +from torch import Tensor +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from spd.configs import ImportanceMinimalityLossConfig +from spd.losses import compute_total_loss +from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache +from spd.models.components import make_mask_infos +from spd.scripts.model_loading import load_model_from_wandb +from spd.scripts.optim_cis.config import OptimCIConfig +from spd.utils.component_utils import calc_ci_l_zero +from spd.utils.general_utils import set_seed + + +@dataclass +class AliveComponentInfo: + """Info about which components are alive at each position for each layer.""" + + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] # Per-layer masks of alive positions + alive_counts: dict[str, list[int]] # Number of alive components per position per layer + + +def compute_alive_info( + ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], + ci_threshold: float, +) -> AliveComponentInfo: + """Compute which (position, component) pairs are alive based on initial CI values.""" + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = {} + alive_counts: dict[str, list[int]] = {} + + for layer_name, ci in ci_lower_leaky.items(): + mask = ci > ci_threshold + alive_masks[layer_name] = mask + # Count alive components per position: mask is [1, seq, C], sum over C + counts_per_pos = mask[0].sum(dim=-1) # [seq] + alive_counts[layer_name] = counts_per_pos.tolist() + + return AliveComponentInfo(alive_masks=alive_masks, alive_counts=alive_counts) + + +@dataclass +class OptimizableCIParams: + """Container for optimizable CI pre-sigmoid parameters.""" + + # List of pre-sigmoid tensors for alive positions at each sequence position + ci_pre_sigmoid: dict[str, list[Tensor]] # layer_name -> list of [alive_at_pos] values + alive_info: AliveComponentInfo + + def create_ci_outputs(self, model: ComponentModel, device: str) -> CIOutputs: + """Expand sparse pre-sigmoid values to full CI tensors and create CIOutputs.""" + pre_sigmoid: dict[str, Tensor] = {} + + for layer_name, mask in self.alive_info.alive_masks.items(): + # Create full tensors (default to 0 for non-alive positions) + full_pre_sigmoid = torch.zeros_like(mask, dtype=torch.float32, device=device) + + # Get pre-sigmoid list for this layer + layer_pre_sigmoid_list = self.ci_pre_sigmoid[layer_name] + + # For each position, place the values + seq_len = mask.shape[1] + for pos in range(seq_len): + pos_mask = mask[0, pos, :] # [C] + pos_pre_sigmoid = layer_pre_sigmoid_list[pos] # [alive_at_pos] + full_pre_sigmoid[0, pos, pos_mask] = pos_pre_sigmoid + + pre_sigmoid[layer_name] = full_pre_sigmoid + + return CIOutputs( + lower_leaky={k: model.lower_leaky_fn(v) for k, v in pre_sigmoid.items()}, + upper_leaky={k: model.upper_leaky_fn(v) for k, v in pre_sigmoid.items()}, + pre_sigmoid=pre_sigmoid, + ) + + def get_parameters(self) -> list[Tensor]: + """Get all optimizable parameters.""" + params: list[Tensor] = [] + for layer_pre_sigmoid_list in self.ci_pre_sigmoid.values(): + params.extend(layer_pre_sigmoid_list) + return params + + +def create_optimizable_ci_params( + alive_info: AliveComponentInfo, + initial_pre_sigmoid: dict[str, Tensor], +) -> OptimizableCIParams: + """Create optimizable CI parameters for alive positions. + + Creates parameters initialized from the initial pre-sigmoid values for each + (position, component) pair where initial CI > threshold. + """ + ci_pre_sigmoid: dict[str, list[Tensor]] = {} + + for layer_name, mask in alive_info.alive_masks.items(): + # Get initial pre-sigmoid values for this layer + layer_initial = initial_pre_sigmoid[layer_name] # [1, seq, C] + + # Create a tensor for each position + layer_pre_sigmoid_list: list[Tensor] = [] + seq_len = mask.shape[1] + for pos in range(seq_len): + pos_mask = mask[0, pos, :] # [C] + # Extract initial values for alive positions at this position + initial_values = layer_initial[0, pos, pos_mask].clone().detach() + initial_values.requires_grad_(True) + layer_pre_sigmoid_list.append(initial_values) + ci_pre_sigmoid[layer_name] = layer_pre_sigmoid_list + + return OptimizableCIParams( + ci_pre_sigmoid=ci_pre_sigmoid, + alive_info=alive_info, + ) + + +def compute_l0_stats( + ci_outputs: CIOutputs, + ci_alive_threshold: float, +) -> dict[str, float]: + """Compute L0 statistics for each layer.""" + stats: dict[str, float] = {} + for layer_name, layer_ci in ci_outputs.lower_leaky.items(): + l0_val = calc_ci_l_zero(layer_ci, ci_alive_threshold) + stats[f"l0/{layer_name}"] = l0_val + stats["l0/total"] = sum(stats.values()) + return stats + + +def compute_final_token_ce_kl( + model: ComponentModel, + batch: Tensor, + target_out: Tensor, + ci: dict[str, Tensor], + rounding_threshold: float, +) -> dict[str, float]: + """Compute CE and KL metrics for the final token only. + + Args: + model: The ComponentModel. + batch: Input tokens of shape [1, seq_len]. + target_out: Target model output logits of shape [1, seq_len, vocab]. + ci: Causal importance values (lower_leaky) per layer. + rounding_threshold: Threshold for rounding CI values to binary masks. + + Returns: + Dict with kl and ce_difference metrics for ci_masked, unmasked, and rounded_masked. + """ + assert batch.ndim == 2 and batch.shape[0] == 1, "Expected batch shape [1, seq_len]" + + # Get the label for CE (next token prediction at final position) + # The label is the token at the final position for the second-to-last logit prediction + # But since we're optimizing for CI on a single prompt, we use the final logit position + final_target_logits = target_out[0, -1, :] # [vocab] + + def kl_vs_target(logits: Tensor) -> float: + """KL divergence between predicted and target logits at final position.""" + final_logits = logits[0, -1, :] # [vocab] + target_probs = F.softmax(final_target_logits, dim=-1) + pred_log_probs = F.log_softmax(final_logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="sum").item() + + def ce_vs_target(logits: Tensor) -> float: + """CE between predicted logits and target's argmax at final position.""" + final_logits = logits[0, -1, :] # [vocab] + target_token = final_target_logits.argmax() + return F.cross_entropy(final_logits.unsqueeze(0), target_token.unsqueeze(0)).item() + + # Target model CE (baseline) + target_ce = ce_vs_target(target_out) + + # CI masked + ci_mask_infos = make_mask_infos(ci) + ci_masked_logits = model(batch, mask_infos=ci_mask_infos) + ci_masked_kl = kl_vs_target(ci_masked_logits) + ci_masked_ce = ce_vs_target(ci_masked_logits) + + # Unmasked (all components active) + unmasked_infos = make_mask_infos({k: torch.ones_like(v) for k, v in ci.items()}) + unmasked_logits = model(batch, mask_infos=unmasked_infos) + unmasked_kl = kl_vs_target(unmasked_logits) + unmasked_ce = ce_vs_target(unmasked_logits) + + # Rounded masked (binary masks based on threshold) + rounded_mask_infos = make_mask_infos( + {k: (v > rounding_threshold).float() for k, v in ci.items()} + ) + rounded_masked_logits = model(batch, mask_infos=rounded_mask_infos) + rounded_masked_kl = kl_vs_target(rounded_masked_logits) + rounded_masked_ce = ce_vs_target(rounded_masked_logits) + + return { + "kl_ci_masked": ci_masked_kl, + "kl_unmasked": unmasked_kl, + "kl_rounded_masked": rounded_masked_kl, + "ce_difference_ci_masked": ci_masked_ce - target_ce, + "ce_difference_unmasked": unmasked_ce - target_ce, + "ce_difference_rounded_masked": rounded_masked_ce - target_ce, + } + + +def optimize_ci_values( + model: ComponentModel, + tokens: Tensor, + label_token: int, + config: OptimCIConfig, + device: str, + ce_loss_coeff: float, +) -> tuple[dict[str, list[list[float]]], dict[str, float]]: + """Optimize CI values for a single prompt. + + Args: + model: The ComponentModel (weights will be frozen). + tokens: Tokenized prompt of shape [1, seq_len]. + label_token: The token to optimize CI values for. + config: Optimization configuration. + device: Device to run on. + ce_loss_coeff: Coefficient for the CE loss. + Returns: + Tuple of: + - Optimized CI values as dict of layer_name -> [seq][C] nested lists + - Final metrics dict + """ + # Freeze all model parameters + model.requires_grad_(False) + + # Get initial CI values from the model + with torch.no_grad(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + initial_ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + target_out = output_with_cache.output.detach() + + # Compute alive info and create optimizable parameters + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky, config.ci_threshold) + ci_params = create_optimizable_ci_params( + alive_info=alive_info, + initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, + ) + + # Log initial alive counts + total_alive = sum(sum(counts) for counts in alive_info.alive_counts.values()) + print(f"\nAlive components (CI > {config.ci_threshold}):") + for layer_name, counts in alive_info.alive_counts.items(): + layer_total = sum(counts) + print(f" {layer_name}: {layer_total} total across {len(counts)} positions") + print(f" Total: {total_alive}") + + # Get weight deltas for losses that need them + weight_deltas = model.calc_weight_deltas() + + # Setup optimizer + params = ci_params.get_parameters() + optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) + + # Optimization loop + final_metrics: dict[str, float] = {} + + for step in tqdm(range(config.steps), desc="Optimizing CI values"): + optimizer.zero_grad() + + # Create CI outputs from current parameters + ci_outputs = ci_params.create_ci_outputs(model, device) + + # Compute losses + total_loss, loss_terms = compute_total_loss( + loss_metric_configs=config.loss_metric_configs, + model=model, + batch=tokens, + ci=ci_outputs, + target_out=target_out, + weight_deltas=weight_deltas, + pre_weight_acts=output_with_cache.cache, + current_frac_of_training=step / config.steps, + sampling=config.sampling, + use_delta_component=config.use_delta_component, + n_mask_samples=config.n_mask_samples, + output_loss_type=config.output_loss_type, + ) + # Make a new loss which is the CE diff on the final sequence position between the given label + # and the outut logits + mask_infos = make_mask_infos( + component_masks=ci_outputs.lower_leaky, + routing_masks="all", + ) + # TODO: Support stochastic recon and e.g. subset recon. + out = model(tokens, mask_infos=mask_infos) + ce_loss = F.cross_entropy( + out[0, -1, :].unsqueeze(0), torch.tensor([label_token], device=device) + ) + total_loss += ce_loss_coeff * ce_loss + + total_loss.backward() + optimizer.step() + + # Logging + if step % config.log_freq == 0 or step == config.steps - 1: + l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) + + # Compute CE/KL metrics for final token only + with torch.no_grad(): + ce_kl_stats = compute_final_token_ce_kl( + model=model, + batch=tokens, + target_out=target_out, + ci=ci_outputs.lower_leaky, + rounding_threshold=config.ce_kl_rounding_threshold, + ) + + tqdm.write(f"\n--- Step {step} ---") + for name, value in loss_terms.items(): + tqdm.write(f" {name}: {value:.6f}") + for name, value in l0_stats.items(): + tqdm.write(f" {name}: {value:.2f}") + for name, value in ce_kl_stats.items(): + tqdm.write(f" {name}: {value:.6f}") + + if step == config.steps - 1: + final_metrics = {**loss_terms, **l0_stats, **ce_kl_stats} + + # Extract final CI values + with torch.no_grad(): + final_ci_outputs = ci_params.create_ci_outputs(model, device) + + # Convert to nested lists for JSON serialization + optimized_ci: dict[str, list[list[float]]] = {} + for layer_name, ci_tensor in final_ci_outputs.lower_leaky.items(): + # ci_tensor is [1, seq, C], convert to [seq][C] + optimized_ci[layer_name] = ci_tensor[0].cpu().tolist() + + return optimized_ci, final_metrics + + +def get_out_dir() -> Path: + """Get the output directory for optimization results.""" + out_dir = Path(__file__).parent / "out" + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir + + +# %% +# Example configuration +if __name__ == "__main__": + # Configuration + config = OptimCIConfig( + seed=0, + # wandb_path="wandb:goodfire/spd/runs/jyo9duz5", # ss_gpt2_simple-1.25M (4L) + wandb_path="wandb:goodfire/spd/runs/33n6xjjt", # ss_gpt2_simple-1L + prompt="They walked hand in", + label="hand", + lr=1e-3, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + steps=10000, + log_freq=500, + loss_metric_configs=[ + # StochasticReconSubsetLossConfig(coeff=1.0, routing=UniformKSubsetRoutingConfig()), + ImportanceMinimalityLossConfig(coeff=1e-1, pnorm=0.3), + ], + ce_loss_coeff=1.0, + ci_threshold=1e-6, + sampling="continuous", + n_mask_samples=1, + output_loss_type="kl", + use_delta_component=True, + ce_kl_rounding_threshold=0.5, + ) + + set_seed(config.seed) + + loaded = load_model_from_wandb(config.wandb_path) + model, run_config, device = loaded.model, loaded.config, loaded.device + + tokenizer = AutoTokenizer.from_pretrained(run_config.tokenizer_name) + assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" + + print(f"\nPrompt: {config.prompt!r}") + tokens = tokenizer.encode(config.prompt, return_tensors="pt", add_special_tokens=False) + assert isinstance(tokens, Tensor), "Expected Tensor" + tokens = tokens.to(device) + print(f"Tokens shape: {tokens.shape}") + token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] + print(f"Token strings: {token_strings}") + label_token_ids = tokenizer.encode(config.label, add_special_tokens=False) + assert len(label_token_ids) == 1, f"Expected single token for label, got {len(label_token_ids)}" + label_token = label_token_ids[0] + print(f"Label token: {label_token}") + + # Run optimization + optimized_ci, final_metrics = optimize_ci_values( + model=model, + tokens=tokens, + label_token=label_token, + config=config, + device=device, + ce_loss_coeff=config.ce_loss_coeff, + ) + + # Save results + out_dir = get_out_dir() + output_path = out_dir / f"optimized_ci_{loaded.wandb_id}.json" + + output_data = { + "config": config.model_dump(), + "prompt": config.prompt, + "token_strings": token_strings, + "optimized_ci": optimized_ci, + "final_metrics": final_metrics, + "wandb_id": loaded.wandb_id, + } + + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"\nSaved optimized CI values to {output_path}") From 787ae374c0c174e3363c40dc7f4ffcc49cd7cb66 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 09:18:57 +0000 Subject: [PATCH 30/36] Simplify losses in run_optim --- spd/scripts/optim_cis/config.py | 24 +++++------ spd/scripts/optim_cis/run_optim_cis.py | 60 ++++++++++---------------- 2 files changed, 33 insertions(+), 51 deletions(-) diff --git a/spd/scripts/optim_cis/config.py b/spd/scripts/optim_cis/config.py index 97cd6eb01..0cf717712 100644 --- a/spd/scripts/optim_cis/config.py +++ b/spd/scripts/optim_cis/config.py @@ -1,11 +1,11 @@ """Configuration for CI optimization on single prompts.""" -from typing import Annotated, Literal, Self +from typing import Literal, Self from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt, model_validator from spd.base_config import BaseConfig -from spd.configs import LossMetricConfigType, SamplingType +from spd.configs import ImportanceMinimalityLossConfig, SamplingType from spd.spd_types import Probability @@ -60,12 +60,14 @@ class OptimCIConfig(BaseConfig): description="Frequency of logging during optimization", ) - # Loss configuration - loss_metric_configs: list[Annotated[LossMetricConfigType, Field(discriminator="classname")]] = ( - Field( - ..., - description="List of loss metric configs (must have coeff set)", - ) + # Loss configs + imp_min_config: ImportanceMinimalityLossConfig = Field( + ..., + description="Configuration for the importance minimality loss", + ) + ce_loss_coeff: float = Field( + ..., + description="Coefficient for the CE loss", ) # CI thresholds and sampling @@ -87,17 +89,11 @@ class OptimCIConfig(BaseConfig): description="Loss type for reconstruction: 'kl' for LMs, 'mse' for vectors", ) - # Delta component use_delta_component: bool = Field( ..., description="Whether to use delta component in reconstruction losses", ) - # CE/KL metrics - ce_loss_coeff: float = Field( - ..., - description="Coefficient for the CE loss", - ) ce_kl_rounding_threshold: float = Field( ..., description="Threshold for rounding CI values in CE/KL metric computation", diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 1e94b72fb..74ee3f7cf 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -15,7 +15,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from spd.configs import ImportanceMinimalityLossConfig -from spd.losses import compute_total_loss +from spd.metrics import importance_minimality_loss from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.scripts.model_loading import load_model_from_wandb @@ -215,7 +215,6 @@ def optimize_ci_values( label_token: int, config: OptimCIConfig, device: str, - ce_loss_coeff: float, ) -> tuple[dict[str, list[list[float]]], dict[str, float]]: """Optimize CI values for a single prompt. @@ -225,12 +224,15 @@ def optimize_ci_values( label_token: The token to optimize CI values for. config: Optimization configuration. device: Device to run on. - ce_loss_coeff: Coefficient for the CE loss. Returns: Tuple of: - Optimized CI values as dict of layer_name -> [seq][C] nested lists - Final metrics dict """ + imp_min_coeff = config.imp_min_config.coeff + assert imp_min_coeff is not None, "Importance minimality loss coefficient must be set" + ce_loss_coeff = config.ce_loss_coeff + # Freeze all model parameters model.requires_grad_(False) @@ -259,14 +261,9 @@ def optimize_ci_values( print(f" {layer_name}: {layer_total} total across {len(counts)} positions") print(f" Total: {total_alive}") - # Get weight deltas for losses that need them - weight_deltas = model.calc_weight_deltas() - - # Setup optimizer params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) - # Optimization loop final_metrics: dict[str, float] = {} for step in tqdm(range(config.steps), desc="Optimizing CI values"): @@ -275,38 +272,34 @@ def optimize_ci_values( # Create CI outputs from current parameters ci_outputs = ci_params.create_ci_outputs(model, device) - # Compute losses - total_loss, loss_terms = compute_total_loss( - loss_metric_configs=config.loss_metric_configs, - model=model, - batch=tokens, - ci=ci_outputs, - target_out=target_out, - weight_deltas=weight_deltas, - pre_weight_acts=output_with_cache.cache, - current_frac_of_training=step / config.steps, - sampling=config.sampling, - use_delta_component=config.use_delta_component, - n_mask_samples=config.n_mask_samples, - output_loss_type=config.output_loss_type, - ) - # Make a new loss which is the CE diff on the final sequence position between the given label - # and the outut logits mask_infos = make_mask_infos( component_masks=ci_outputs.lower_leaky, routing_masks="all", ) - # TODO: Support stochastic recon and e.g. subset recon. out = model(tokens, mask_infos=mask_infos) + + imp_min_loss = importance_minimality_loss( + ci_upper_leaky=ci_outputs.upper_leaky, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) ce_loss = F.cross_entropy( out[0, -1, :].unsqueeze(0), torch.tensor([label_token], device=device) ) - total_loss += ce_loss_coeff * ce_loss + total_loss = ce_loss_coeff * ce_loss + imp_min_coeff * imp_min_loss + loss_terms = { + "imp_min_loss": imp_min_loss.item(), + "ce_loss": ce_loss.item(), + "total_loss": total_loss.item(), + } total_loss.backward() optimizer.step() - # Logging if step % config.log_freq == 0 or step == config.steps - 1: l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) @@ -331,14 +324,11 @@ def optimize_ci_values( if step == config.steps - 1: final_metrics = {**loss_terms, **l0_stats, **ce_kl_stats} - # Extract final CI values with torch.no_grad(): final_ci_outputs = ci_params.create_ci_outputs(model, device) - # Convert to nested lists for JSON serialization optimized_ci: dict[str, list[list[float]]] = {} for layer_name, ci_tensor in final_ci_outputs.lower_leaky.items(): - # ci_tensor is [1, seq, C], convert to [seq][C] optimized_ci[layer_name] = ci_tensor[0].cpu().tolist() return optimized_ci, final_metrics @@ -368,11 +358,8 @@ def get_out_dir() -> Path: lr_warmup_pct=0.01, steps=10000, log_freq=500, - loss_metric_configs=[ - # StochasticReconSubsetLossConfig(coeff=1.0, routing=UniformKSubsetRoutingConfig()), - ImportanceMinimalityLossConfig(coeff=1e-1, pnorm=0.3), - ], - ce_loss_coeff=1.0, + imp_min_config=ImportanceMinimalityLossConfig(coeff=1e-1, pnorm=0.3), + ce_loss_coeff=1, ci_threshold=1e-6, sampling="continuous", n_mask_samples=1, @@ -408,7 +395,6 @@ def get_out_dir() -> Path: label_token=label_token, config=config, device=device, - ce_loss_coeff=config.ce_loss_coeff, ) # Save results From 9757306eac5f6dd17aee116c98fe0b2b48d5aa01 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 10:59:52 +0000 Subject: [PATCH 31/36] Use stoch masks instead of ci masks in optim_cis --- spd/scripts/optim_cis/run_optim_cis.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 74ee3f7cf..781ac4c14 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -18,9 +18,10 @@ from spd.metrics import importance_minimality_loss from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache from spd.models.components import make_mask_infos +from spd.routing import AllLayersRouter from spd.scripts.model_loading import load_model_from_wandb from spd.scripts.optim_cis.config import OptimCIConfig -from spd.utils.component_utils import calc_ci_l_zero +from spd.utils.component_utils import calc_ci_l_zero, calc_stochastic_component_mask_info from spd.utils.general_utils import set_seed @@ -261,6 +262,8 @@ def optimize_ci_values( print(f" {layer_name}: {layer_total} total across {len(counts)} positions") print(f" Total: {total_alive}") + weight_deltas = model.calc_weight_deltas() + params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) @@ -272,9 +275,11 @@ def optimize_ci_values( # Create CI outputs from current parameters ci_outputs = ci_params.create_ci_outputs(model, device) - mask_infos = make_mask_infos( - component_masks=ci_outputs.lower_leaky, - routing_masks="all", + mask_infos = calc_stochastic_component_mask_info( + causal_importances=ci_outputs.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), ) out = model(tokens, mask_infos=mask_infos) @@ -351,7 +356,7 @@ def get_out_dir() -> Path: wandb_path="wandb:goodfire/spd/runs/33n6xjjt", # ss_gpt2_simple-1L prompt="They walked hand in", label="hand", - lr=1e-3, + lr=1e-2, weight_decay=0.0, lr_schedule="cosine", lr_exponential_halflife=None, From 2dae809f4b4d9dda989a57d5226c825b61862fc7 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 12:09:43 +0000 Subject: [PATCH 32/36] Tweaks to run_optim_cis.py --- spd/scripts/optim_cis/config.py | 5 -- spd/scripts/optim_cis/run_optim_cis.py | 80 +++++++++++++++++--------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/spd/scripts/optim_cis/config.py b/spd/scripts/optim_cis/config.py index 0cf717712..cb0acb920 100644 --- a/spd/scripts/optim_cis/config.py +++ b/spd/scripts/optim_cis/config.py @@ -89,11 +89,6 @@ class OptimCIConfig(BaseConfig): description="Loss type for reconstruction: 'kl' for LMs, 'mse' for vectors", ) - use_delta_component: bool = Field( - ..., - description="Whether to use delta component in reconstruction losses", - ) - ce_kl_rounding_threshold: float = Field( ..., description="Threshold for rounding CI values in CE/KL metric computation", diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 781ac4c14..96cd06bdc 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -216,7 +216,7 @@ def optimize_ci_values( label_token: int, config: OptimCIConfig, device: str, -) -> tuple[dict[str, list[list[float]]], dict[str, float]]: +) -> OptimizableCIParams: """Optimize CI values for a single prompt. Args: @@ -226,9 +226,7 @@ def optimize_ci_values( config: Optimization configuration. device: Device to run on. Returns: - Tuple of: - - Optimized CI values as dict of layer_name -> [seq][C] nested lists - - Final metrics dict + The OptimizableCIParams object. """ imp_min_coeff = config.imp_min_config.coeff assert imp_min_coeff is not None, "Importance minimality loss coefficient must be set" @@ -267,8 +265,6 @@ def optimize_ci_values( params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) - final_metrics: dict[str, float] = {} - for step in tqdm(range(config.steps), desc="Optimizing CI values"): optimizer.zero_grad() @@ -296,14 +292,8 @@ def optimize_ci_values( out[0, -1, :].unsqueeze(0), torch.tensor([label_token], device=device) ) total_loss = ce_loss_coeff * ce_loss + imp_min_coeff * imp_min_loss - loss_terms = { - "imp_min_loss": imp_min_loss.item(), - "ce_loss": ce_loss.item(), - "total_loss": total_loss.item(), - } - - total_loss.backward() - optimizer.step() + # Get the output probability for the label_token in the final seq position + label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] if step % config.log_freq == 0 or step == config.steps - 1: l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) @@ -318,25 +308,38 @@ def optimize_ci_values( rounding_threshold=config.ce_kl_rounding_threshold, ) + # Also calculate the ci-masked label probability + with torch.no_grad(): + mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") + out = model(tokens, mask_infos=mask_infos) + ci_masked_label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] + + log_terms = { + "imp_min_loss": imp_min_loss.item(), + "ce_loss": ce_loss.item(), + "total_loss": total_loss.item(), + "stoch_masked_label_prob": label_prob.item(), + "ci_masked_label_prob": ci_masked_label_prob.item(), + } tqdm.write(f"\n--- Step {step} ---") - for name, value in loss_terms.items(): + for name, value in log_terms.items(): tqdm.write(f" {name}: {value:.6f}") for name, value in l0_stats.items(): tqdm.write(f" {name}: {value:.2f}") for name, value in ce_kl_stats.items(): tqdm.write(f" {name}: {value:.6f}") - if step == config.steps - 1: - final_metrics = {**loss_terms, **l0_stats, **ce_kl_stats} + total_loss.backward() + optimizer.step() - with torch.no_grad(): - final_ci_outputs = ci_params.create_ci_outputs(model, device) + # with torch.no_grad(): + # final_ci_outputs = ci_params.create_ci_outputs(model, device) - optimized_ci: dict[str, list[list[float]]] = {} - for layer_name, ci_tensor in final_ci_outputs.lower_leaky.items(): - optimized_ci[layer_name] = ci_tensor[0].cpu().tolist() + # optimized_ci: dict[str, list[list[float]]] = {} + # for layer_name, ci_tensor in final_ci_outputs.lower_leaky.items(): + # optimized_ci[layer_name] = ci_tensor[0].cpu().tolist() - return optimized_ci, final_metrics + return ci_params def get_out_dir() -> Path: @@ -361,7 +364,7 @@ def get_out_dir() -> Path: lr_schedule="cosine", lr_exponential_halflife=None, lr_warmup_pct=0.01, - steps=10000, + steps=2000, log_freq=500, imp_min_config=ImportanceMinimalityLossConfig(coeff=1e-1, pnorm=0.3), ce_loss_coeff=1, @@ -369,7 +372,6 @@ def get_out_dir() -> Path: sampling="continuous", n_mask_samples=1, output_loss_type="kl", - use_delta_component=True, ce_kl_rounding_threshold=0.5, ) @@ -394,7 +396,7 @@ def get_out_dir() -> Path: print(f"Label token: {label_token}") # Run optimization - optimized_ci, final_metrics = optimize_ci_values( + ci_params = optimize_ci_values( model=model, tokens=tokens, label_token=label_token, @@ -402,6 +404,29 @@ def get_out_dir() -> Path: device=device, ) + # Get final metrics + ci_outputs = ci_params.create_ci_outputs(model, device) + l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) + + with torch.no_grad(): + target_out = model(tokens) + ce_kl_stats = compute_final_token_ce_kl( + model=model, + batch=tokens, + target_out=target_out, + ci=ci_outputs.lower_leaky, + rounding_threshold=config.ce_kl_rounding_threshold, + ) + # Use ci-masked model to get final label probability + mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") + out = model(tokens, mask_infos=mask_infos) + label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] + + final_metrics = {**l0_stats, **ce_kl_stats, "ci_masked_label_prob": label_prob.item()} + print(f"\nFinal metrics after {config.steps} steps:") + for name, value in final_metrics.items(): + print(f" {name}: {value:.6f}") + # Save results out_dir = get_out_dir() output_path = out_dir / f"optimized_ci_{loaded.wandb_id}.json" @@ -410,8 +435,7 @@ def get_out_dir() -> Path: "config": config.model_dump(), "prompt": config.prompt, "token_strings": token_strings, - "optimized_ci": optimized_ci, - "final_metrics": final_metrics, + "optimized_ci": {k: v[0].cpu().tolist() for k, v in ci_outputs.lower_leaky.items()}, "wandb_id": loaded.wandb_id, } From 46ad2963b3b39944c931f2df74d662fc69a59241 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 12:12:00 +0000 Subject: [PATCH 33/36] Minor clean --- spd/scripts/optim_cis/run_optim_cis.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 96cd06bdc..9f3a450b2 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -332,13 +332,6 @@ def optimize_ci_values( total_loss.backward() optimizer.step() - # with torch.no_grad(): - # final_ci_outputs = ci_params.create_ci_outputs(model, device) - - # optimized_ci: dict[str, list[list[float]]] = {} - # for layer_name, ci_tensor in final_ci_outputs.lower_leaky.items(): - # optimized_ci[layer_name] = ci_tensor[0].cpu().tolist() - return ci_params @@ -350,9 +343,7 @@ def get_out_dir() -> Path: # %% -# Example configuration if __name__ == "__main__": - # Configuration config = OptimCIConfig( seed=0, # wandb_path="wandb:goodfire/spd/runs/jyo9duz5", # ss_gpt2_simple-1.25M (4L) @@ -393,7 +384,6 @@ def get_out_dir() -> Path: label_token_ids = tokenizer.encode(config.label, add_special_tokens=False) assert len(label_token_ids) == 1, f"Expected single token for label, got {len(label_token_ids)}" label_token = label_token_ids[0] - print(f"Label token: {label_token}") # Run optimization ci_params = optimize_ci_values( @@ -427,7 +417,6 @@ def get_out_dir() -> Path: for name, value in final_metrics.items(): print(f" {name}: {value:.6f}") - # Save results out_dir = get_out_dir() output_path = out_dir / f"optimized_ci_{loaded.wandb_id}.json" From c31a67d634082223a7526e7b2a21a57015b78357 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 15:34:28 +0000 Subject: [PATCH 34/36] Support multiple prompts in calc_local_attributions --- spd/scripts/calc_local_attributions.py | 585 ++++++++++++++++--------- spd/scripts/optim_cis/run_optim_cis.py | 15 +- 2 files changed, 389 insertions(+), 211 deletions(-) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index 010e69508..f8998ea5b 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -6,7 +6,7 @@ from typing import Any import torch -from jaxtyping import Bool, Float +from jaxtyping import Bool, Float, Int from torch import Tensor, nn from tqdm.auto import tqdm from transformers import AutoTokenizer @@ -22,87 +22,214 @@ @dataclass class LayerAliveInfo: - """Info about alive components for a layer.""" + """Info about alive components for a layer (single batch item).""" - alive_mask: Bool[Tensor, "1 s dim"] # Which (pos, component) pairs are alive + alive_mask: Bool[Tensor, "s dim"] # Which (pos, component) pairs are alive alive_c_idxs: list[int] # Components alive at any position c_to_trimmed: dict[int, int] # original idx -> trimmed idx +@dataclass +class TokensAndCI: + """Tokenized prompts and their CI values.""" + + tokens: Int[Tensor, "N seq"] + ci_lower_leaky: dict[str, Float[Tensor, "N seq C"]] + set_names: list[str] + prompts: list[str] + all_token_strings: list[list[str]] + + def compute_layer_alive_info( layer_name: str, ci_lower_leaky: dict[str, Tensor], - output_probs: Float[Tensor, "1 s vocab"] | None, + output_probs: Float[Tensor, "N s vocab"] | None, ci_threshold: float, output_prob_threshold: float, n_seq: int, + n_batch: int, device: str, -) -> LayerAliveInfo: - """Compute alive info for a layer. Handles regular, wte, and output layers.""" +) -> list[LayerAliveInfo]: + """Compute alive info for a layer across all batch items. + + Returns list of LayerAliveInfo, one per batch item. + """ if layer_name == "wte": # WTE: single pseudo-component, always alive at all positions - alive_mask = torch.ones(1, n_seq, 1, device=device, dtype=torch.bool) + alive_mask = torch.ones(n_seq, 1, device=device, dtype=torch.bool) alive_c_idxs = [0] + c_to_trimmed = {0: 0} + # Same info for all batch items + return [LayerAliveInfo(alive_mask, alive_c_idxs, c_to_trimmed) for _ in range(n_batch)] + elif layer_name == "output": assert output_probs is not None - alive_mask = output_probs >= output_prob_threshold - alive_c_idxs = torch.where(alive_mask[0].any(dim=0))[0].tolist() - else: - ci = ci_lower_leaky[layer_name] - alive_mask = ci >= ci_threshold - alive_c_idxs = torch.where(alive_mask[0].any(dim=0))[0].tolist() + # output_probs: [N, seq, vocab] + full_alive_mask = output_probs >= output_prob_threshold # [N, seq, vocab] + results = [] + for b in range(n_batch): + batch_mask = full_alive_mask[b] # [seq, vocab] + alive_c_idxs = torch.where(batch_mask.any(dim=0))[0].tolist() + c_to_trimmed = {c: i for i, c in enumerate(alive_c_idxs)} + results.append(LayerAliveInfo(batch_mask, alive_c_idxs, c_to_trimmed)) + return results - c_to_trimmed = {c: i for i, c in enumerate(alive_c_idxs)} - return LayerAliveInfo(alive_mask, alive_c_idxs, c_to_trimmed) + else: + ci = ci_lower_leaky[layer_name] # [N, seq, C] + full_alive_mask = ci >= ci_threshold # [N, seq, C] + results = [] + for b in range(n_batch): + batch_mask = full_alive_mask[b] # [seq, C] + alive_c_idxs = torch.where(batch_mask.any(dim=0))[0].tolist() + c_to_trimmed = {c: i for i, c in enumerate(alive_c_idxs)} + results.append(LayerAliveInfo(batch_mask, alive_c_idxs, c_to_trimmed)) + return results def load_ci_from_json( ci_vals_path: str | Path, - expected_prompt: str, device: str, -) -> dict[str, Float[Tensor, "1 seq C"]]: +) -> tuple[dict[str, Float[Tensor, "N seq C"]], list[str], list[str]]: """Load precomputed CI values from a JSON file. + Expected format: + { + "ci_sets": { + "set_name_1": { + "prompt": "...", # Each set has its own prompt + "ci_vals": {"layer1": [[...]], ...}, + ... + }, + "set_name_2": {...}, + } + } + Args: - ci_vals_path: Path to JSON file from run_optim_cis.py - expected_prompt: The prompt we're analyzing (must match the JSON) + ci_vals_path: Path to JSON file with ci_sets structure device: Device to load tensors to Returns: - Dict mapping layer_name -> CI tensor of shape [1, seq, C] - - Raises: - ValueError: If the prompt in the JSON doesn't match expected_prompt + Tuple of: + - Dict mapping layer_name -> CI tensor of shape [N, seq, C] where N = number of sets + - List of set names in batch order + - List of prompts in batch order (one per set) """ with open(ci_vals_path) as f: data = json.load(f) - json_prompt = data["prompt"] - if json_prompt != expected_prompt: - raise ValueError( - f"Prompt mismatch: JSON has {json_prompt!r}, but expected {expected_prompt!r}" - ) + ci_sets: dict[str, dict[str, Any]] = data["ci_sets"] + set_names = list(ci_sets.keys()) + assert len(set_names) > 0, "No CI sets found in JSON" + + # Extract prompt from each set + prompts: list[str] = [ci_sets[set_name]["prompt"] for set_name in set_names] + + # Get layer names from first set + first_set = ci_sets[set_names[0]] + layer_names = list(first_set["ci_vals"].keys()) + # Stack tensors along batch dimension ci_lower_leaky: dict[str, Tensor] = {} - for layer_name, ci_list in data["optimized_ci"].items(): - # ci_list is [seq][C], convert to tensor [1, seq, C] - ci_tensor = torch.tensor(ci_list, device=device).unsqueeze(0) - ci_lower_leaky[layer_name] = ci_tensor + for layer_name in layer_names: + # Collect [seq, C] tensors from each set + layer_tensors = [ + torch.tensor(ci_sets[set_name]["ci_vals"][layer_name], device=device) + for set_name in set_names + ] + # Stack to [N, seq, C] + ci_lower_leaky[layer_name] = torch.stack(layer_tensors, dim=0) + + return ci_lower_leaky, set_names, prompts + + +def get_tokens_and_ci( + model: ComponentModel, + tokenizer: PreTrainedTokenizerFast, + sampling: SamplingType, + device: str, + ci_vals_path: str | Path | None, + prompts: list[str] | None, +) -> TokensAndCI: + """Get tokenized prompts and CI values, either from file or computed from model. + + Args: + model: The ComponentModel to use for computing CI if needed. + tokenizer: Tokenizer for the prompts. + sampling: Sampling type for CI computation. + device: Device to place tensors on. + ci_vals_path: Path to JSON with precomputed CI values and prompts. + If provided, prompts are read from the JSON file. + prompts: List of prompts to use when ci_vals_path is None. + + Returns: + TokensAndCI containing tokens, CI values, set names, prompts, and token strings. + """ + if ci_vals_path is not None: + print(f"\nLoading precomputed CI values from {ci_vals_path}") + ci_lower_leaky, set_names, prompts = load_ci_from_json(ci_vals_path, device) + print(f"Loaded CI values for layers: {list(ci_lower_leaky.keys())}") + print(f"CI sets: {set_names}") + else: + assert prompts is not None, "prompts is required when ci_vals_path is None" + set_names = [f"prompt_{i}" for i in range(len(prompts))] + ci_lower_leaky = None + + # Tokenize each prompt + tokens_list: list[Tensor] = [] + all_token_strings: list[list[str]] = [] + for i, p in enumerate(prompts): + if ci_vals_path is not None: + print(f"\nPrompt for {set_names[i]}: {p!r}") + toks = tokenizer.encode(p, return_tensors="pt", add_special_tokens=False) + assert isinstance(toks, Tensor), "Expected Tensor" + tokens_list.append(toks) + first_row = toks[0] + assert isinstance(first_row, Tensor), "Expected 2D tensor" + all_token_strings.append([tokenizer.decode([t]) for t in first_row.tolist()]) + if ci_vals_path is not None: + print(f" Token strings: {all_token_strings[-1]}") + + # Validate all prompts have the same token length + token_lengths = [t.shape[1] for t in tokens_list] + assert all(length == token_lengths[0] for length in token_lengths), ( + f"All prompts must tokenize to the same length, got {token_lengths}" + ) + + tokens = torch.cat(tokens_list, dim=0).to(device) # [N, seq] + print(f"\nTokens shape: {tokens.shape}") + + # Compute CI values from model if not provided + if ci_lower_leaky is None: + print("\nComputing CI values from model...") + with torch.no_grad(): + output_with_cache = model(tokens, cache_type="input") + ci_lower_leaky = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ).lower_leaky - return ci_lower_leaky + return TokensAndCI( + tokens=tokens, + ci_lower_leaky=ci_lower_leaky, + set_names=set_names, + prompts=prompts, + all_token_strings=all_token_strings, + ) def compute_local_attributions( model: ComponentModel, - tokens: Float[Tensor, "1 seq"], + tokens: Int[Tensor, "N seq"], sources_by_target: dict[str, list[str]], ci_threshold: float, output_prob_threshold: float, sampling: SamplingType, device: str, - ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]] | None = None, -) -> tuple[list[PairAttribution], Float[Tensor, "1 seq vocab"]]: - """Compute local attributions for a single prompt. + ci_lower_leaky: dict[str, Float[Tensor, "N seq C"]], + set_names: list[str], +) -> tuple[dict[str, list[PairAttribution]], Float[Tensor, "N seq vocab"]]: + """Compute local attributions for multiple prompts across multiple CI sets. For each valid layer pair (in_layer, out_layer), computes the gradient-based attribution of output component activations with respect to input component @@ -110,40 +237,43 @@ def compute_local_attributions( Args: model: The ComponentModel to analyze. - tokens: Tokenized prompt of shape [1, seq_len]. + tokens: Tokenized prompts of shape [N, seq_len] - one per CI set. sources_by_target: Dict mapping out_layer -> list of in_layers. ci_threshold: Threshold for considering a component alive at a position. output_prob_threshold: Threshold for considering an output logit alive (on softmax probs). sampling: Sampling type to use for causal importances. device: Device to run on. - ci_lower_leaky: Optional precomputed/optimized CI values. If None, will use model CI. - When provided, we still compute original model CI to track "ghost" nodes. + ci_lower_leaky: Precomputed/optimized CI values with shape [N, seq, C]. + set_names: Ordered list of set names corresponding to batch dimension. Returns: - List of PairAttribution objects. + Tuple of: + - Dict mapping set_name -> list of PairAttribution objects + - Output probabilities of shape [N, seq, vocab] """ - n_seq = tokens.shape[1] + n_batch, n_seq = tokens.shape + # Validate batch size matches CI tensors + first_ci = next(iter(ci_lower_leaky.values())) + assert first_ci.shape[0] == n_batch, f"CI batch size {first_ci.shape[0]} != tokens {n_batch}" + assert len(set_names) == n_batch, f"set_names length {len(set_names)} != n_batch {n_batch}" with torch.no_grad(): output_with_cache: OutputWithCache = model(tokens, cache_type="input") # Always compute original CI from model (needed for ghost nodes when using optimized CI) + # Note: original CI is computed from single-batch input, then repeated for comparison with torch.no_grad(): ci = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, sampling=sampling, detach_inputs=False, ) - ci_original = ci.lower_leaky - - # Use provided CI values if given, otherwise use original - if ci_lower_leaky is None: - ci_lower_leaky = ci_original + ci_original = ci.lower_leaky # [N, seq, C] # Log the l0 (lower_leaky values > ci_threshold) for each layer - print("L0 values for final seq position:") + print("L0 values for final seq position (first CI set):") for layer, ci_vals in ci_lower_leaky.items(): - # We only care about the final position + # We only care about the final position, show first batch item l0_vals = (ci_vals[0, -1] > ci_threshold).sum().item() print(f" Layer {layer} has {l0_vals} components alive at {ci_threshold}") @@ -159,11 +289,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) - mask_infos = make_mask_infos( - # component_masks={k: torch.ones_like(v) for k, v in ci_lower_leaky.items()}, - component_masks=ci_lower_leaky, - routing_masks="all", - ) + mask_infos = make_mask_infos(component_masks=ci_lower_leaky, routing_masks="all") with torch.enable_grad(): comp_output_with_cache: OutputWithCache = model( tokens, mask_infos=mask_infos, cache_type="component_acts" @@ -182,38 +308,66 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An for sources in sources_by_target.values(): all_layers.update(sources) - alive_info: dict[str, LayerAliveInfo] = {} - original_alive_info: dict[str, LayerAliveInfo] = {} + # alive_info[layer] -> list of LayerAliveInfo, one per batch item + alive_info: dict[str, list[LayerAliveInfo]] = {} + original_alive_info: dict[str, list[LayerAliveInfo]] = {} + # alive_c_union[layer] -> union of alive_c_idxs across all batches + alive_c_union: dict[str, list[int]] = {} + for layer in all_layers: alive_info[layer] = compute_layer_alive_info( - layer, ci_lower_leaky, output_probs, ci_threshold, output_prob_threshold, n_seq, device + layer, + ci_lower_leaky, + output_probs, + ci_threshold, + output_prob_threshold, + n_seq, + n_batch, + device, ) # Compute original alive info (from model CI, not optimized CI) original_alive_info[layer] = compute_layer_alive_info( - layer, ci_original, output_probs, ci_threshold, output_prob_threshold, n_seq, device + layer, + ci_original, + output_probs, + ci_threshold, + output_prob_threshold, + n_seq, + n_batch, + device, ) + # Compute union of alive components across batches + all_alive: set[int] = set() + for batch_info in alive_info[layer]: + all_alive.update(batch_info.alive_c_idxs) + alive_c_union[layer] = sorted(all_alive) - local_attributions: list[PairAttribution] = [] + # Initialize output dictionary + local_attributions_by_set: dict[str, list[PairAttribution]] = {name: [] for name in set_names} for target, sources in tqdm(sources_by_target.items(), desc="Target layers"): - target_info = alive_info[target] - out_pre_detach: Float[Tensor, "1 s dim"] = cache[f"{target}_pre_detach"] + target_infos = alive_info[target] # list of LayerAliveInfo per batch + out_pre_detach: Float[Tensor, "N s dim"] = cache[f"{target}_pre_detach"] - source_infos = [alive_info[source] for source in sources] - in_post_detaches: list[Float[Tensor, "1 s dim"]] = [ + all_source_infos = [alive_info[source] for source in sources] # list of lists + in_post_detaches: list[Float[Tensor, "N s dim"]] = [ cache[f"{source}_post_detach"] for source in sources ] - # Initialize attribution tensors at final trimmed size - attributions: list[Float[Tensor, "s_in n_c_in s_out n_c_out"]] = [ - torch.zeros( - n_seq, - len(source_info.alive_c_idxs), - n_seq, - len(target_info.alive_c_idxs), - device=device, - ) - for source_info in source_infos + # Initialize per-batch attribution tensors + # attributions[source_idx][batch_idx] = tensor + attributions: list[list[Float[Tensor, "s_in n_c_in s_out n_c_out"]]] = [ + [ + torch.zeros( + n_seq, + len(source_infos[b].alive_c_idxs), + n_seq, + len(target_infos[b].alive_c_idxs), + device=device, + ) + for b in range(n_batch) + ] + for source_infos in all_source_infos ] # NOTE: o->q will be treated as an attention pair even though there are no attrs @@ -221,67 +375,83 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An is_attention_output = any(is_kv_to_o_pair(source, target) for source in sources) for s_out in tqdm(range(n_seq), desc=f"{target} <- {sources}", leave=False): - # Get alive output components at this position - s_out_alive_c: list[int] = [ - c for c in target_info.alive_c_idxs if target_info.alive_mask[0, s_out, c] + # Union of alive c_out across all batches at this position + s_out_alive_c_union: list[int] = [ + c + for c in alive_c_union[target] + if any(info.alive_mask[s_out, c] for info in target_infos) ] - if not s_out_alive_c: + if not s_out_alive_c_union: continue - for c_out in s_out_alive_c: + for c_out in s_out_alive_c_union: + # Sum over batch dimension for efficient batched gradient in_post_detach_grads = torch.autograd.grad( - outputs=out_pre_detach[0, s_out, c_out], + outputs=out_pre_detach[:, s_out, c_out].sum(), inputs=in_post_detaches, retain_graph=True, ) # Handle causal attention mask s_in_range = range(s_out + 1) if is_attention_output else range(s_out, s_out + 1) - trimmed_c_out = target_info.c_to_trimmed[c_out] with torch.no_grad(): - for source, source_info, grad, in_post_detach, attr in zip( - sources, - source_infos, - in_post_detach_grads, - in_post_detaches, - attributions, - strict=True, + for source_idx, (source, source_infos, grad, in_post_detach) in enumerate( + zip( + sources, + all_source_infos, + in_post_detach_grads, + in_post_detaches, + strict=True, + ) ): - weighted: Float[Tensor, "s dim"] = (grad * in_post_detach)[0] + # grad and in_post_detach: [N, seq, dim] + weighted: Float[Tensor, "N s dim"] = grad * in_post_detach if source == "wte": # Sum over embedding_dim to get single pseudo-component - weighted = weighted.sum(dim=1, keepdim=True) - - for s_in in s_in_range: - alive_c_in = [ - c - for c in source_info.alive_c_idxs - if source_info.alive_mask[0, s_in, c] - ] - for c_in in alive_c_in: - trimmed_c_in = source_info.c_to_trimmed[c_in] - attr[s_in, trimmed_c_in, s_out, trimmed_c_out] = weighted[ - s_in, c_in + weighted = weighted.sum(dim=2, keepdim=True) + + # Store attributions per-batch + for b in range(n_batch): + # Only store if c_out is alive in this batch at this position + if not target_infos[b].alive_mask[s_out, c_out]: + continue + if c_out not in target_infos[b].c_to_trimmed: + continue + trimmed_c_out = target_infos[b].c_to_trimmed[c_out] + + for s_in in s_in_range: + alive_c_in = [ + c + for c in source_infos[b].alive_c_idxs + if source_infos[b].alive_mask[s_in, c] ] - - for source, source_info, attr in zip(sources, source_infos, attributions, strict=True): - original_source_info = original_alive_info[source] - original_target_info = original_alive_info[target] - local_attributions.append( - PairAttribution( - source=source, - target=target, - attribution=attr, - trimmed_c_in_idxs=source_info.alive_c_idxs, - trimmed_c_out_idxs=target_info.alive_c_idxs, - is_kv_to_o_pair=is_kv_to_o_pair(source, target), - # Pass per-position alive masks (squeeze out batch dim) - original_alive_mask_in=original_source_info.alive_mask[0], # [seq, C] - original_alive_mask_out=original_target_info.alive_mask[0], # [seq, C] + for c_in in alive_c_in: + trimmed_c_in = source_infos[b].c_to_trimmed[c_in] + attributions[source_idx][b][ + s_in, trimmed_c_in, s_out, trimmed_c_out + ] = weighted[b, s_in, c_in] + + # Build output per set + for source_idx, (source, source_infos) in enumerate( + zip(sources, all_source_infos, strict=True) + ): + original_source_infos = original_alive_info[source] + original_target_infos = original_alive_info[target] + for b, set_name in enumerate(set_names): + local_attributions_by_set[set_name].append( + PairAttribution( + source=source, + target=target, + attribution=attributions[source_idx][b], + trimmed_c_in_idxs=source_infos[b].alive_c_idxs, + trimmed_c_out_idxs=target_infos[b].alive_c_idxs, + is_kv_to_o_pair=is_kv_to_o_pair(source, target), + original_alive_mask_in=original_source_infos[b].alive_mask, # [seq, C] + original_alive_mask_out=original_target_infos[b].alive_mask, # [seq, C] + ) ) - ) - return local_attributions, output_probs + return local_attributions_by_set, output_probs def main( @@ -289,9 +459,20 @@ def main( n_blocks: int, ci_threshold: float, output_prob_threshold: float, - prompt: str, - ci_vals_path: str | None, + ci_vals_path: str | Path | None, + prompts: list[str] | None = None, ) -> None: + """Compute local attributions for a prompt. + + Args: + wandb_path: WandB path to load model from. + n_blocks: Number of transformer blocks to analyze. + ci_threshold: Threshold for considering a component alive. + output_prob_threshold: Threshold for considering an output logit alive. + ci_vals_path: Path to JSON with precomputed CI values and prompts. + If provided, prompts are read from the JSON file. + prompts: List of prompts to use when ci_vals_path is None. Required if ci_vals_path is None. + """ loaded = load_model_from_wandb(wandb_path) model, config, device = loaded.model, loaded.config, loaded.device @@ -304,102 +485,97 @@ def main( for out_layer, in_layers in sources_by_target.items(): print(f" {out_layer} <- {in_layers}") - # Tokenize the prompt - print(f"\nPrompt: {prompt!r}") - tokens = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False) - assert isinstance(tokens, Tensor), "Expected Tensor" - tokens = tokens.to(device) - print(f"Tokens shape: {tokens.shape}") - print(f"Tokens: {tokens[0].tolist()}") - token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] - print(f"Token strings: {token_strings}") - - # Load precomputed CI values if path is provided - ci_lower_leaky: dict[str, Tensor] | None = None - if ci_vals_path is not None: - print(f"\nLoading precomputed CI values from {ci_vals_path}") - ci_lower_leaky = load_ci_from_json(ci_vals_path, prompt, device) - print(f"Loaded CI values for layers: {list(ci_lower_leaky.keys())}") + data = get_tokens_and_ci( + model=model, + tokenizer=tokenizer, + sampling=config.sampling, + device=device, + ci_vals_path=ci_vals_path, + prompts=prompts, + ) # Compute local attributions print("\nComputing local attributions...") - attr_pairs, output_probs = compute_local_attributions( + attr_pairs_by_set, output_probs = compute_local_attributions( model=model, - tokens=tokens, + tokens=data.tokens, sources_by_target=sources_by_target, ci_threshold=ci_threshold, output_prob_threshold=output_prob_threshold, sampling=config.sampling, device=device, - ci_lower_leaky=ci_lower_leaky, + ci_lower_leaky=data.ci_lower_leaky, + set_names=data.set_names, ) - # Print summary statistics - print("\nAttribution summary:") - for attr_pair in attr_pairs: - total = attr_pair.attribution.numel() - if total == 0: + # Print summary statistics per set + for set_name, attr_pairs in attr_pairs_by_set.items(): + print(f"\nAttribution summary for {set_name}:") + for attr_pair in attr_pairs: + total = attr_pair.attribution.numel() + if total == 0: + print( + f"Ignoring {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, zero" + ) + continue + nonzero = (attr_pair.attribution > 0).sum().item() print( - f"Ignoring {attr_pair.source} -> {attr_pair.target}: " - f"shape={list(attr_pair.attribution.shape)}, zero" + f" {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, " + f"nonzero={nonzero}/{total} ({100 * nonzero / (total + 1e-12):.2f}%), " + f"max={attr_pair.attribution.max():.6f}" ) - continue - nonzero = (attr_pair.attribution > 0).sum().item() - print( - f" {attr_pair.source} -> {attr_pair.target}: " - f"shape={list(attr_pair.attribution.shape)}, " - f"nonzero={nonzero}/{total} ({100 * nonzero / (total + 1e-12):.2f}%), " - f"max={attr_pair.attribution.max():.6f}" - ) - # Save attributions + # Save and plot per set out_dir = get_out_dir() - pt_path = out_dir / f"local_attributions_{loaded.wandb_id}.pt" - output_path = out_dir / f"local_attribution_graph_{loaded.wandb_id}.png" - if ci_vals_path is not None: - pt_path = pt_path.with_stem(pt_path.stem + "_with_ci_optim") - output_path = output_path.with_stem(output_path.stem + "_with_ci_optim") - - output_token_labels: dict[int, str] = {} - output_probs_by_pos: dict[tuple[int, int], float] = {} - for attr_pair in attr_pairs: - if attr_pair.target == "output": - for c_idx in attr_pair.trimmed_c_out_idxs: - if c_idx not in output_token_labels: - output_token_labels[c_idx] = tokenizer.decode([c_idx]) - # Store probability for each position - for s in range(tokens.shape[1]): - prob = output_probs[0, s, c_idx].item() - if prob >= output_prob_threshold: - output_probs_by_pos[(s, c_idx)] = prob - break - - save_data = { - "attr_pairs": attr_pairs, - "token_strings": token_strings, - "prompt": prompt, - "ci_threshold": ci_threshold, - "output_prob_threshold": output_prob_threshold, - "output_token_labels": output_token_labels, - "output_probs_by_pos": output_probs_by_pos, - "wandb_id": loaded.wandb_id, - } - torch.save(save_data, pt_path) - print(f"\nSaved local attributions to {pt_path}") - - fig = plot_local_graph( - attr_pairs=attr_pairs, - token_strings=token_strings, - output_token_labels=output_token_labels, - output_prob_threshold=output_prob_threshold, - output_probs_by_pos=output_probs_by_pos, - min_edge_weight=0.0001, - node_scale=30.0, - edge_alpha_scale=0.7, - ) - fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") - print(f"Saved figure to {output_path}") + for b, set_name in enumerate(data.set_names): + attr_pairs = attr_pairs_by_set[set_name] + pt_path = out_dir / f"local_attributions_{loaded.wandb_id}_{set_name}.pt" + output_path = out_dir / f"local_attribution_graph_{loaded.wandb_id}_{set_name}.png" + + output_token_labels: dict[int, str] = {} + output_probs_by_pos: dict[tuple[int, int], float] = {} + for attr_pair in attr_pairs: + if attr_pair.target == "output": + for c_idx in attr_pair.trimmed_c_out_idxs: + if c_idx not in output_token_labels: + output_token_labels[c_idx] = tokenizer.decode([c_idx]) + # Store probability for each position + for s in range(data.tokens.shape[1]): + prob = output_probs[b, s, c_idx].item() + if prob >= output_prob_threshold: + output_probs_by_pos[(s, c_idx)] = prob + break + + save_data = { + "attr_pairs": attr_pairs, + "token_strings": data.all_token_strings[b], + "prompt": data.prompts[b], + "ci_threshold": ci_threshold, + "output_prob_threshold": output_prob_threshold, + "output_token_labels": output_token_labels, + "output_probs_by_pos": output_probs_by_pos, + "wandb_id": loaded.wandb_id, + "set_name": set_name, + } + torch.save(save_data, pt_path) + print(f"\nSaved local attributions for {set_name} to {pt_path}") + + fig = plot_local_graph( + attr_pairs=attr_pairs, + token_strings=data.all_token_strings[b], + output_token_labels=output_token_labels, + output_prob_threshold=output_prob_threshold, + output_probs_by_pos=output_probs_by_pos, + min_edge_weight=0.0001, + node_scale=30.0, + edge_alpha_scale=0.7, + ) + + fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") + print(f"Saved figure for {set_name} to {output_path}") if __name__ == "__main__": @@ -412,17 +588,14 @@ def main( n_blocks = 1 ci_threshold = 1e-6 output_prob_threshold = 1e-1 - # prompt = "The quick brown fox" - # prompt = "Eagerly, a girl named Kim went" - prompt = "They walked hand in" - # Path to precomputed CI values from run_optim_cis.py (None to compute from model) - ci_vals_path: str | None = None - # ci_vals_path = "spd/scripts/optim_cis/out/optimized_ci_33n6xjjt.json" + # ci_vals_path = Path("spd/scripts/optim_cis/out/optimized_ci_33n6xjjt.json") + ci_vals_path = None + prompts = ["They walked hand in", "She is a happy"] main( wandb_path=wandb_path, n_blocks=n_blocks, ci_threshold=ci_threshold, output_prob_threshold=output_prob_threshold, - prompt=prompt, ci_vals_path=ci_vals_path, + prompts=prompts, ) diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 9f3a450b2..2038ed7da 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -421,11 +421,16 @@ def get_out_dir() -> Path: output_path = out_dir / f"optimized_ci_{loaded.wandb_id}.json" output_data = { - "config": config.model_dump(), - "prompt": config.prompt, - "token_strings": token_strings, - "optimized_ci": {k: v[0].cpu().tolist() for k, v in ci_outputs.lower_leaky.items()}, - "wandb_id": loaded.wandb_id, + "ci_sets": { + f"imp_min_{config.imp_min_config.coeff:.0e}": { + "config": config.model_dump(), + "prompt": config.prompt, + "tokens": tokens[0].tolist(), + "token_strings": token_strings, + "ci_vals": {k: v[0].cpu().tolist() for k, v in ci_outputs.lower_leaky.items()}, + "wandb_id": loaded.wandb_id, + }, + }, } with open(output_path, "w") as f: From 6aa657573b0867d5b7f4c3243b90a5b51cf0daa3 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 17:57:09 +0000 Subject: [PATCH 35/36] Misc cleaning --- spd/scripts/calc_local_attributions.py | 145 ++++++++++++------------- spd/scripts/optim_cis/run_optim_cis.py | 10 +- 2 files changed, 72 insertions(+), 83 deletions(-) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index f8998ea5b..fb2de2ac0 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, cast import torch from jaxtyping import Bool, Float, Int @@ -35,7 +35,7 @@ class TokensAndCI: tokens: Int[Tensor, "N seq"] ci_lower_leaky: dict[str, Float[Tensor, "N seq C"]] - set_names: list[str] + sample_names: list[str] prompts: list[str] all_token_strings: list[list[str]] @@ -94,55 +94,55 @@ def load_ci_from_json( Expected format: { - "ci_sets": { - "set_name_1": { - "prompt": "...", # Each set has its own prompt + "samples": { + "sample_name_1": { + "prompt": "...", # Each sample has its own prompt "ci_vals": {"layer1": [[...]], ...}, - ... + "metadata": {...}, }, - "set_name_2": {...}, + "sample_name_2": {...}, } } Args: - ci_vals_path: Path to JSON file with ci_sets structure + ci_vals_path: Path to JSON file with samples structure device: Device to load tensors to Returns: Tuple of: - - Dict mapping layer_name -> CI tensor of shape [N, seq, C] where N = number of sets - - List of set names in batch order - - List of prompts in batch order (one per set) + - Dict mapping layer_name -> CI tensor of shape [N, seq, C] where N = number of samples + - List of sample names in batch order + - List of prompts in batch order (one per sample) """ with open(ci_vals_path) as f: data = json.load(f) - ci_sets: dict[str, dict[str, Any]] = data["ci_sets"] - set_names = list(ci_sets.keys()) - assert len(set_names) > 0, "No CI sets found in JSON" + samples: dict[str, dict[str, Any]] = data["samples"] + sample_names = list(samples.keys()) + assert len(sample_names) > 0, "No samples found in JSON" - # Extract prompt from each set - prompts: list[str] = [ci_sets[set_name]["prompt"] for set_name in set_names] + # Extract prompt from each sample + prompts: list[str] = [samples[sample_name]["prompt"] for sample_name in sample_names] - # Get layer names from first set - first_set = ci_sets[set_names[0]] - layer_names = list(first_set["ci_vals"].keys()) + # Get layer names from first sample + first_sample = samples[sample_names[0]] + layer_names = list(first_sample["ci_vals"].keys()) # Stack tensors along batch dimension ci_lower_leaky: dict[str, Tensor] = {} for layer_name in layer_names: - # Collect [seq, C] tensors from each set + # Collect [seq, C] tensors from each sample layer_tensors = [ - torch.tensor(ci_sets[set_name]["ci_vals"][layer_name], device=device) - for set_name in set_names + torch.tensor(samples[sample_name]["ci_vals"][layer_name], device=device) + for sample_name in sample_names ] # Stack to [N, seq, C] ci_lower_leaky[layer_name] = torch.stack(layer_tensors, dim=0) - return ci_lower_leaky, set_names, prompts + return ci_lower_leaky, sample_names, prompts -def get_tokens_and_ci( +def get_tokens_and_ci_vals( model: ComponentModel, tokenizer: PreTrainedTokenizerFast, sampling: SamplingType, @@ -162,42 +162,27 @@ def get_tokens_and_ci( prompts: List of prompts to use when ci_vals_path is None. Returns: - TokensAndCI containing tokens, CI values, set names, prompts, and token strings. + TokensAndCI containing tokens, CI values, sample names, prompts, and token strings. """ if ci_vals_path is not None: print(f"\nLoading precomputed CI values from {ci_vals_path}") - ci_lower_leaky, set_names, prompts = load_ci_from_json(ci_vals_path, device) - print(f"Loaded CI values for layers: {list(ci_lower_leaky.keys())}") - print(f"CI sets: {set_names}") + ci_lower_leaky, sample_names, prompts = load_ci_from_json(ci_vals_path, device) else: assert prompts is not None, "prompts is required when ci_vals_path is None" - set_names = [f"prompt_{i}" for i in range(len(prompts))] + sample_names = [f"prompt_{i}" for i in range(len(prompts))] ci_lower_leaky = None - # Tokenize each prompt - tokens_list: list[Tensor] = [] - all_token_strings: list[list[str]] = [] - for i, p in enumerate(prompts): - if ci_vals_path is not None: - print(f"\nPrompt for {set_names[i]}: {p!r}") - toks = tokenizer.encode(p, return_tensors="pt", add_special_tokens=False) - assert isinstance(toks, Tensor), "Expected Tensor" - tokens_list.append(toks) - first_row = toks[0] - assert isinstance(first_row, Tensor), "Expected 2D tensor" - all_token_strings.append([tokenizer.decode([t]) for t in first_row.tolist()]) - if ci_vals_path is not None: - print(f" Token strings: {all_token_strings[-1]}") - - # Validate all prompts have the same token length - token_lengths = [t.shape[1] for t in tokens_list] - assert all(length == token_lengths[0] for length in token_lengths), ( - f"All prompts must tokenize to the same length, got {token_lengths}" - ) - - tokens = torch.cat(tokens_list, dim=0).to(device) # [N, seq] - print(f"\nTokens shape: {tokens.shape}") + try: + tokens: Int[Tensor, "N seq"] = cast( + Tensor, tokenizer(prompts, return_tensors="pt", add_special_tokens=False)["input_ids"] + ) + except ValueError as e: + e.add_note("NOTE: This script only supports prompts which tokenize to the same length") + raise e + all_token_strings: list[list[str]] = [ + [tokenizer.decode(t) for t in sample] for sample in tokens + ] # Compute CI values from model if not provided if ci_lower_leaky is None: print("\nComputing CI values from model...") @@ -212,7 +197,7 @@ def get_tokens_and_ci( return TokensAndCI( tokens=tokens, ci_lower_leaky=ci_lower_leaky, - set_names=set_names, + sample_names=sample_names, prompts=prompts, all_token_strings=all_token_strings, ) @@ -227,9 +212,9 @@ def compute_local_attributions( sampling: SamplingType, device: str, ci_lower_leaky: dict[str, Float[Tensor, "N seq C"]], - set_names: list[str], + sample_names: list[str], ) -> tuple[dict[str, list[PairAttribution]], Float[Tensor, "N seq vocab"]]: - """Compute local attributions for multiple prompts across multiple CI sets. + """Compute local attributions for multiple prompts across multiple samples. For each valid layer pair (in_layer, out_layer), computes the gradient-based attribution of output component activations with respect to input component @@ -244,18 +229,20 @@ def compute_local_attributions( sampling: Sampling type to use for causal importances. device: Device to run on. ci_lower_leaky: Precomputed/optimized CI values with shape [N, seq, C]. - set_names: Ordered list of set names corresponding to batch dimension. + sample_names: Ordered list of sample names corresponding to batch dimension. Returns: Tuple of: - - Dict mapping set_name -> list of PairAttribution objects + - Dict mapping sample_name -> list of PairAttribution objects - Output probabilities of shape [N, seq, vocab] """ n_batch, n_seq = tokens.shape # Validate batch size matches CI tensors first_ci = next(iter(ci_lower_leaky.values())) assert first_ci.shape[0] == n_batch, f"CI batch size {first_ci.shape[0]} != tokens {n_batch}" - assert len(set_names) == n_batch, f"set_names length {len(set_names)} != n_batch {n_batch}" + assert len(sample_names) == n_batch, ( + f"sample_names length {len(sample_names)} != n_batch {n_batch}" + ) with torch.no_grad(): output_with_cache: OutputWithCache = model(tokens, cache_type="input") @@ -271,7 +258,7 @@ def compute_local_attributions( ci_original = ci.lower_leaky # [N, seq, C] # Log the l0 (lower_leaky values > ci_threshold) for each layer - print("L0 values for final seq position (first CI set):") + print("L0 values for final seq position (first sample):") for layer, ci_vals in ci_lower_leaky.items(): # We only care about the final position, show first batch item l0_vals = (ci_vals[0, -1] > ci_threshold).sum().item() @@ -343,7 +330,9 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An alive_c_union[layer] = sorted(all_alive) # Initialize output dictionary - local_attributions_by_set: dict[str, list[PairAttribution]] = {name: [] for name in set_names} + local_attributions_by_sample: dict[str, list[PairAttribution]] = { + name: [] for name in sample_names + } for target, sources in tqdm(sources_by_target.items(), desc="Target layers"): target_infos = alive_info[target] # list of LayerAliveInfo per batch @@ -431,14 +420,14 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An s_in, trimmed_c_in, s_out, trimmed_c_out ] = weighted[b, s_in, c_in] - # Build output per set + # Build output per sample for source_idx, (source, source_infos) in enumerate( zip(sources, all_source_infos, strict=True) ): original_source_infos = original_alive_info[source] original_target_infos = original_alive_info[target] - for b, set_name in enumerate(set_names): - local_attributions_by_set[set_name].append( + for b, sample_name in enumerate(sample_names): + local_attributions_by_sample[sample_name].append( PairAttribution( source=source, target=target, @@ -451,7 +440,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An ) ) - return local_attributions_by_set, output_probs + return local_attributions_by_sample, output_probs def main( @@ -485,7 +474,7 @@ def main( for out_layer, in_layers in sources_by_target.items(): print(f" {out_layer} <- {in_layers}") - data = get_tokens_and_ci( + data = get_tokens_and_ci_vals( model=model, tokenizer=tokenizer, sampling=config.sampling, @@ -496,7 +485,7 @@ def main( # Compute local attributions print("\nComputing local attributions...") - attr_pairs_by_set, output_probs = compute_local_attributions( + attr_pairs_by_sample, output_probs = compute_local_attributions( model=model, tokens=data.tokens, sources_by_target=sources_by_target, @@ -505,12 +494,12 @@ def main( sampling=config.sampling, device=device, ci_lower_leaky=data.ci_lower_leaky, - set_names=data.set_names, + sample_names=data.sample_names, ) - # Print summary statistics per set - for set_name, attr_pairs in attr_pairs_by_set.items(): - print(f"\nAttribution summary for {set_name}:") + # Print summary statistics per sample + for sample_name, attr_pairs in attr_pairs_by_sample.items(): + print(f"\nAttribution summary for {sample_name}:") for attr_pair in attr_pairs: total = attr_pair.attribution.numel() if total == 0: @@ -527,13 +516,13 @@ def main( f"max={attr_pair.attribution.max():.6f}" ) - # Save and plot per set + # Save and plot per sample out_dir = get_out_dir() - for b, set_name in enumerate(data.set_names): - attr_pairs = attr_pairs_by_set[set_name] - pt_path = out_dir / f"local_attributions_{loaded.wandb_id}_{set_name}.pt" - output_path = out_dir / f"local_attribution_graph_{loaded.wandb_id}_{set_name}.png" + for b, sample_name in enumerate(data.sample_names): + attr_pairs = attr_pairs_by_sample[sample_name] + pt_path = out_dir / f"local_attributions_{loaded.wandb_id}_{sample_name}.pt" + output_path = out_dir / f"local_attribution_graph_{loaded.wandb_id}_{sample_name}.png" output_token_labels: dict[int, str] = {} output_probs_by_pos: dict[tuple[int, int], float] = {} @@ -558,10 +547,10 @@ def main( "output_token_labels": output_token_labels, "output_probs_by_pos": output_probs_by_pos, "wandb_id": loaded.wandb_id, - "set_name": set_name, + "sample_name": sample_name, } torch.save(save_data, pt_path) - print(f"\nSaved local attributions for {set_name} to {pt_path}") + print(f"\nSaved local attributions for {sample_name} to {pt_path}") fig = plot_local_graph( attr_pairs=attr_pairs, @@ -575,7 +564,7 @@ def main( ) fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") - print(f"Saved figure for {set_name} to {output_path}") + print(f"Saved figure for {sample_name} to {output_path}") if __name__ == "__main__": diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 2038ed7da..919f9747a 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -421,14 +421,14 @@ def get_out_dir() -> Path: output_path = out_dir / f"optimized_ci_{loaded.wandb_id}.json" output_data = { - "ci_sets": { + "samples": { f"imp_min_{config.imp_min_config.coeff:.0e}": { - "config": config.model_dump(), "prompt": config.prompt, - "tokens": tokens[0].tolist(), - "token_strings": token_strings, "ci_vals": {k: v[0].cpu().tolist() for k, v in ci_outputs.lower_leaky.items()}, - "wandb_id": loaded.wandb_id, + "metadata": { + "config": config.model_dump(), + "wandb_id": loaded.wandb_id, + }, }, }, } From e0793a5de454f95dcb82bda0576aefc7a8adfe61 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 2 Dec 2025 18:39:59 +0000 Subject: [PATCH 36/36] Misc minor cleaning --- spd/scripts/calc_local_attributions.py | 47 +++++++++++--------------- spd/scripts/optim_cis/run_optim_cis.py | 1 - 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py index fb2de2ac0..badc2a55e 100644 --- a/spd/scripts/calc_local_attributions.py +++ b/spd/scripts/calc_local_attributions.py @@ -52,7 +52,9 @@ def compute_layer_alive_info( ) -> list[LayerAliveInfo]: """Compute alive info for a layer across all batch items. - Returns list of LayerAliveInfo, one per batch item. + For wte, we create a pseudo-component that is always alive at all positions. + + Returns list of LayerAliveInfo, one per sample. """ if layer_name == "wte": # WTE: single pseudo-component, always alive at all positions @@ -62,28 +64,20 @@ def compute_layer_alive_info( # Same info for all batch items return [LayerAliveInfo(alive_mask, alive_c_idxs, c_to_trimmed) for _ in range(n_batch)] - elif layer_name == "output": + full_alive_mask: Bool[Tensor, "N seq C"] | Bool[Tensor, "N seq vocab"] + if layer_name == "output": assert output_probs is not None - # output_probs: [N, seq, vocab] full_alive_mask = output_probs >= output_prob_threshold # [N, seq, vocab] - results = [] - for b in range(n_batch): - batch_mask = full_alive_mask[b] # [seq, vocab] - alive_c_idxs = torch.where(batch_mask.any(dim=0))[0].tolist() - c_to_trimmed = {c: i for i, c in enumerate(alive_c_idxs)} - results.append(LayerAliveInfo(batch_mask, alive_c_idxs, c_to_trimmed)) - return results - else: - ci = ci_lower_leaky[layer_name] # [N, seq, C] - full_alive_mask = ci >= ci_threshold # [N, seq, C] - results = [] - for b in range(n_batch): - batch_mask = full_alive_mask[b] # [seq, C] - alive_c_idxs = torch.where(batch_mask.any(dim=0))[0].tolist() - c_to_trimmed = {c: i for i, c in enumerate(alive_c_idxs)} - results.append(LayerAliveInfo(batch_mask, alive_c_idxs, c_to_trimmed)) - return results + full_alive_mask = ci_lower_leaky[layer_name] >= ci_threshold # [N, seq, C] + + results = [] + for b in range(n_batch): + batch_mask: Bool[Tensor, "seq C"] | Bool[Tensor, "seq vocab"] = full_alive_mask[b] + alive_c_idxs: list[int] = torch.where(batch_mask.any(dim=0))[0].tolist() + c_to_trimmed: dict[int, int] = {c: i for i, c in enumerate(alive_c_idxs)} + results.append(LayerAliveInfo(batch_mask, alive_c_idxs, c_to_trimmed)) + return results def load_ci_from_json( @@ -330,9 +324,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An alive_c_union[layer] = sorted(all_alive) # Initialize output dictionary - local_attributions_by_sample: dict[str, list[PairAttribution]] = { - name: [] for name in sample_names - } + local_attributions: dict[str, list[PairAttribution]] = {name: [] for name in sample_names} for target, sources in tqdm(sources_by_target.items(), desc="Target layers"): target_infos = alive_info[target] # list of LayerAliveInfo per batch @@ -393,11 +385,10 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An strict=True, ) ): - # grad and in_post_detach: [N, seq, dim] - weighted: Float[Tensor, "N s dim"] = grad * in_post_detach + weighted: Float[Tensor, "N s C"] = grad * in_post_detach if source == "wte": # Sum over embedding_dim to get single pseudo-component - weighted = weighted.sum(dim=2, keepdim=True) + weighted = weighted.sum(dim=-1, keepdim=True) # Store attributions per-batch for b in range(n_batch): @@ -427,7 +418,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An original_source_infos = original_alive_info[source] original_target_infos = original_alive_info[target] for b, sample_name in enumerate(sample_names): - local_attributions_by_sample[sample_name].append( + local_attributions[sample_name].append( PairAttribution( source=source, target=target, @@ -440,7 +431,7 @@ def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> An ) ) - return local_attributions_by_sample, output_probs + return local_attributions, output_probs def main( diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py index 919f9747a..096314873 100644 --- a/spd/scripts/optim_cis/run_optim_cis.py +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -378,7 +378,6 @@ def get_out_dir() -> Path: tokens = tokenizer.encode(config.prompt, return_tensors="pt", add_special_tokens=False) assert isinstance(tokens, Tensor), "Expected Tensor" tokens = tokens.to(device) - print(f"Tokens shape: {tokens.shape}") token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] print(f"Token strings: {token_strings}") label_token_ids = tokenizer.encode(config.label, add_special_tokens=False)