diff --git a/spd/models/component_model.py b/spd/models/component_model.py index e95e150f0..026512940 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -90,10 +90,11 @@ class ComponentModel(LoadableModule): `LlamaForCausalLM`, `AutoModelForCausalLM`) as long as its sub-module names match the patterns you pass in `target_module_patterns`. - Forward passes support optional component replacement and/or input caching: + Forward passes support optional component replacement and/or caching: - No args: Standard forward pass of the target model - With mask_infos: Components replace the specified modules via forward hooks - With cache_type="input": Input activations are cached for the specified modules + - With cache_type="component_acts": Component activations are cached for the specified modules - Both can be used simultaneously for component forward pass with input caching We register components and causal importance functions (ci_fns) as modules in this class in order to have them update @@ -307,7 +308,7 @@ def __call__( self, *args: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["input"], + cache_type: Literal["component_acts", "input"], **kwargs: Any, ) -> OutputWithCache: ... @@ -329,31 +330,30 @@ def forward( self, *args: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["input", "none"] = "none", + cache_type: Literal["component_acts", "input", "none"] = "none", **kwargs: Any, ) -> Tensor | OutputWithCache: """Forward pass with optional component replacement and/or input caching. This method handles the following 4 cases: 1. mask_infos is None and cache_type is "none": Regular forward pass. - 2. mask_infos is None and cache_type is "input": Forward pass with input caching on - all modules in self.target_module_paths. - 3. mask_infos is not None and cache_type is "input": Forward pass with component replacement - and input caching on the modules provided in mask_infos. + 2. mask_infos is None and cache_type is "input" or "component_acts": Forward pass with + caching on all modules in self.target_module_paths. + 3. mask_infos is not None and cache_type is "input" or "component_acts": Forward pass with + component replacement and caching on the modules provided in mask_infos. 4. mask_infos is not None and cache_type is "none": Forward pass with component replacement on the modules provided in mask_infos and no caching. - We use the same _components_and_cache_hook for cases 2, 3, and 4, and don't use any hooks - for case 1. - Args: mask_infos: Dictionary mapping module names to ComponentsMaskInfo. If provided, those modules will be replaced with their components. - cache_type: If "input", cache the inputs to the modules provided in mask_infos. If - mask_infos is None, cache the inputs to all modules in self.target_module_paths. + cache_type: If "input" or "component_acts", cache the inputs or component acts to the + modules provided in mask_infos. If "none", no caching is done. If mask_infos is None, + cache the inputs or component acts to all modules in self.target_module_paths. Returns: - OutputWithCache object if cache_type is "input", otherwise the model output tensor. + OutputWithCache object if cache_type is "input" or "component_acts", otherwise the + model output tensor. """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. @@ -382,7 +382,7 @@ def forward( out = self._extract_output(raw_out) match cache_type: - case "input": + case "input" | "component_acts": return OutputWithCache(output=out, cache=cache) case "none": return out @@ -396,7 +396,7 @@ def _components_and_cache_hook( module_name: str, components: Components | None, mask_info: ComponentsMaskInfo | None, - cache_type: Literal["input", "none"], + cache_type: Literal["component_acts", "input", "none"], cache: dict[str, Tensor], ) -> Any | None: """Unified hook function that handles both component replacement and caching. @@ -409,7 +409,7 @@ def _components_and_cache_hook( module_name: Name of the module in the target model components: Component replacement (if using components) mask_info: Mask information (if using components) - cache_type: Whether to cache the input + cache_type: Whether to cache the component acts, input, or none cache: Cache dictionary to populate (if cache_type is not None) Returns: @@ -420,7 +420,6 @@ def _components_and_cache_hook( assert len(kwargs) == 0, "Expected no keyword arguments" x = args[0] assert isinstance(x, Tensor), "Expected input tensor" - assert cache_type in ["input", "none"], "Expected cache_type to be 'input' or 'none'" if cache_type == "input": cache[module_name] = x @@ -430,11 +429,16 @@ def _components_and_cache_hook( f"Only supports single-tensor outputs, got {type(output)}" ) + component_acts_cache = {} if cache_type == "component_acts" else None components_out = components( x, mask=mask_info.component_mask, weight_delta_and_mask=mask_info.weight_delta_and_mask, + component_acts_cache=component_acts_cache, ) + if component_acts_cache is not None: + for k, v in component_acts_cache.items(): + cache[f"{module_name}_{k}"] = v if mask_info.routing_mask == "all": return components_out diff --git a/spd/models/components.py b/spd/models/components.py index 61c63c3f5..0841b9949 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -190,6 +190,7 @@ def forward( x: Float[Tensor, "... d_in"], mask: Float[Tensor, "... C"] | None = None, weight_delta_and_mask: WeightDeltaAndMask | None = None, + component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = None, ) -> Float[Tensor, "... d_out"]: """Forward pass through V and U matrices. @@ -199,13 +200,18 @@ def forward( weight_delta_and_mask: Optional tuple of tensors containing: 0: the weight differences between the target model and summed component weights 1: mask over the weight delta component for each sample + component_acts_cache: Cache dictionary to populate with component acts Returns: output: The summed output across all components """ component_acts = self.get_inner_acts(x) + if component_acts_cache is not None: + component_acts_cache["pre_detach"] = component_acts + component_acts = component_acts.detach().requires_grad_(True) + component_acts_cache["post_detach"] = component_acts if mask is not None: - component_acts *= mask + component_acts = component_acts * mask out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out") @@ -254,6 +260,7 @@ def forward( x: Int[Tensor, "..."], mask: Float[Tensor, "... C"] | None = None, weight_delta_and_mask: WeightDeltaAndMask | None = None, + component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = None, ) -> Float[Tensor, "... embedding_dim"]: """Forward through the embedding component using indexing instead of one-hot matmul. @@ -263,13 +270,19 @@ def forward( weight_delta_and_mask: Optional tuple of tensors containing: 0: the weight differences between the target model and summed component weights 1: mask over the weight delta component for each sample + component_acts_cache: Cache dictionary to populate with component acts """ assert x.dtype == torch.long, "x must be an integer tensor" component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x) + if component_acts_cache is not None: + component_acts_cache["pre_detach"] = component_acts + component_acts = component_acts.detach().requires_grad_(True) + component_acts_cache["post_detach"] = component_acts + if mask is not None: - component_acts *= mask + component_acts = component_acts * mask out = einops.einsum(component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim") diff --git a/spd/scripts/calc_global_attributions.py b/spd/scripts/calc_global_attributions.py new file mode 100644 index 000000000..62134d404 --- /dev/null +++ b/spd/scripts/calc_global_attributions.py @@ -0,0 +1,620 @@ +# %% + +import gzip +import json +from collections import defaultdict +from collections.abc import Iterable +from typing import Any + +import torch +from jaxtyping import Float +from PIL import Image +from torch import Tensor, nn +from tqdm.auto import tqdm + +from spd.configs import Config +from spd.models.component_model import ComponentModel, OutputWithCache +from spd.models.components import make_mask_infos +from spd.plotting import plot_mean_component_cis_both_scales +from spd.scripts.model_loading import ( + create_data_loader_from_config, + get_out_dir, + load_model_from_wandb, +) +from spd.utils.general_utils import extract_batch_data + + +def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: + """Check if pair requires per-sequence-position gradient computation. + + For k/v → o_proj within the same attention block, output at s_out + has gradients w.r.t. inputs at all s_in ≤ s_out (causal attention). + """ + in_is_kv = any(x in in_layer for x in ["k_proj", "v_proj"]) + out_is_o = "o_proj" in out_layer + if not (in_is_kv and out_is_o): + return False + + # Check same attention block: "h.{idx}.attn.{proj}" + in_block = in_layer.split(".")[1] + out_block = out_layer.split(".")[1] + return in_block == out_block + + +# %% +def compute_mean_ci_per_component( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + max_batches: int | None, +) -> dict[str, Tensor]: + """Compute mean causal importance per component over the dataset. + + Also computes mean output probability per vocab token. + + Args: + model: The ComponentModel to analyze. + data_loader: DataLoader providing batches. + device: Device to run on. + config: SPD config with sampling settings. + max_batches: Maximum number of batches to process. + + Returns: + Dictionary mapping module path -> tensor of shape [C] with mean CI per component. + Also includes "output" -> tensor of shape [vocab_size] with mean output probability. + """ + # Initialize accumulators + ci_sums: dict[str, Tensor] = { + module_name: torch.zeros(model.C, device=device) for module_name in model.components + } + examples_seen: dict[str, int] = {module_name: 0 for module_name in model.components} + + # Output prob accumulators (initialized on first batch to get vocab_size) + output_prob_sum: Tensor | None = None + output_examples_seen = 0 + + if max_batches is not None: + batch_pbar = tqdm(enumerate(data_loader), desc="Computing mean CI", total=max_batches) + else: + batch_pbar = tqdm(enumerate(data_loader), desc="Computing mean CI") + + for batch_idx, batch_raw in batch_pbar: + if max_batches is not None and batch_idx >= max_batches: + break + + batch = extract_batch_data(batch_raw).to(device) + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + # Accumulate CI values (using lower_leaky as in CIMeanPerComponent) + for module_name, ci_vals in ci.lower_leaky.items(): + n_leading_dims = ci_vals.ndim - 1 + n_examples = ci_vals.shape[:n_leading_dims].numel() + examples_seen[module_name] += n_examples + leading_dim_idxs = tuple(range(n_leading_dims)) + ci_sums[module_name] += ci_vals.sum(dim=leading_dim_idxs) + + # Accumulate output probabilities + output_probs = torch.softmax(output_with_cache.output, dim=-1) # [b, s, vocab] + if output_prob_sum is None: + vocab_size = output_probs.shape[-1] + output_prob_sum = torch.zeros(vocab_size, device=device) + output_prob_sum += output_probs.sum(dim=(0, 1)) + output_examples_seen += output_probs.shape[0] * output_probs.shape[1] + + # Compute means + mean_cis: dict[str, Tensor] = { + module_name: ci_sums[module_name] / examples_seen[module_name] + for module_name in model.components + } + assert output_prob_sum is not None, "No batches processed" + mean_cis["output"] = output_prob_sum / output_examples_seen + + return mean_cis + + +def compute_alive_components( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + max_batches: int | None, + threshold: float, + output_mean_prob_threshold: float, +) -> tuple[dict[str, Tensor], dict[str, list[int]], tuple[Image.Image, Image.Image]]: + """Compute alive components based on mean CI threshold. + + Args: + model: The ComponentModel to analyze. + data_loader: DataLoader providing batches. + device: Device to run on. + config: SPD config with sampling settings. + max_batches: Maximum number of batches to process. + threshold: Minimum mean CI to consider a component alive. + output_mean_prob_threshold: Minimum mean output probability to consider a token alive. + + Returns: + Tuple of: + - mean_cis: Dictionary mapping module path -> tensor of mean CI per component + (includes "output" key with mean output probabilities) + - alive_indices: Dictionary mapping module path -> list of alive component indices + (includes "output" key) + - images: Tuple of (linear_scale_image, log_scale_image) for verification + """ + mean_cis = compute_mean_ci_per_component(model, data_loader, device, config, max_batches) + + alive_indices = {} + for module_name, mean_val in mean_cis.items(): + # Use output_mean_prob_threshold for output layer, threshold for components + thresh = output_mean_prob_threshold if module_name == "output" else threshold + alive_mask = mean_val >= thresh + alive_indices[module_name] = torch.where(alive_mask)[0].tolist() + + images = plot_mean_component_cis_both_scales(mean_cis) + + return mean_cis, alive_indices, images + + +def get_sources_by_target( + model: ComponentModel, + device: str, + config: Config, + n_blocks: int, +) -> dict[str, list[str]]: + """Find valid gradient connections grouped by target layer. + + Returns: + Dict mapping out_layer -> list of in_layers that have gradient flow to it. + """ + # Use a small dummy batch - we only need to trace gradient connections + batch: Float[Tensor, "batch seq"] = torch.zeros(2, 3, dtype=torch.long, device=device) + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + print(f"Output shape: {output_with_cache.output.shape}") + + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + # Create masks so we can use all components (without masks) + mask_infos = make_mask_infos( + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + routing_masks="all", + ) + + wte_cache: dict[str, Tensor] = {} + + # Add an extra forward hook to the model to cache the output of model.target_model.wte + def wte_hook( + _module: nn.Module, _args: tuple[Any, ...], _kwargs: dict[Any, Any], output: Tensor + ) -> Any: + output.requires_grad_(True) + # We call it "post_detach" for consistency, we don't bother detaching here as there are + # no modules before it that we care about + wte_cache["wte_post_detach"] = output + return output + + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + + with torch.enable_grad(): + comp_output_with_cache_grad: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + wte_handle.remove() + + cache = comp_output_with_cache_grad.cache + cache["wte_post_detach"] = wte_cache["wte_post_detach"] + cache["output_pre_detach"] = comp_output_with_cache_grad.output + + layers = ["wte"] + component_layers = [ + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "attn.o_proj", + "mlp.c_fc", + "mlp.down_proj", + ] + for i in range(n_blocks): + layers.extend([f"h.{i}.{layer_name}" for layer_name in component_layers]) + layers.append("output") + + test_pairs = [] + for in_layer in layers[:-1]: # Don't include "output" in + for out_layer in layers[1:]: # Don't include "wte" in out_layers + if layers.index(in_layer) < layers.index(out_layer): + test_pairs.append((in_layer, out_layer)) + + sources_by_target: dict[str, list[str]] = defaultdict(list) + for in_layer, out_layer in test_pairs: + out_pre_detach = cache[f"{out_layer}_pre_detach"] + in_post_detach = cache[f"{in_layer}_post_detach"] + out_value = out_pre_detach[0, 0, 0] # Pick arbitrary value + grads = torch.autograd.grad( + outputs=out_value, + inputs=in_post_detach, + retain_graph=True, + allow_unused=True, + ) + assert len(grads) == 1, "Expected 1 gradient" + grad = grads[0] + if grad is not None: # pyright: ignore[reportUnnecessaryComparison] + sources_by_target[out_layer].append(in_layer) + return dict(sources_by_target) + + +def compute_global_attributions( + model: ComponentModel, + data_loader: Iterable[dict[str, Any]], + device: str, + config: Config, + sources_by_target: dict[str, list[str]], + max_batches: int, + alive_indices: dict[str, list[int]], + ci_attribution_threshold: float, + output_mean_prob_threshold: float, +) -> dict[tuple[str, str], Tensor]: + """Compute global attributions accumulated over the dataset. + + For each valid layer pair (in_layer, out_layer), computes the mean absolute gradient + of output component activations with respect to input component activations, + averaged over batch, sequence positions, and number of batches. + + Optimization: For each target layer, we batch all source layers into a single + autograd.grad call, sharing backward computation. + + Args: + model: The ComponentModel to analyze. + data_loader: DataLoader providing batches. + device: Device to run on. + config: SPD config with sampling settings. + sources_by_target: Dict mapping out_layer -> list of in_layers. + max_batches: Maximum number of batches to process. + alive_indices: Dictionary mapping module path -> list of alive component indices. + ci_attribution_threshold: Threshold for considering a component for the attribution calculation. + output_mean_prob_threshold: Threshold for considering an output logit alive (on softmax probs). + + Returns: + Dictionary mapping (in_layer, out_layer) -> attribution tensor of shape [n_alive_in, n_alive_out] + where attribution[i, j] is the mean absolute gradient from the i-th alive input component + to the j-th alive output component. + """ + alive_indices["wte"] = [0] # Treat wte as single alive component + + # Initialize accumulators for each (in_layer, out_layer) pair + attribution_sums: dict[tuple[str, str], Float[Tensor, "n_alive_in n_alive_out"]] = {} + samples_per_pair: dict[tuple[str, str], int] = {} + for out_layer, in_layers in sources_by_target.items(): + for in_layer in in_layers: + n_alive_in = len(alive_indices[in_layer]) + n_alive_out = len(alive_indices[out_layer]) + attribution_sums[(in_layer, out_layer)] = torch.zeros( + n_alive_in, n_alive_out, device=device + ) + samples_per_pair[(in_layer, out_layer)] = 0 + + # Set up wte hook + wte_handle = None + wte_cache: dict[str, Tensor] = {} + + def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> Any: + output.requires_grad_(True) + wte_cache["wte_post_detach"] = output + return output + + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + + batch_pbar = tqdm(enumerate(data_loader), desc="Batches", total=max_batches) + for batch_idx, batch_raw in batch_pbar: + if batch_idx >= max_batches: + break + + batch: Float[Tensor, "b s C"] = extract_batch_data(batch_raw).to(device) + + batch_size, n_seq = batch.shape + + # Forward pass to get pre-weight activations + with torch.no_grad(): + output_with_cache: OutputWithCache = model(batch, cache_type="input") + + # Calculate causal importances for masking + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + + mask_infos = make_mask_infos( + # component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + component_masks=ci.lower_leaky, + routing_masks="all", + ) + + with torch.enable_grad(): + comp_output_with_cache: OutputWithCache = model( + batch, + mask_infos=mask_infos, + cache_type="component_acts", + ) + + cache = comp_output_with_cache.cache + + # Add wte to cache and CI + cache["wte_post_detach"] = wte_cache["wte_post_detach"] + # Add fake CI for wte: shape (b, s, embedding_dim) with 1.0 at index 0 + ci.lower_leaky["wte"] = torch.zeros_like(wte_cache["wte_post_detach"]) + ci.lower_leaky["wte"][:, :, 0] = 1.0 + + # Add output to cache and CI + cache["output_pre_detach"] = comp_output_with_cache.output + # Use output probs as fake CI for output layer + output_probs: Float[Tensor, "b s vocab"] = torch.softmax( + comp_output_with_cache.output, dim=-1 + ) + # Only consider tokens above threshold as "alive" for this batch + ci.lower_leaky["output"] = torch.where( + output_probs >= output_mean_prob_threshold, output_probs, torch.zeros_like(output_probs) + ) + + # Compute attributions grouped by target layer + for out_layer, in_layers in tqdm( + sources_by_target.items(), desc="Target layers", leave=False + ): + out_pre_detach: Float[Tensor, "b s C"] = cache[f"{out_layer}_pre_detach"] + alive_out: list[int] = alive_indices[out_layer] + ci_out = ci.lower_leaky[out_layer] + + # Gather all input tensors for this target layer + in_tensors = [cache[f"{in_layer}_post_detach"] for in_layer in in_layers] + + # Initialize batch attributions for each input layer + batch_attributions = { + in_layer: torch.zeros(len(alive_indices[in_layer]), len(alive_out), device=device) + for in_layer in in_layers + } + + # NOTE: o->q will be treated as an attention pair even though there are no attrs + # across sequence positions. This is just so we don't have to special case it. + is_attention_output = any( + is_kv_to_o_pair(in_layer, out_layer) for in_layer in in_layers + ) + + grad_outputs: Float[Tensor, "b s C"] = torch.zeros_like(out_pre_detach) + + for c_enum, c_idx in tqdm( + enumerate(alive_out), desc="Components", leave=False, total=len(alive_out) + ): + if is_attention_output: + # Attention target: loop over output sequence positions + for s_out in range(n_seq): + if ci_out[:, s_out, c_idx].sum() <= ci_attribution_threshold: + continue + grad_outputs.zero_() + grad_outputs[:, s_out, c_idx] = ci_out[:, s_out, c_idx].detach() + + # Single autograd call for all input layers + grads_tuple = torch.autograd.grad( + outputs=out_pre_detach, + inputs=in_tensors, + grad_outputs=grad_outputs, + retain_graph=True, + ) + + with torch.no_grad(): + for i, in_layer in enumerate(in_layers): + grads = grads_tuple[i] + assert grads is not None, f"Gradient is None for {in_layer}" + weighted = grads * in_tensors[i] + + # Special handling for wte: sum over embedding_dim to get single component + if in_layer == "wte": + # weighted is (b, s, embedding_dim), sum to (b, s, 1) + weighted = weighted.sum(dim=-1, keepdim=True) + alive_in = [0] + else: + alive_in = alive_indices[in_layer] + + # Only sum contributions from positions s_in <= s_out (causal) + weighted_alive = weighted[:, : s_out + 1, alive_in] + batch_attributions[in_layer][:, c_enum] += weighted_alive.pow( + 2 + ).sum(dim=(0, 1)) + else: + # Non-attention target: vectorize over all (b, s) positions + if ci_out[:, :, c_idx].sum() <= ci_attribution_threshold: + continue + grad_outputs.zero_() + grad_outputs[:, :, c_idx] = ci_out[:, :, c_idx].detach() + + # Single autograd call for all input layers + grads_tuple = torch.autograd.grad( + outputs=out_pre_detach, + inputs=in_tensors, + grad_outputs=grad_outputs, + retain_graph=True, + allow_unused=True, + ) + + with torch.no_grad(): + for i, in_layer in enumerate(in_layers): + grads = grads_tuple[i] + assert grads is not None, f"Gradient is None for {in_layer}" + weighted = grads * in_tensors[i] + + # Special handling for wte: sum over embedding_dim to get single component + if in_layer == "wte": + # weighted is (b, s, embedding_dim), sum to (b, s, 1) + weighted = weighted.sum(dim=-1, keepdim=True) + alive_in = [0] + else: + alive_in = alive_indices[in_layer] + + weighted_alive = weighted[:, :, alive_in] + batch_attributions[in_layer][:, c_enum] += weighted_alive.pow(2).sum( + dim=(0, 1) + ) + + # Accumulate batch results and track samples + for in_layer in in_layers: + attribution_sums[(in_layer, out_layer)] += batch_attributions[in_layer] + # Track samples: attention pairs have triangular sum due to causal masking + if is_attention_output: + samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq * (n_seq + 1) // 2 + else: + samples_per_pair[(in_layer, out_layer)] += batch_size * n_seq + + wte_handle.remove() + + global_attributions = {} + for pair, attr_sum in attribution_sums.items(): + attr: Float[Tensor, "n_alive_in n_alive_out"] = attr_sum / samples_per_pair[pair] + global_attributions[pair] = attr / attr.sum(dim=1, keepdim=True) + + n_pairs = len(attribution_sums) + total_samples = sum(samples_per_pair.values()) // n_pairs if n_pairs else 0 + print(f"Computed global attributions over ~{total_samples} samples per pair") + return global_attributions + + +if __name__ == "__main__": + # %% + # Configuration + # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) + # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L + # wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) + wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (New) + n_blocks = 1 + # batch_size = 1024 + batch_size = 128 + n_ctx = 64 + # n_attribution_batches = 20 + n_attribution_batches = 5 + n_alive_calc_batches = 100 + # n_alive_calc_batches = 200 + ci_mean_alive_threshold = 1e-6 + ci_attribution_threshold = 1e-6 + output_mean_prob_threshold = 1e-8 + dataset_seed = 0 + + out_dir = get_out_dir() + + # Load model using shared utility + loaded = load_model_from_wandb(wandb_path) + model = loaded.model + config = loaded.config + device = loaded.device + wandb_id = loaded.wandb_id + + # Load the dataset + data_loader, tokenizer = create_data_loader_from_config( + config=config, + batch_size=batch_size, + n_ctx=n_ctx, + seed=dataset_seed, + ) + + sources_by_target = get_sources_by_target(model, device, config, n_blocks) + n_pairs = sum(len(ins) for ins in sources_by_target.values()) + print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") + for out_layer, in_layers in sources_by_target.items(): + print(f" {out_layer} <- {in_layers}") + # %% + # Compute alive components based on mean CI threshold (and output probability threshold) + print("\nComputing alive components...") + mean_cis, alive_indices, (img_linear, img_log) = compute_alive_components( + model=model, + data_loader=data_loader, + device=device, + config=config, + max_batches=n_alive_calc_batches, + threshold=ci_mean_alive_threshold, + output_mean_prob_threshold=output_mean_prob_threshold, + ) + + # Print summary + print("\nAlive components per layer:") + for module_name, indices in alive_indices.items(): + n_alive = len(indices) + print(f" {module_name}: {n_alive} alive") + + # Save images for verification + img_linear.save(out_dir / f"ci_mean_per_component_linear_{wandb_id}.png") + img_log.save(out_dir / f"ci_mean_per_component_log_{wandb_id}.png") + print( + f"Saved verification images to {out_dir / f'ci_mean_per_component_linear_{wandb_id}.png'} and {out_dir / f'ci_mean_per_component_log_{wandb_id}.png'}" + ) + # %% + # Compute global attributions over the dataset + print("\nComputing global attributions...") + global_attributions = compute_global_attributions( + model=model, + data_loader=data_loader, + device=device, + config=config, + sources_by_target=sources_by_target, + max_batches=n_attribution_batches, + alive_indices=alive_indices, + ci_attribution_threshold=ci_attribution_threshold, + output_mean_prob_threshold=output_mean_prob_threshold, + ) + + # Print summary statistics + for pair, attr in global_attributions.items(): + print(f"{pair[0]} -> {pair[1]}: mean={attr.mean():.6f}, max={attr.max():.6f}") + + # %% + # Save attributions in both PyTorch and JSON formats + print("\nSaving attribution data...") + + # Save PyTorch format + pt_path = out_dir / f"global_attributions_{wandb_id}.pt" + torch.save(global_attributions, pt_path) + print(f"Saved PyTorch format to {pt_path}") + + # Convert and save JSON format for web visualization + attributions_json = {} + for (in_layer, out_layer), attr_tensor in global_attributions.items(): + key = f"('{in_layer}', '{out_layer}')" + # Keep full precision - just convert to list + attributions_json[key] = attr_tensor.cpu().tolist() + + json_data = { + "n_blocks": n_blocks, + "attributions": attributions_json, + "alive_indices": alive_indices, + } + + json_path = out_dir / f"global_attributions_{wandb_id}.json" + + # Write JSON with compact formatting + with open(json_path, "w") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) + + # Also save a compressed version for very large files + gz_path = out_dir / f"global_attributions_{wandb_id}.json.gz" + with gzip.open(gz_path, "wt", encoding="utf-8") as f: + json.dump(json_data, f, separators=(",", ":"), ensure_ascii=False) + + print(f"Saved JSON format to {json_path}") + print(f"Saved compressed format to {gz_path}") + print(f" - {len(attributions_json)} layer pairs") + print(f" - {sum(len(v) for v in alive_indices.values())} total alive components") + print(f"\nTo visualize: Open scripts/plot_attributions.html and load {json_path}") + + # %% diff --git a/spd/scripts/calc_local_attributions.py b/spd/scripts/calc_local_attributions.py new file mode 100644 index 000000000..badc2a55e --- /dev/null +++ b/spd/scripts/calc_local_attributions.py @@ -0,0 +1,581 @@ +"""Compute local attributions for a single prompt.""" + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor, nn +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from spd.configs import SamplingType +from spd.models.component_model import ComponentModel, OutputWithCache +from spd.models.components import make_mask_infos +from spd.scripts.calc_global_attributions import get_sources_by_target, is_kv_to_o_pair +from spd.scripts.model_loading import get_out_dir, load_model_from_wandb +from spd.scripts.plot_local_attributions import PairAttribution, plot_local_graph + + +@dataclass +class LayerAliveInfo: + """Info about alive components for a layer (single batch item).""" + + alive_mask: Bool[Tensor, "s dim"] # Which (pos, component) pairs are alive + alive_c_idxs: list[int] # Components alive at any position + c_to_trimmed: dict[int, int] # original idx -> trimmed idx + + +@dataclass +class TokensAndCI: + """Tokenized prompts and their CI values.""" + + tokens: Int[Tensor, "N seq"] + ci_lower_leaky: dict[str, Float[Tensor, "N seq C"]] + sample_names: list[str] + prompts: list[str] + all_token_strings: list[list[str]] + + +def compute_layer_alive_info( + layer_name: str, + ci_lower_leaky: dict[str, Tensor], + output_probs: Float[Tensor, "N s vocab"] | None, + ci_threshold: float, + output_prob_threshold: float, + n_seq: int, + n_batch: int, + device: str, +) -> list[LayerAliveInfo]: + """Compute alive info for a layer across all batch items. + + For wte, we create a pseudo-component that is always alive at all positions. + + Returns list of LayerAliveInfo, one per sample. + """ + if layer_name == "wte": + # WTE: single pseudo-component, always alive at all positions + alive_mask = torch.ones(n_seq, 1, device=device, dtype=torch.bool) + alive_c_idxs = [0] + c_to_trimmed = {0: 0} + # Same info for all batch items + return [LayerAliveInfo(alive_mask, alive_c_idxs, c_to_trimmed) for _ in range(n_batch)] + + full_alive_mask: Bool[Tensor, "N seq C"] | Bool[Tensor, "N seq vocab"] + if layer_name == "output": + assert output_probs is not None + full_alive_mask = output_probs >= output_prob_threshold # [N, seq, vocab] + else: + full_alive_mask = ci_lower_leaky[layer_name] >= ci_threshold # [N, seq, C] + + results = [] + for b in range(n_batch): + batch_mask: Bool[Tensor, "seq C"] | Bool[Tensor, "seq vocab"] = full_alive_mask[b] + alive_c_idxs: list[int] = torch.where(batch_mask.any(dim=0))[0].tolist() + c_to_trimmed: dict[int, int] = {c: i for i, c in enumerate(alive_c_idxs)} + results.append(LayerAliveInfo(batch_mask, alive_c_idxs, c_to_trimmed)) + return results + + +def load_ci_from_json( + ci_vals_path: str | Path, + device: str, +) -> tuple[dict[str, Float[Tensor, "N seq C"]], list[str], list[str]]: + """Load precomputed CI values from a JSON file. + + Expected format: + { + "samples": { + "sample_name_1": { + "prompt": "...", # Each sample has its own prompt + "ci_vals": {"layer1": [[...]], ...}, + "metadata": {...}, + }, + "sample_name_2": {...}, + } + } + + Args: + ci_vals_path: Path to JSON file with samples structure + device: Device to load tensors to + + Returns: + Tuple of: + - Dict mapping layer_name -> CI tensor of shape [N, seq, C] where N = number of samples + - List of sample names in batch order + - List of prompts in batch order (one per sample) + """ + with open(ci_vals_path) as f: + data = json.load(f) + + samples: dict[str, dict[str, Any]] = data["samples"] + sample_names = list(samples.keys()) + assert len(sample_names) > 0, "No samples found in JSON" + + # Extract prompt from each sample + prompts: list[str] = [samples[sample_name]["prompt"] for sample_name in sample_names] + + # Get layer names from first sample + first_sample = samples[sample_names[0]] + layer_names = list(first_sample["ci_vals"].keys()) + + # Stack tensors along batch dimension + ci_lower_leaky: dict[str, Tensor] = {} + for layer_name in layer_names: + # Collect [seq, C] tensors from each sample + layer_tensors = [ + torch.tensor(samples[sample_name]["ci_vals"][layer_name], device=device) + for sample_name in sample_names + ] + # Stack to [N, seq, C] + ci_lower_leaky[layer_name] = torch.stack(layer_tensors, dim=0) + + return ci_lower_leaky, sample_names, prompts + + +def get_tokens_and_ci_vals( + model: ComponentModel, + tokenizer: PreTrainedTokenizerFast, + sampling: SamplingType, + device: str, + ci_vals_path: str | Path | None, + prompts: list[str] | None, +) -> TokensAndCI: + """Get tokenized prompts and CI values, either from file or computed from model. + + Args: + model: The ComponentModel to use for computing CI if needed. + tokenizer: Tokenizer for the prompts. + sampling: Sampling type for CI computation. + device: Device to place tensors on. + ci_vals_path: Path to JSON with precomputed CI values and prompts. + If provided, prompts are read from the JSON file. + prompts: List of prompts to use when ci_vals_path is None. + + Returns: + TokensAndCI containing tokens, CI values, sample names, prompts, and token strings. + """ + if ci_vals_path is not None: + print(f"\nLoading precomputed CI values from {ci_vals_path}") + ci_lower_leaky, sample_names, prompts = load_ci_from_json(ci_vals_path, device) + else: + assert prompts is not None, "prompts is required when ci_vals_path is None" + sample_names = [f"prompt_{i}" for i in range(len(prompts))] + ci_lower_leaky = None + + try: + tokens: Int[Tensor, "N seq"] = cast( + Tensor, tokenizer(prompts, return_tensors="pt", add_special_tokens=False)["input_ids"] + ) + except ValueError as e: + e.add_note("NOTE: This script only supports prompts which tokenize to the same length") + raise e + + all_token_strings: list[list[str]] = [ + [tokenizer.decode(t) for t in sample] for sample in tokens + ] + # Compute CI values from model if not provided + if ci_lower_leaky is None: + print("\nComputing CI values from model...") + with torch.no_grad(): + output_with_cache = model(tokens, cache_type="input") + ci_lower_leaky = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ).lower_leaky + + return TokensAndCI( + tokens=tokens, + ci_lower_leaky=ci_lower_leaky, + sample_names=sample_names, + prompts=prompts, + all_token_strings=all_token_strings, + ) + + +def compute_local_attributions( + model: ComponentModel, + tokens: Int[Tensor, "N seq"], + sources_by_target: dict[str, list[str]], + ci_threshold: float, + output_prob_threshold: float, + sampling: SamplingType, + device: str, + ci_lower_leaky: dict[str, Float[Tensor, "N seq C"]], + sample_names: list[str], +) -> tuple[dict[str, list[PairAttribution]], Float[Tensor, "N seq vocab"]]: + """Compute local attributions for multiple prompts across multiple samples. + + For each valid layer pair (in_layer, out_layer), computes the gradient-based + attribution of output component activations with respect to input component + activations, preserving sequence position information. + + Args: + model: The ComponentModel to analyze. + tokens: Tokenized prompts of shape [N, seq_len] - one per CI set. + sources_by_target: Dict mapping out_layer -> list of in_layers. + ci_threshold: Threshold for considering a component alive at a position. + output_prob_threshold: Threshold for considering an output logit alive (on softmax probs). + sampling: Sampling type to use for causal importances. + device: Device to run on. + ci_lower_leaky: Precomputed/optimized CI values with shape [N, seq, C]. + sample_names: Ordered list of sample names corresponding to batch dimension. + + Returns: + Tuple of: + - Dict mapping sample_name -> list of PairAttribution objects + - Output probabilities of shape [N, seq, vocab] + """ + n_batch, n_seq = tokens.shape + # Validate batch size matches CI tensors + first_ci = next(iter(ci_lower_leaky.values())) + assert first_ci.shape[0] == n_batch, f"CI batch size {first_ci.shape[0]} != tokens {n_batch}" + assert len(sample_names) == n_batch, ( + f"sample_names length {len(sample_names)} != n_batch {n_batch}" + ) + + with torch.no_grad(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + + # Always compute original CI from model (needed for ghost nodes when using optimized CI) + # Note: original CI is computed from single-batch input, then repeated for comparison + with torch.no_grad(): + ci = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + ci_original = ci.lower_leaky # [N, seq, C] + + # Log the l0 (lower_leaky values > ci_threshold) for each layer + print("L0 values for final seq position (first sample):") + for layer, ci_vals in ci_lower_leaky.items(): + # We only care about the final position, show first batch item + l0_vals = (ci_vals[0, -1] > ci_threshold).sum().item() + print(f" Layer {layer} has {l0_vals} components alive at {ci_threshold}") + + wte_cache: dict[str, Tensor] = {} + + def wte_hook(_module: nn.Module, _args: Any, _kwargs: Any, output: Tensor) -> Any: + output.requires_grad_(True) + # We call it "post_detach" for consistency, we don't bother detaching here as there are + # no modules before it that we care about + wte_cache["wte_post_detach"] = output + return output + + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + + mask_infos = make_mask_infos(component_masks=ci_lower_leaky, routing_masks="all") + with torch.enable_grad(): + comp_output_with_cache: OutputWithCache = model( + tokens, mask_infos=mask_infos, cache_type="component_acts" + ) + wte_handle.remove() + + cache = comp_output_with_cache.cache + cache["wte_post_detach"] = wte_cache["wte_post_detach"] + cache["output_pre_detach"] = comp_output_with_cache.output + + # Compute output probabilities for thresholding + output_probs = torch.softmax(comp_output_with_cache.output, dim=-1) + + # Compute alive info for all layers upfront + all_layers: set[str] = set(sources_by_target.keys()) + for sources in sources_by_target.values(): + all_layers.update(sources) + + # alive_info[layer] -> list of LayerAliveInfo, one per batch item + alive_info: dict[str, list[LayerAliveInfo]] = {} + original_alive_info: dict[str, list[LayerAliveInfo]] = {} + # alive_c_union[layer] -> union of alive_c_idxs across all batches + alive_c_union: dict[str, list[int]] = {} + + for layer in all_layers: + alive_info[layer] = compute_layer_alive_info( + layer, + ci_lower_leaky, + output_probs, + ci_threshold, + output_prob_threshold, + n_seq, + n_batch, + device, + ) + # Compute original alive info (from model CI, not optimized CI) + original_alive_info[layer] = compute_layer_alive_info( + layer, + ci_original, + output_probs, + ci_threshold, + output_prob_threshold, + n_seq, + n_batch, + device, + ) + # Compute union of alive components across batches + all_alive: set[int] = set() + for batch_info in alive_info[layer]: + all_alive.update(batch_info.alive_c_idxs) + alive_c_union[layer] = sorted(all_alive) + + # Initialize output dictionary + local_attributions: dict[str, list[PairAttribution]] = {name: [] for name in sample_names} + + for target, sources in tqdm(sources_by_target.items(), desc="Target layers"): + target_infos = alive_info[target] # list of LayerAliveInfo per batch + out_pre_detach: Float[Tensor, "N s dim"] = cache[f"{target}_pre_detach"] + + all_source_infos = [alive_info[source] for source in sources] # list of lists + in_post_detaches: list[Float[Tensor, "N s dim"]] = [ + cache[f"{source}_post_detach"] for source in sources + ] + + # Initialize per-batch attribution tensors + # attributions[source_idx][batch_idx] = tensor + attributions: list[list[Float[Tensor, "s_in n_c_in s_out n_c_out"]]] = [ + [ + torch.zeros( + n_seq, + len(source_infos[b].alive_c_idxs), + n_seq, + len(target_infos[b].alive_c_idxs), + device=device, + ) + for b in range(n_batch) + ] + for source_infos in all_source_infos + ] + + # NOTE: o->q will be treated as an attention pair even though there are no attrs + # across sequence positions. This is just so we don't have to special case it. + is_attention_output = any(is_kv_to_o_pair(source, target) for source in sources) + + for s_out in tqdm(range(n_seq), desc=f"{target} <- {sources}", leave=False): + # Union of alive c_out across all batches at this position + s_out_alive_c_union: list[int] = [ + c + for c in alive_c_union[target] + if any(info.alive_mask[s_out, c] for info in target_infos) + ] + if not s_out_alive_c_union: + continue + + for c_out in s_out_alive_c_union: + # Sum over batch dimension for efficient batched gradient + in_post_detach_grads = torch.autograd.grad( + outputs=out_pre_detach[:, s_out, c_out].sum(), + inputs=in_post_detaches, + retain_graph=True, + ) + # Handle causal attention mask + s_in_range = range(s_out + 1) if is_attention_output else range(s_out, s_out + 1) + + with torch.no_grad(): + for source_idx, (source, source_infos, grad, in_post_detach) in enumerate( + zip( + sources, + all_source_infos, + in_post_detach_grads, + in_post_detaches, + strict=True, + ) + ): + weighted: Float[Tensor, "N s C"] = grad * in_post_detach + if source == "wte": + # Sum over embedding_dim to get single pseudo-component + weighted = weighted.sum(dim=-1, keepdim=True) + + # Store attributions per-batch + for b in range(n_batch): + # Only store if c_out is alive in this batch at this position + if not target_infos[b].alive_mask[s_out, c_out]: + continue + if c_out not in target_infos[b].c_to_trimmed: + continue + trimmed_c_out = target_infos[b].c_to_trimmed[c_out] + + for s_in in s_in_range: + alive_c_in = [ + c + for c in source_infos[b].alive_c_idxs + if source_infos[b].alive_mask[s_in, c] + ] + for c_in in alive_c_in: + trimmed_c_in = source_infos[b].c_to_trimmed[c_in] + attributions[source_idx][b][ + s_in, trimmed_c_in, s_out, trimmed_c_out + ] = weighted[b, s_in, c_in] + + # Build output per sample + for source_idx, (source, source_infos) in enumerate( + zip(sources, all_source_infos, strict=True) + ): + original_source_infos = original_alive_info[source] + original_target_infos = original_alive_info[target] + for b, sample_name in enumerate(sample_names): + local_attributions[sample_name].append( + PairAttribution( + source=source, + target=target, + attribution=attributions[source_idx][b], + trimmed_c_in_idxs=source_infos[b].alive_c_idxs, + trimmed_c_out_idxs=target_infos[b].alive_c_idxs, + is_kv_to_o_pair=is_kv_to_o_pair(source, target), + original_alive_mask_in=original_source_infos[b].alive_mask, # [seq, C] + original_alive_mask_out=original_target_infos[b].alive_mask, # [seq, C] + ) + ) + + return local_attributions, output_probs + + +def main( + wandb_path: str, + n_blocks: int, + ci_threshold: float, + output_prob_threshold: float, + ci_vals_path: str | Path | None, + prompts: list[str] | None = None, +) -> None: + """Compute local attributions for a prompt. + + Args: + wandb_path: WandB path to load model from. + n_blocks: Number of transformer blocks to analyze. + ci_threshold: Threshold for considering a component alive. + output_prob_threshold: Threshold for considering an output logit alive. + ci_vals_path: Path to JSON with precomputed CI values and prompts. + If provided, prompts are read from the JSON file. + prompts: List of prompts to use when ci_vals_path is None. Required if ci_vals_path is None. + """ + loaded = load_model_from_wandb(wandb_path) + model, config, device = loaded.model, loaded.config, loaded.device + + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" + sources_by_target = get_sources_by_target(model, device, config, n_blocks) + + n_pairs = sum(len(ins) for ins in sources_by_target.values()) + print(f"Sources by target: {n_pairs} pairs across {len(sources_by_target)} target layers") + for out_layer, in_layers in sources_by_target.items(): + print(f" {out_layer} <- {in_layers}") + + data = get_tokens_and_ci_vals( + model=model, + tokenizer=tokenizer, + sampling=config.sampling, + device=device, + ci_vals_path=ci_vals_path, + prompts=prompts, + ) + + # Compute local attributions + print("\nComputing local attributions...") + attr_pairs_by_sample, output_probs = compute_local_attributions( + model=model, + tokens=data.tokens, + sources_by_target=sources_by_target, + ci_threshold=ci_threshold, + output_prob_threshold=output_prob_threshold, + sampling=config.sampling, + device=device, + ci_lower_leaky=data.ci_lower_leaky, + sample_names=data.sample_names, + ) + + # Print summary statistics per sample + for sample_name, attr_pairs in attr_pairs_by_sample.items(): + print(f"\nAttribution summary for {sample_name}:") + for attr_pair in attr_pairs: + total = attr_pair.attribution.numel() + if total == 0: + print( + f"Ignoring {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, zero" + ) + continue + nonzero = (attr_pair.attribution > 0).sum().item() + print( + f" {attr_pair.source} -> {attr_pair.target}: " + f"shape={list(attr_pair.attribution.shape)}, " + f"nonzero={nonzero}/{total} ({100 * nonzero / (total + 1e-12):.2f}%), " + f"max={attr_pair.attribution.max():.6f}" + ) + + # Save and plot per sample + out_dir = get_out_dir() + + for b, sample_name in enumerate(data.sample_names): + attr_pairs = attr_pairs_by_sample[sample_name] + pt_path = out_dir / f"local_attributions_{loaded.wandb_id}_{sample_name}.pt" + output_path = out_dir / f"local_attribution_graph_{loaded.wandb_id}_{sample_name}.png" + + output_token_labels: dict[int, str] = {} + output_probs_by_pos: dict[tuple[int, int], float] = {} + for attr_pair in attr_pairs: + if attr_pair.target == "output": + for c_idx in attr_pair.trimmed_c_out_idxs: + if c_idx not in output_token_labels: + output_token_labels[c_idx] = tokenizer.decode([c_idx]) + # Store probability for each position + for s in range(data.tokens.shape[1]): + prob = output_probs[b, s, c_idx].item() + if prob >= output_prob_threshold: + output_probs_by_pos[(s, c_idx)] = prob + break + + save_data = { + "attr_pairs": attr_pairs, + "token_strings": data.all_token_strings[b], + "prompt": data.prompts[b], + "ci_threshold": ci_threshold, + "output_prob_threshold": output_prob_threshold, + "output_token_labels": output_token_labels, + "output_probs_by_pos": output_probs_by_pos, + "wandb_id": loaded.wandb_id, + "sample_name": sample_name, + } + torch.save(save_data, pt_path) + print(f"\nSaved local attributions for {sample_name} to {pt_path}") + + fig = plot_local_graph( + attr_pairs=attr_pairs, + token_strings=data.all_token_strings[b], + output_token_labels=output_token_labels, + output_prob_threshold=output_prob_threshold, + output_probs_by_pos=output_probs_by_pos, + min_edge_weight=0.0001, + node_scale=30.0, + edge_alpha_scale=0.7, + ) + + fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") + print(f"Saved figure for {sample_name} to {output_path}") + + +if __name__ == "__main__": + # Configuration + # wandb_path = "wandb:goodfire/spd/runs/8ynfbr38" # ss_gpt2_simple-1L (Old) + wandb_path = "wandb:goodfire/spd/runs/33n6xjjt" # ss_gpt2_simple-1L (new) + # wandb_path = "wandb:goodfire/spd/runs/c0k3z78g" # ss_gpt2_simple-2L + # wandb_path = "wandb:goodfire/spd/runs/jyo9duz5" # ss_gpt2_simple-1.25M (4L) + # n_blocks = 4 + n_blocks = 1 + ci_threshold = 1e-6 + output_prob_threshold = 1e-1 + # ci_vals_path = Path("spd/scripts/optim_cis/out/optimized_ci_33n6xjjt.json") + ci_vals_path = None + prompts = ["They walked hand in", "She is a happy"] + main( + wandb_path=wandb_path, + n_blocks=n_blocks, + ci_threshold=ci_threshold, + output_prob_threshold=output_prob_threshold, + ci_vals_path=ci_vals_path, + prompts=prompts, + ) diff --git a/spd/scripts/model_loading.py b/spd/scripts/model_loading.py new file mode 100644 index 000000000..a61fc7ed2 --- /dev/null +++ b/spd/scripts/model_loading.py @@ -0,0 +1,113 @@ +# %% +"""Shared model loading utilities for attribution scripts.""" + +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from transformers import PreTrainedTokenizer + +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, SPDRunInfo + + +@dataclass +class LoadedModel: + """Container for a loaded SPD model and its configuration.""" + + model: ComponentModel + config: Config + run_info: SPDRunInfo + device: str + wandb_id: str + + +def load_model_from_wandb(wandb_path: str, device: str | None = None) -> LoadedModel: + """Load a ComponentModel from a wandb run path. + + Args: + wandb_path: Path like "wandb:goodfire/spd/runs/8ynfbr38" + device: Device to load model on. If None, uses cuda if available. + + Returns: + LoadedModel containing the model, config, and metadata. + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + wandb_id = wandb_path.split("/")[-1] + + print(f"Using device: {device}") + print(f"Loading model from {wandb_path}...") + + run_info = SPDRunInfo.from_path(wandb_path) + config: Config = run_info.config + model = ComponentModel.from_run_info(run_info) + model = model.to(device) + model.eval() + + print("Model loaded successfully!") + print(f"Number of components: {model.C}") + print(f"Target module paths: {model.target_module_paths}") + + return LoadedModel( + model=model, + config=config, + run_info=run_info, + device=device, + wandb_id=wandb_id, + ) + + +def create_data_loader_from_config( + config: Config, + batch_size: int, + n_ctx: int, + seed: int = 0, +) -> tuple[Iterable[dict[str, Any]], PreTrainedTokenizer]: + """Create a data loader from an SPD config. + + Args: + config: SPD Config with task configuration. + batch_size: Batch size for the data loader. + n_ctx: Context length. + seed: Random seed for shuffling. + + Returns: + Tuple of (data_loader, tokenizer). + """ + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig), "Expected LM task config" + + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.train_data_split, + n_ctx=n_ctx, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=False, + seed=seed, + ) + + print(f"\nLoading dataset {dataset_config.name}...") + data_loader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=task_config.buffer_size, + global_seed=seed, + ) + + return data_loader, tokenizer + + +def get_out_dir() -> Path: + """Get the output directory for attribution scripts.""" + out_dir = Path(__file__).parent / "out" + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir diff --git a/spd/scripts/optim_cis/__init__.py b/spd/scripts/optim_cis/__init__.py new file mode 100644 index 000000000..a958aba2a --- /dev/null +++ b/spd/scripts/optim_cis/__init__.py @@ -0,0 +1 @@ +"""CI optimization for single prompts.""" diff --git a/spd/scripts/optim_cis/config.py b/spd/scripts/optim_cis/config.py new file mode 100644 index 000000000..cb0acb920 --- /dev/null +++ b/spd/scripts/optim_cis/config.py @@ -0,0 +1,103 @@ +"""Configuration for CI optimization on single prompts.""" + +from typing import Literal, Self + +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt, model_validator + +from spd.base_config import BaseConfig +from spd.configs import ImportanceMinimalityLossConfig, SamplingType +from spd.spd_types import Probability + + +class OptimCIConfig(BaseConfig): + """Configuration for optimizing CI values on a single prompt.""" + + seed: int = Field( + ..., + description="Random seed for reproducibility", + ) + # Model and prompt + wandb_path: str = Field( + ..., + description="Wandb path to load model from, e.g. 'wandb:goodfire/spd/runs/jyo9duz5'", + ) + prompt: str = Field( + ..., + description="The prompt to optimize CI values for", + ) + label: str = Field( + ..., + description="The label to optimize CI values for", + ) + + # Optimization hyperparameters + lr: PositiveFloat = Field( + ..., + description="Learning rate for AdamW optimizer", + ) + steps: PositiveInt = Field( + ..., + description="Number of optimization steps", + ) + weight_decay: NonNegativeFloat = Field( + ..., + description="Weight decay for AdamW optimizer", + ) + lr_schedule: Literal["linear", "constant", "cosine", "exponential"] = Field( + ..., + description="Type of learning-rate schedule to apply", + ) + lr_exponential_halflife: PositiveFloat | None = Field( + ..., + description="Half-life parameter when using an exponential LR schedule", + ) + lr_warmup_pct: Probability = Field( + ..., + description="Fraction of total steps to linearly warm up the learning rate", + ) + log_freq: PositiveInt = Field( + ..., + description="Frequency of logging during optimization", + ) + + # Loss configs + imp_min_config: ImportanceMinimalityLossConfig = Field( + ..., + description="Configuration for the importance minimality loss", + ) + ce_loss_coeff: float = Field( + ..., + description="Coefficient for the CE loss", + ) + + # CI thresholds and sampling + ci_threshold: PositiveFloat = Field( + ..., + description="Threshold for considering a component alive in original CI values. " + "Only components with CI > ci_threshold will be optimized.", + ) + sampling: SamplingType = Field( + ..., + description="Sampling mode for stochastic losses: 'continuous' or 'binomial'", + ) + n_mask_samples: PositiveInt = Field( + ..., + description="Number of stochastic masks to sample for recon losses", + ) + output_loss_type: Literal["mse", "kl"] = Field( + ..., + description="Loss type for reconstruction: 'kl' for LMs, 'mse' for vectors", + ) + + ce_kl_rounding_threshold: float = Field( + ..., + description="Threshold for rounding CI values in CE/KL metric computation", + ) + + @model_validator(mode="after") + def validate_model(self) -> Self: + if self.lr_schedule == "exponential": + assert self.lr_exponential_halflife is not None, ( + "lr_exponential_halflife must be set if lr_schedule is exponential" + ) + return self diff --git a/spd/scripts/optim_cis/run_optim_cis.py b/spd/scripts/optim_cis/run_optim_cis.py new file mode 100644 index 000000000..096314873 --- /dev/null +++ b/spd/scripts/optim_cis/run_optim_cis.py @@ -0,0 +1,438 @@ +# %% +"""Optimize CI values for a single prompt while keeping component weights fixed.""" + +import json +from dataclasses import dataclass +from pathlib import Path + +import torch +import torch.nn.functional as F +import torch.optim as optim +from jaxtyping import Bool, Float +from torch import Tensor +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from spd.configs import ImportanceMinimalityLossConfig +from spd.metrics import importance_minimality_loss +from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache +from spd.models.components import make_mask_infos +from spd.routing import AllLayersRouter +from spd.scripts.model_loading import load_model_from_wandb +from spd.scripts.optim_cis.config import OptimCIConfig +from spd.utils.component_utils import calc_ci_l_zero, calc_stochastic_component_mask_info +from spd.utils.general_utils import set_seed + + +@dataclass +class AliveComponentInfo: + """Info about which components are alive at each position for each layer.""" + + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] # Per-layer masks of alive positions + alive_counts: dict[str, list[int]] # Number of alive components per position per layer + + +def compute_alive_info( + ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], + ci_threshold: float, +) -> AliveComponentInfo: + """Compute which (position, component) pairs are alive based on initial CI values.""" + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = {} + alive_counts: dict[str, list[int]] = {} + + for layer_name, ci in ci_lower_leaky.items(): + mask = ci > ci_threshold + alive_masks[layer_name] = mask + # Count alive components per position: mask is [1, seq, C], sum over C + counts_per_pos = mask[0].sum(dim=-1) # [seq] + alive_counts[layer_name] = counts_per_pos.tolist() + + return AliveComponentInfo(alive_masks=alive_masks, alive_counts=alive_counts) + + +@dataclass +class OptimizableCIParams: + """Container for optimizable CI pre-sigmoid parameters.""" + + # List of pre-sigmoid tensors for alive positions at each sequence position + ci_pre_sigmoid: dict[str, list[Tensor]] # layer_name -> list of [alive_at_pos] values + alive_info: AliveComponentInfo + + def create_ci_outputs(self, model: ComponentModel, device: str) -> CIOutputs: + """Expand sparse pre-sigmoid values to full CI tensors and create CIOutputs.""" + pre_sigmoid: dict[str, Tensor] = {} + + for layer_name, mask in self.alive_info.alive_masks.items(): + # Create full tensors (default to 0 for non-alive positions) + full_pre_sigmoid = torch.zeros_like(mask, dtype=torch.float32, device=device) + + # Get pre-sigmoid list for this layer + layer_pre_sigmoid_list = self.ci_pre_sigmoid[layer_name] + + # For each position, place the values + seq_len = mask.shape[1] + for pos in range(seq_len): + pos_mask = mask[0, pos, :] # [C] + pos_pre_sigmoid = layer_pre_sigmoid_list[pos] # [alive_at_pos] + full_pre_sigmoid[0, pos, pos_mask] = pos_pre_sigmoid + + pre_sigmoid[layer_name] = full_pre_sigmoid + + return CIOutputs( + lower_leaky={k: model.lower_leaky_fn(v) for k, v in pre_sigmoid.items()}, + upper_leaky={k: model.upper_leaky_fn(v) for k, v in pre_sigmoid.items()}, + pre_sigmoid=pre_sigmoid, + ) + + def get_parameters(self) -> list[Tensor]: + """Get all optimizable parameters.""" + params: list[Tensor] = [] + for layer_pre_sigmoid_list in self.ci_pre_sigmoid.values(): + params.extend(layer_pre_sigmoid_list) + return params + + +def create_optimizable_ci_params( + alive_info: AliveComponentInfo, + initial_pre_sigmoid: dict[str, Tensor], +) -> OptimizableCIParams: + """Create optimizable CI parameters for alive positions. + + Creates parameters initialized from the initial pre-sigmoid values for each + (position, component) pair where initial CI > threshold. + """ + ci_pre_sigmoid: dict[str, list[Tensor]] = {} + + for layer_name, mask in alive_info.alive_masks.items(): + # Get initial pre-sigmoid values for this layer + layer_initial = initial_pre_sigmoid[layer_name] # [1, seq, C] + + # Create a tensor for each position + layer_pre_sigmoid_list: list[Tensor] = [] + seq_len = mask.shape[1] + for pos in range(seq_len): + pos_mask = mask[0, pos, :] # [C] + # Extract initial values for alive positions at this position + initial_values = layer_initial[0, pos, pos_mask].clone().detach() + initial_values.requires_grad_(True) + layer_pre_sigmoid_list.append(initial_values) + ci_pre_sigmoid[layer_name] = layer_pre_sigmoid_list + + return OptimizableCIParams( + ci_pre_sigmoid=ci_pre_sigmoid, + alive_info=alive_info, + ) + + +def compute_l0_stats( + ci_outputs: CIOutputs, + ci_alive_threshold: float, +) -> dict[str, float]: + """Compute L0 statistics for each layer.""" + stats: dict[str, float] = {} + for layer_name, layer_ci in ci_outputs.lower_leaky.items(): + l0_val = calc_ci_l_zero(layer_ci, ci_alive_threshold) + stats[f"l0/{layer_name}"] = l0_val + stats["l0/total"] = sum(stats.values()) + return stats + + +def compute_final_token_ce_kl( + model: ComponentModel, + batch: Tensor, + target_out: Tensor, + ci: dict[str, Tensor], + rounding_threshold: float, +) -> dict[str, float]: + """Compute CE and KL metrics for the final token only. + + Args: + model: The ComponentModel. + batch: Input tokens of shape [1, seq_len]. + target_out: Target model output logits of shape [1, seq_len, vocab]. + ci: Causal importance values (lower_leaky) per layer. + rounding_threshold: Threshold for rounding CI values to binary masks. + + Returns: + Dict with kl and ce_difference metrics for ci_masked, unmasked, and rounded_masked. + """ + assert batch.ndim == 2 and batch.shape[0] == 1, "Expected batch shape [1, seq_len]" + + # Get the label for CE (next token prediction at final position) + # The label is the token at the final position for the second-to-last logit prediction + # But since we're optimizing for CI on a single prompt, we use the final logit position + final_target_logits = target_out[0, -1, :] # [vocab] + + def kl_vs_target(logits: Tensor) -> float: + """KL divergence between predicted and target logits at final position.""" + final_logits = logits[0, -1, :] # [vocab] + target_probs = F.softmax(final_target_logits, dim=-1) + pred_log_probs = F.log_softmax(final_logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="sum").item() + + def ce_vs_target(logits: Tensor) -> float: + """CE between predicted logits and target's argmax at final position.""" + final_logits = logits[0, -1, :] # [vocab] + target_token = final_target_logits.argmax() + return F.cross_entropy(final_logits.unsqueeze(0), target_token.unsqueeze(0)).item() + + # Target model CE (baseline) + target_ce = ce_vs_target(target_out) + + # CI masked + ci_mask_infos = make_mask_infos(ci) + ci_masked_logits = model(batch, mask_infos=ci_mask_infos) + ci_masked_kl = kl_vs_target(ci_masked_logits) + ci_masked_ce = ce_vs_target(ci_masked_logits) + + # Unmasked (all components active) + unmasked_infos = make_mask_infos({k: torch.ones_like(v) for k, v in ci.items()}) + unmasked_logits = model(batch, mask_infos=unmasked_infos) + unmasked_kl = kl_vs_target(unmasked_logits) + unmasked_ce = ce_vs_target(unmasked_logits) + + # Rounded masked (binary masks based on threshold) + rounded_mask_infos = make_mask_infos( + {k: (v > rounding_threshold).float() for k, v in ci.items()} + ) + rounded_masked_logits = model(batch, mask_infos=rounded_mask_infos) + rounded_masked_kl = kl_vs_target(rounded_masked_logits) + rounded_masked_ce = ce_vs_target(rounded_masked_logits) + + return { + "kl_ci_masked": ci_masked_kl, + "kl_unmasked": unmasked_kl, + "kl_rounded_masked": rounded_masked_kl, + "ce_difference_ci_masked": ci_masked_ce - target_ce, + "ce_difference_unmasked": unmasked_ce - target_ce, + "ce_difference_rounded_masked": rounded_masked_ce - target_ce, + } + + +def optimize_ci_values( + model: ComponentModel, + tokens: Tensor, + label_token: int, + config: OptimCIConfig, + device: str, +) -> OptimizableCIParams: + """Optimize CI values for a single prompt. + + Args: + model: The ComponentModel (weights will be frozen). + tokens: Tokenized prompt of shape [1, seq_len]. + label_token: The token to optimize CI values for. + config: Optimization configuration. + device: Device to run on. + Returns: + The OptimizableCIParams object. + """ + imp_min_coeff = config.imp_min_config.coeff + assert imp_min_coeff is not None, "Importance minimality loss coefficient must be set" + ce_loss_coeff = config.ce_loss_coeff + + # Freeze all model parameters + model.requires_grad_(False) + + # Get initial CI values from the model + with torch.no_grad(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + initial_ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + target_out = output_with_cache.output.detach() + + # Compute alive info and create optimizable parameters + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky, config.ci_threshold) + ci_params = create_optimizable_ci_params( + alive_info=alive_info, + initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, + ) + + # Log initial alive counts + total_alive = sum(sum(counts) for counts in alive_info.alive_counts.values()) + print(f"\nAlive components (CI > {config.ci_threshold}):") + for layer_name, counts in alive_info.alive_counts.items(): + layer_total = sum(counts) + print(f" {layer_name}: {layer_total} total across {len(counts)} positions") + print(f" Total: {total_alive}") + + weight_deltas = model.calc_weight_deltas() + + params = ci_params.get_parameters() + optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) + + for step in tqdm(range(config.steps), desc="Optimizing CI values"): + optimizer.zero_grad() + + # Create CI outputs from current parameters + ci_outputs = ci_params.create_ci_outputs(model, device) + + mask_infos = calc_stochastic_component_mask_info( + causal_importances=ci_outputs.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + out = model(tokens, mask_infos=mask_infos) + + imp_min_loss = importance_minimality_loss( + ci_upper_leaky=ci_outputs.upper_leaky, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) + ce_loss = F.cross_entropy( + out[0, -1, :].unsqueeze(0), torch.tensor([label_token], device=device) + ) + total_loss = ce_loss_coeff * ce_loss + imp_min_coeff * imp_min_loss + # Get the output probability for the label_token in the final seq position + label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] + + if step % config.log_freq == 0 or step == config.steps - 1: + l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) + + # Compute CE/KL metrics for final token only + with torch.no_grad(): + ce_kl_stats = compute_final_token_ce_kl( + model=model, + batch=tokens, + target_out=target_out, + ci=ci_outputs.lower_leaky, + rounding_threshold=config.ce_kl_rounding_threshold, + ) + + # Also calculate the ci-masked label probability + with torch.no_grad(): + mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") + out = model(tokens, mask_infos=mask_infos) + ci_masked_label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] + + log_terms = { + "imp_min_loss": imp_min_loss.item(), + "ce_loss": ce_loss.item(), + "total_loss": total_loss.item(), + "stoch_masked_label_prob": label_prob.item(), + "ci_masked_label_prob": ci_masked_label_prob.item(), + } + tqdm.write(f"\n--- Step {step} ---") + for name, value in log_terms.items(): + tqdm.write(f" {name}: {value:.6f}") + for name, value in l0_stats.items(): + tqdm.write(f" {name}: {value:.2f}") + for name, value in ce_kl_stats.items(): + tqdm.write(f" {name}: {value:.6f}") + + total_loss.backward() + optimizer.step() + + return ci_params + + +def get_out_dir() -> Path: + """Get the output directory for optimization results.""" + out_dir = Path(__file__).parent / "out" + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir + + +# %% +if __name__ == "__main__": + config = OptimCIConfig( + seed=0, + # wandb_path="wandb:goodfire/spd/runs/jyo9duz5", # ss_gpt2_simple-1.25M (4L) + wandb_path="wandb:goodfire/spd/runs/33n6xjjt", # ss_gpt2_simple-1L + prompt="They walked hand in", + label="hand", + lr=1e-2, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + steps=2000, + log_freq=500, + imp_min_config=ImportanceMinimalityLossConfig(coeff=1e-1, pnorm=0.3), + ce_loss_coeff=1, + ci_threshold=1e-6, + sampling="continuous", + n_mask_samples=1, + output_loss_type="kl", + ce_kl_rounding_threshold=0.5, + ) + + set_seed(config.seed) + + loaded = load_model_from_wandb(config.wandb_path) + model, run_config, device = loaded.model, loaded.config, loaded.device + + tokenizer = AutoTokenizer.from_pretrained(run_config.tokenizer_name) + assert isinstance(tokenizer, PreTrainedTokenizerFast), "Expected PreTrainedTokenizerFast" + + print(f"\nPrompt: {config.prompt!r}") + tokens = tokenizer.encode(config.prompt, return_tensors="pt", add_special_tokens=False) + assert isinstance(tokens, Tensor), "Expected Tensor" + tokens = tokens.to(device) + token_strings = [tokenizer.decode([t]) for t in tokens[0].tolist()] + print(f"Token strings: {token_strings}") + label_token_ids = tokenizer.encode(config.label, add_special_tokens=False) + assert len(label_token_ids) == 1, f"Expected single token for label, got {len(label_token_ids)}" + label_token = label_token_ids[0] + + # Run optimization + ci_params = optimize_ci_values( + model=model, + tokens=tokens, + label_token=label_token, + config=config, + device=device, + ) + + # Get final metrics + ci_outputs = ci_params.create_ci_outputs(model, device) + l0_stats = compute_l0_stats(ci_outputs, config.ci_threshold) + + with torch.no_grad(): + target_out = model(tokens) + ce_kl_stats = compute_final_token_ce_kl( + model=model, + batch=tokens, + target_out=target_out, + ci=ci_outputs.lower_leaky, + rounding_threshold=config.ce_kl_rounding_threshold, + ) + # Use ci-masked model to get final label probability + mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") + out = model(tokens, mask_infos=mask_infos) + label_prob = F.softmax(out[0, -1, :], dim=-1)[label_token] + + final_metrics = {**l0_stats, **ce_kl_stats, "ci_masked_label_prob": label_prob.item()} + print(f"\nFinal metrics after {config.steps} steps:") + for name, value in final_metrics.items(): + print(f" {name}: {value:.6f}") + + out_dir = get_out_dir() + output_path = out_dir / f"optimized_ci_{loaded.wandb_id}.json" + + output_data = { + "samples": { + f"imp_min_{config.imp_min_config.coeff:.0e}": { + "prompt": config.prompt, + "ci_vals": {k: v[0].cpu().tolist() for k, v in ci_outputs.lower_leaky.items()}, + "metadata": { + "config": config.model_dump(), + "wandb_id": loaded.wandb_id, + }, + }, + }, + } + + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"\nSaved optimized CI values to {output_path}") diff --git a/spd/scripts/plot_global_attributions.py b/spd/scripts/plot_global_attributions.py new file mode 100644 index 000000000..5189ff874 --- /dev/null +++ b/spd/scripts/plot_global_attributions.py @@ -0,0 +1,324 @@ +"""Plot attribution graph from saved global attributions.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import matplotlib.pyplot as plt +import networkx as nx +import torch +from torch import Tensor + +from spd.scripts.model_loading import get_out_dir + +if TYPE_CHECKING: + Graph = nx.DiGraph[Any] +else: + Graph = nx.DiGraph + + +@dataclass +class LayerInfo: + """All display information for a layer type.""" + + name: str + color: str + y_offset: float + x_offset: float = 0.0 + legend_name: str | None = None # If different from name + + @property + def display_name(self) -> str: + return self.legend_name if self.legend_name is not None else self.name + + +# fmt: off +LAYER_INFOS: dict[str, LayerInfo] = { + "wte": LayerInfo("wte", color="#34495E", y_offset=-3.0), + "attn.q_proj": LayerInfo("attn.q_proj", color="#1f77b4", y_offset=-1.0, x_offset=-20.0, legend_name="q_proj"), + "attn.k_proj": LayerInfo("attn.k_proj", color="#2ca02c", y_offset=-1.0, x_offset=0.0, legend_name="k_proj"), + "attn.v_proj": LayerInfo("attn.v_proj", color="#9467bd", y_offset=-1.0, x_offset=20.0, legend_name="v_proj"), + "attn.o_proj": LayerInfo("attn.o_proj", color="#d62728", y_offset=0.0, legend_name="o_proj"), + "mlp.c_fc": LayerInfo("mlp.c_fc", color="#ff7f0e", y_offset=1.0, legend_name="c_fc"), + "mlp.down_proj": LayerInfo("mlp.down_proj", color="#8c564b", y_offset=2.0, legend_name="down_proj"), + "output": LayerInfo("output", color="#17A589", y_offset=4.0), +} +# fmt: on + + +def load_attributions( + out_dir: Path, wandb_id: str +) -> tuple[dict[tuple[str, str], Tensor], dict[str, list[int]]]: + """Load global attributions and reconstruct alive_indices from tensor shapes.""" + global_attributions: dict[tuple[str, str], Tensor] = torch.load( + out_dir / f"global_attributions_{wandb_id}.pt" + ) + + alive_indices: dict[str, list[int]] = {} + for (in_layer, out_layer), attr in global_attributions.items(): + n_alive_in, n_alive_out = attr.shape + if in_layer not in alive_indices: + alive_indices[in_layer] = list(range(n_alive_in)) + if out_layer not in alive_indices: + alive_indices[out_layer] = list(range(n_alive_out)) + + return global_attributions, alive_indices + + +def print_edge_statistics(global_attributions: dict[tuple[str, str], Tensor]) -> None: + """Print statistics about edges at various thresholds.""" + total_edges = sum(attr.numel() for attr in global_attributions.values()) + print(f"Total edges: {total_edges:,}") + + thresholds = [1, 0.6, 0.2, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-12, 1e-15] + for threshold in thresholds: + edges_above = sum((attr > threshold).sum().item() for attr in global_attributions.values()) + print(f"Edges > {threshold}: {edges_above:,}") + + +def build_layer_list(n_blocks: int) -> list[str]: + """Build full layer list in network order (wte -> blocks -> output).""" + all_layers = ["wte"] + for block_idx in range(n_blocks): + for layer_name in [k for k in LAYER_INFOS if k not in ["wte", "output"]]: + all_layers.append(f"h.{block_idx}.{layer_name}") + all_layers.append("output") + return all_layers + + +def find_nodes_with_edges( + global_attributions: dict[tuple[str, str], Tensor], + alive_indices: dict[str, list[int]], + edge_threshold: float, +) -> set[str]: + """Find all nodes that participate in edges above threshold.""" + nodes_with_edges: set[str] = set() + for (in_layer, out_layer), attr_tensor in global_attributions.items(): + in_alive = alive_indices.get(in_layer, []) + out_alive = alive_indices.get(out_layer, []) + for i, in_comp in enumerate(in_alive): + for j, out_comp in enumerate(out_alive): + if attr_tensor[i, j].item() > edge_threshold: + nodes_with_edges.add(f"{in_layer}:{in_comp}") + nodes_with_edges.add(f"{out_layer}:{out_comp}") + return nodes_with_edges + + +def parse_layer_name(layer: str, n_blocks: int) -> tuple[int, str]: + """Parse layer name to get block_idx and base layer_name.""" + if layer == "wte": + return 0, "wte" + elif layer == "output": + return n_blocks - 1, "output" + else: + parts = layer.split(".") + return int(parts[1]), ".".join(parts[2:]) + + +def build_graph( + all_layers: list[str], + global_attributions: dict[tuple[str, str], Tensor], + alive_indices: dict[str, list[int]], + nodes_with_edges: set[str], + n_blocks: int, + edge_threshold: float, + block_spacing: float = 6.0, +) -> tuple[Graph, dict[str, tuple[float, float]], list[float]]: + """Build the attribution graph with node positions and edge weights.""" + G: Graph = nx.DiGraph() + node_positions: dict[str, tuple[float, float]] = {} + + # Add nodes + for layer in all_layers: + block_idx, layer_name = parse_layer_name(layer, n_blocks) + info = LAYER_INFOS[layer_name] + + y_base = block_idx * block_spacing + info.y_offset + x_base = info.x_offset + + layer_alive = alive_indices.get(layer, []) + layer_nodes_with_edges = [ + (idx, comp) + for idx, comp in enumerate(layer_alive) + if f"{layer}:{comp}" in nodes_with_edges + ] + n_layer_nodes = len(layer_nodes_with_edges) + + for pos_idx, (_, local_idx) in enumerate(layer_nodes_with_edges): + node_id = f"{layer}:{local_idx}" + G.add_node(node_id, layer=layer, component=local_idx) + x = x_base + (pos_idx - n_layer_nodes / 2) * 0.25 + node_positions[node_id] = (x, y_base) + + # Add edges + edge_weights: list[float] = [] + for (in_layer, out_layer), attr_tensor in global_attributions.items(): + in_alive = alive_indices.get(in_layer, []) + out_alive = alive_indices.get(out_layer, []) + + for i, in_comp in enumerate(in_alive): + for j, out_comp in enumerate(out_alive): + weight = attr_tensor[i, j].item() + if weight > edge_threshold: + in_node = f"{in_layer}:{in_comp}" + out_node = f"{out_layer}:{out_comp}" + if in_node in G.nodes and out_node in G.nodes: + G.add_edge(in_node, out_node, weight=weight) + edge_weights.append(weight) + + return G, node_positions, edge_weights + + +def draw_nodes( + ax: plt.Axes, + G: Graph, + node_positions: dict[str, tuple[float, float]], + all_layers: list[str], +) -> None: + """Draw nodes grouped by layer.""" + for layer in all_layers: + if layer in ("wte", "output"): + layer_name = layer + else: + parts = layer.split(".") + layer_name = ".".join(parts[2:]) + color = LAYER_INFOS[layer_name].color + + layer_nodes = [n for n in G.nodes if G.nodes[n].get("layer") == layer] + if layer_nodes: + pos_subset = {n: node_positions[n] for n in layer_nodes} + nx.draw_networkx_nodes( + G, + pos_subset, + nodelist=layer_nodes, + node_color=color, + node_size=100, + alpha=0.8, + ax=ax, + ) + + +def draw_edges( + ax: plt.Axes, + G: Graph, + node_positions: dict[str, tuple[float, float]], + edge_weights: list[float], + n_buckets: int = 10, +) -> None: + """Draw edges batched by weight bucket for performance.""" + if not edge_weights: + return + + max_weight = max(edge_weights) + min_weight = min(edge_weights) + + edge_buckets: list[list[tuple[str, str]]] = [[] for _ in range(n_buckets)] + + for u, v, data in G.edges(data=True): + weight = data.get("weight", 0) + if max_weight > min_weight: + normalized = (weight - min_weight) / (max_weight - min_weight) + else: + normalized = 0.5 + bucket_idx = min(int(normalized * n_buckets), n_buckets - 1) + edge_buckets[bucket_idx].append((u, v)) + + for bucket_idx, bucket_edges in enumerate(edge_buckets): + if not bucket_edges: + continue + normalized = (bucket_idx + 0.5) / n_buckets + width = 0.2 + normalized * 2.0 + alpha = 0.3 + normalized * 0.5 + + nx.draw_networkx_edges( + G, + node_positions, + edgelist=bucket_edges, + width=width, + alpha=alpha, + edge_color="#666666", + arrows=True, + arrowsize=8, + connectionstyle="arc3,rad=0.1", + ax=ax, + ) + + +def add_legend(ax: plt.Axes) -> None: + """Add legend for layer types.""" + legend_elements = [ + plt.Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=info.color, + markersize=10, + label=info.display_name, + ) + for info in LAYER_INFOS.values() + ] + ax.legend(handles=legend_elements, loc="upper right", fontsize=8) + + +def plot_attribution_graph( + global_attributions: dict[tuple[str, str], Tensor], + alive_indices: dict[str, list[int]], + n_blocks: int, + edge_threshold: float, + output_path: Path, +) -> None: + """Create and save the attribution graph visualization.""" + print("\nPlotting attribution graph...") + + all_layers = build_layer_list(n_blocks) + nodes_with_edges = find_nodes_with_edges(global_attributions, alive_indices, edge_threshold) + G, node_positions, edge_weights = build_graph( + all_layers, global_attributions, alive_indices, nodes_with_edges, n_blocks, edge_threshold + ) + + print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") + + fig, ax = plt.subplots(1, 1, figsize=(32, 12)) + + draw_nodes(ax, G, node_positions, all_layers) + draw_edges(ax, G, node_positions, edge_weights) + add_legend(ax) + + ax.set_title("Global Attribution Graph", fontsize=14, fontweight="bold") + ax.axis("off") + plt.tight_layout() + + fig.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"Saved to {output_path}") + plt.close(fig) + + +if __name__ == "__main__": + # Configuration + # wandb_id = "jyo9duz5" # ss_gpt2_simple-1.25M (4L) + # wandb_id = "c0k3z78g" # ss_gpt2_simple-2L + wandb_id = "33n6xjjt" # ss_gpt2_simple-1L (New) + n_blocks = 1 + edge_threshold = 1e-1 + + out_dir = get_out_dir() + + global_attributions, alive_indices = load_attributions(out_dir, wandb_id) + print(f"Loaded attributions for {len(global_attributions)} layer pairs") + print(f"Total alive components: {sum(len(v) for v in alive_indices.values())}") + + print_edge_statistics(global_attributions) + + edge_threshold_str = f"{edge_threshold:.1e}".replace(".0", "") + output_path = out_dir / f"attribution_graph_{wandb_id}_edge_threshold_{edge_threshold_str}.png" + + plot_attribution_graph( + global_attributions, + alive_indices, + n_blocks, + edge_threshold, + output_path, + ) diff --git a/spd/scripts/plot_local_attributions.py b/spd/scripts/plot_local_attributions.py new file mode 100644 index 000000000..d812e5030 --- /dev/null +++ b/spd/scripts/plot_local_attributions.py @@ -0,0 +1,679 @@ +# %% +"""Plot local attribution graph from saved .pt file.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Float +from matplotlib.collections import LineCollection +from torch import Tensor + +from spd.scripts.model_loading import get_out_dir + + +@dataclass +class PairAttribution: + source: str + target: str + attribution: Float[Tensor, "s_in trimmed_c_in s_out trimmed_c_out"] + trimmed_c_in_idxs: list[int] + trimmed_c_out_idxs: list[int] + is_kv_to_o_pair: bool + # Original alive masks (from model CI, before any optimization) + # Used to show "ghost" nodes that would have been active without CI optimization + # Shape: [n_seq, n_components] where True means alive at that (seq, component) + original_alive_mask_in: Float[Tensor, "seq C"] | None = None + original_alive_mask_out: Float[Tensor, "seq C"] | None = None + + +@dataclass +class NodeInfo: + """Information about a node in the attribution graph.""" + + layer: str + seq_pos: int + component_idx: int + x: float + y: float + importance: float + + +def get_layer_order() -> list[str]: + """Get the canonical ordering of sublayers within a block.""" + return [ + "wte", # Word token embeddings (first layer) + "attn.q_proj", + "attn.v_proj", + "attn.k_proj", + "attn.o_proj", + "mlp.c_fc", + "mlp.down_proj", + ] + + +# Sublayers that share a row (q, v, k displayed side by side) +QVK_SUBLAYERS = {"attn.q_proj", "attn.v_proj", "attn.k_proj"} + +# Column allocation for q, v, k: (n_cols, start_col) out of 12 total columns +# Layout: Q(2) | gap(1) | V(4) | gap(1) | K(4) = 12 total +QVK_LAYOUT: dict[str, tuple[int, int]] = { + "attn.q_proj": (2, 0), # 2 columns, starts at 0 + "attn.v_proj": (4, 3), # 4 columns, starts at 3 (1 col gap after q) + "attn.k_proj": (4, 8), # 4 columns, starts at 8 (1 col gap after v) +} +QVK_TOTAL_COLS = 12 + + +def get_layer_color(layer: str) -> str: + """Get color for a layer based on its type.""" + if layer == "wte": + return "#34495E" # Dark blue-gray for embeddings + + colors = { + "attn.q_proj": "#E67E22", # Orange + "attn.k_proj": "#27AE60", # Green + "attn.v_proj": "#F1C40F", # Yellow + "attn.o_proj": "#E74C3C", # Red + "mlp.c_fc": "#3498DB", # Blue + "mlp.down_proj": "#9B59B6", # Purple + } + for sublayer, color in colors.items(): + if sublayer in layer: + return color + return "#95A5A6" # Gray fallback + + +def parse_layer_name(layer: str) -> tuple[int, str]: + """Parse layer name into block index and sublayer type. + + E.g., "h.0.attn.q_proj" -> (0, "attn.q_proj") + "wte" -> (-1, "wte") + "output" -> (999, "output") + """ + if layer == "wte": + return -1, "wte" + if layer == "output": + return 999, "output" + + parts = layer.split(".") + block_idx = int(parts[1]) + sublayer = ".".join(parts[2:]) + return block_idx, sublayer + + +def compute_layer_y_positions( + attr_pairs: list[PairAttribution], +) -> dict[str, float]: + """Compute Y position for each layer. + + Layers are ordered by block, then by sublayer type within block. + q, v, k layers share the same row (same y position). + + Returns: + Dict mapping layer name to y position. + """ + # Collect all unique layers + layers = set() + for pair in attr_pairs: + layers.add(pair.source) + layers.add(pair.target) + + # Parse and sort + layer_order = get_layer_order() + parsed = [(layer, *parse_layer_name(layer)) for layer in layers] + + def sort_key(item: tuple[str, int, str]) -> tuple[int, int]: + _, block_idx, sublayer = item + sublayer_idx = layer_order.index(sublayer) if sublayer in layer_order else 999 + return (block_idx, sublayer_idx) + + sorted_layers = sorted(parsed, key=sort_key) + + # Assign Y positions, grouping q, v, k on the same row + y_positions = {} + current_y = 0.0 + prev_block_idx = None + prev_was_qvk = False + + for layer, block_idx, sublayer in sorted_layers: + is_qvk = sublayer in QVK_SUBLAYERS + + # Check if we should share y with previous layer + if is_qvk and prev_was_qvk and block_idx == prev_block_idx: + # Same row as previous q/v/k layer + y_positions[layer] = current_y + else: + # New row + if y_positions: # Not the first layer + current_y += 1.0 + y_positions[layer] = current_y + + prev_block_idx = block_idx + prev_was_qvk = is_qvk + + return y_positions + + +def compute_node_importances( + attr_pairs: list[PairAttribution], + n_seq: int, +) -> tuple[dict[str, Float[Tensor, "seq C"]], dict[str, Float[Tensor, "seq C"]]]: + """Compute importance values for nodes based on total attribution flow. + + Returns: + importances: Dict mapping layer -> tensor of shape [n_seq, max_component_idx+1]. + Importance is the sum of incoming and outgoing attributions. + original_alive_masks: Dict mapping layer -> bool tensor of shape [n_seq, max_component_idx+1]. + True means the component was alive at that (seq_pos, component) in original model CI. + """ + # First pass: determine max component index per layer (including original alive) + layer_max_c: dict[str, int] = {} + for pair in attr_pairs: + src_max = max(pair.trimmed_c_in_idxs) if pair.trimmed_c_in_idxs else 0 + tgt_max = max(pair.trimmed_c_out_idxs) if pair.trimmed_c_out_idxs else 0 + # Also consider original alive mask dimensions + if pair.original_alive_mask_in is not None: + src_max = max(src_max, pair.original_alive_mask_in.shape[1] - 1) + if pair.original_alive_mask_out is not None: + tgt_max = max(tgt_max, pair.original_alive_mask_out.shape[1] - 1) + layer_max_c[pair.source] = max(layer_max_c.get(pair.source, 0), src_max) + layer_max_c[pair.target] = max(layer_max_c.get(pair.target, 0), tgt_max) + + # Initialize importance tensors and original alive masks (on same device as attributions) + device = attr_pairs[0].attribution.device if attr_pairs else "cpu" + importances: dict[str, Float[Tensor, "seq C"]] = {} + original_alive_masks: dict[str, Float[Tensor, "seq C"]] = {} + for layer, max_c in layer_max_c.items(): + importances[layer] = torch.zeros(n_seq, max_c + 1, device=device) + original_alive_masks[layer] = torch.zeros(n_seq, max_c + 1, device=device, dtype=torch.bool) + + # Accumulate attribution magnitudes and original alive masks + for pair in attr_pairs: + attr = pair.attribution.abs() # [s_in, trimmed_c_in, s_out, trimmed_c_out] + + # Sum over output dimensions -> importance for source nodes + src_importance = attr.sum(dim=(2, 3)) # [s_in, trimmed_c_in] + for i, c_in in enumerate(pair.trimmed_c_in_idxs): + importances[pair.source][:, c_in] += src_importance[:, i] + + # Sum over input dimensions -> importance for target nodes + tgt_importance = attr.sum(dim=(0, 1)) # [s_out, trimmed_c_out] + for j, c_out in enumerate(pair.trimmed_c_out_idxs): + importances[pair.target][:, c_out] += tgt_importance[:, j] + + # Accumulate original alive masks (OR them together since multiple pairs may have same layer) + if pair.original_alive_mask_in is not None: + n_c = pair.original_alive_mask_in.shape[1] + original_alive_masks[pair.source][:, :n_c] |= pair.original_alive_mask_in + if pair.original_alive_mask_out is not None: + n_c = pair.original_alive_mask_out.shape[1] + original_alive_masks[pair.target][:, :n_c] |= pair.original_alive_mask_out + + return importances, original_alive_masks + + +def plot_local_graph( + attr_pairs: list[PairAttribution], + token_strings: list[str], + min_edge_weight: float = 0.001, + node_scale: float = 30.0, + edge_alpha_scale: float = 0.5, + figsize: tuple[float, float] | None = None, + max_grid_cols: int = 12, + output_token_labels: dict[int, str] | None = None, + output_prob_threshold: float | None = None, + output_probs_by_pos: dict[tuple[int, int], float] | None = None, +) -> plt.Figure: + """Plot the local attribution graph. + + Args: + attr_pairs: List of PairAttribution objects from compute_local_attributions. + token_strings: List of token strings for x-axis labels. + min_edge_weight: Minimum edge weight to display. + node_scale: Fixed size for all nodes. + edge_alpha_scale: Scale factor for edge transparency. + figsize: Figure size (width, height). Auto-computed if None. + max_grid_cols: Maximum number of columns in the grid per layer. + output_token_labels: Dict mapping output component indices to token strings. + output_prob_threshold: Threshold used for filtering output probabilities. + output_probs_by_pos: Dict mapping (seq_pos, component_idx) -> probability for output layer. + + Returns: + Matplotlib figure. + """ + n_seq = len(token_strings) + + # Compute node importances and original alive masks (per-position) + importances, original_alive_masks = compute_node_importances(attr_pairs, n_seq) + + # Compute layout + layer_y = compute_layer_y_positions(attr_pairs) + + # Auto-compute figure size + if figsize is None: + total_height = max(layer_y.values()) + figsize = (max(16, n_seq * 2), max(8, total_height * 1.2)) + + fig, ax = plt.subplots(figsize=figsize, facecolor="white") + ax.set_facecolor("#FAFAFA") + + # Collect all nodes and their positions + nodes: list[NodeInfo] = [] + ghost_nodes: list[NodeInfo] = [] # Nodes originally alive at this position but now dead + node_lookup: dict[tuple[str, int, int], NodeInfo] = {} # (layer, seq, comp) -> NodeInfo + + # X spacing: spread tokens across the plot + x_positions = np.linspace(0.1, 0.9, n_seq) + + # Grid layout parameters + col_spacing = 0.012 # Horizontal spacing between columns in grid + row_spacing = 0.08 # Vertical spacing between rows in grid + + for layer, y_center in layer_y.items(): + if layer not in importances: + continue + + layer_imp = importances[layer] # [n_seq, max_c+1] + alive_mask = layer_imp > 0 + + # Get original alive mask for this layer (per-position) + layer_original_mask = original_alive_masks.get(layer) + + # Find all components that are alive at ANY sequence position (current) + all_alive_components = torch.where(alive_mask.any(dim=0))[0].tolist() + + # Find all components that were originally alive at ANY sequence position + originally_alive_any_pos: list[int] = [] + if layer_original_mask is not None: + originally_alive_any_pos = torch.where(layer_original_mask.any(dim=0))[0].tolist() + + # Combine for layout purposes + all_components_for_layout = sorted( + set(all_alive_components) | set(originally_alive_any_pos) + ) + n_components = len(all_components_for_layout) + + if n_components == 0: + continue + + # Check if this is a q/v/k layer that shares a row + _, sublayer = parse_layer_name(layer) + is_qvk = sublayer in QVK_SUBLAYERS + is_output_layer = layer == "output" + + if is_qvk: + # Use the allocated columns for this sublayer + layer_max_cols, start_col = QVK_LAYOUT[sublayer] + else: + layer_max_cols = max_grid_cols + start_col = 0 + + # Calculate grid dimensions for this layer + n_rows = (n_components + layer_max_cols - 1) // layer_max_cols + n_cols = min(n_components, layer_max_cols) + + # Process each sequence position separately + for s in range(n_seq): + # Center the grid at this sequence position + x_base = x_positions[s] + + if is_output_layer: + # For output layer: only show components active at THIS position + active_at_pos = [c for c in all_alive_components if alive_mask[s, c]] + n_active = len(active_at_pos) + if n_active == 0: + continue + + n_rows_pos = (n_active + max_grid_cols - 1) // max_grid_cols + n_cols_pos = min(n_active, max_grid_cols) + + # Use full width of max_grid_cols for output layer to spread out labels + max_width = (max_grid_cols - 1) * col_spacing + output_col_spacing = max_width / max(n_cols_pos - 1, 1) if n_cols_pos > 1 else 0 + + for local_idx, c in enumerate(active_at_pos): + col = local_idx % max_grid_cols + row = local_idx // max_grid_cols + + x_offset = (col - (n_cols_pos - 1) / 2) * output_col_spacing + y_offset = (row - (n_rows_pos - 1) / 2) * row_spacing + + x = x_base + x_offset + y = y_center + y_offset + + imp = layer_imp[s, c].item() + node = NodeInfo( + layer=layer, + seq_pos=s, + component_idx=c, + x=x, + y=y, + importance=imp, + ) + nodes.append(node) + node_lookup[(layer, s, c)] = node + else: + # For other layers: arrange all components in grid (including ghost nodes) + for local_idx, c in enumerate(all_components_for_layout): + col = local_idx % layer_max_cols + row = local_idx // layer_max_cols + + if is_qvk: + # Position within the allocated horizontal segment + # Center of this sublayer's segment within the total QVK row + segment_center = start_col + layer_max_cols / 2 + total_center = QVK_TOTAL_COLS / 2 + # Offset from center of entire row + segment_offset = (segment_center - total_center) * col_spacing + # Position within segment, centered + x_offset = segment_offset + (col - (n_cols - 1) / 2) * col_spacing + else: + # Position within grid, centered on sequence position + x_offset = (col - (n_cols - 1) / 2) * col_spacing + + y_offset = (row - (n_rows - 1) / 2) * row_spacing + + x = x_base + x_offset + y = y_center + y_offset + + imp = layer_imp[s, c].item() if c < layer_imp.shape[1] else 0.0 + node = NodeInfo( + layer=layer, + seq_pos=s, + component_idx=c, + x=x, + y=y, + importance=imp, + ) + + # Check if this is a ghost node at THIS position + # Ghost = originally alive at this (s, c) but not currently alive at this (s, c) + is_currently_alive_here = ( + alive_mask[s, c].item() if c < alive_mask.shape[1] else False + ) + is_originally_alive_here = ( + layer_original_mask is not None + and c < layer_original_mask.shape[1] + and layer_original_mask[s, c].item() + ) + is_ghost = is_originally_alive_here and not is_currently_alive_here + + if is_ghost: + ghost_nodes.append(node) + else: + nodes.append(node) + node_lookup[(layer, s, c)] = node + + # Collect edges + edges: list[tuple[NodeInfo, NodeInfo, float]] = [] + + for pair in attr_pairs: + attr = pair.attribution # [s_in, trimmed_c_in, s_out, trimmed_c_out] + + for i, c_in in enumerate(pair.trimmed_c_in_idxs): + for j, c_out in enumerate(pair.trimmed_c_out_idxs): + for s_in in range(attr.shape[0]): + for s_out in range(attr.shape[2]): + weight = attr[s_in, i, s_out, j].abs().item() + if weight < min_edge_weight: + continue + + src_key = (pair.source, s_in, c_in) + tgt_key = (pair.target, s_out, c_out) + + if src_key in node_lookup and tgt_key in node_lookup: + src_node = node_lookup[src_key] + tgt_node = node_lookup[tgt_key] + edges.append((src_node, tgt_node, weight)) + + edges_by_target: dict[tuple[str, int, int], list[tuple[NodeInfo, NodeInfo, float]]] = {} + for src, tgt, w in edges: + key = (tgt.layer, tgt.seq_pos, tgt.component_idx) + if key not in edges_by_target: + edges_by_target[key] = [] + edges_by_target[key].append((src, tgt, w)) + + sorted_edges = [] + for target_edges in edges_by_target.values(): + sorted_edges.extend(sorted(target_edges, key=lambda e: e[2], reverse=True)) + edges = sorted_edges + + # Normalize edge weights for alpha + if edges: + edge_weights = [e[2] for e in edges] + max_edge = max(edge_weights) + if max_edge > 0: + edges = [(s, t, w / max_edge) for s, t, w in edges] + + # Track which nodes have edges + nodes_with_edges = set() + for src, tgt, _ in edges: + nodes_with_edges.add((src.layer, src.seq_pos, src.component_idx)) + nodes_with_edges.add((tgt.layer, tgt.seq_pos, tgt.component_idx)) + + # Draw edges + if edges: + lines = [] + alphas = [] + for src, tgt, w in edges: + lines.append([(src.x, src.y), (tgt.x, tgt.y)]) + alphas.append(w * edge_alpha_scale) + + lc = LineCollection( + lines, + colors=[(0.5, 0.5, 0.5, a) for a in alphas], + linewidths=0.5, + zorder=1, + ) + ax.add_collection(lc) + + # Draw ghost nodes first (so they appear behind regular nodes) + # Ghost nodes = originally alive (in model CI) but not currently alive (in optimized CI) + for node in ghost_nodes: + ax.scatter( + node.x, + node.y, + s=node_scale, + c="#909090", # Darker gray than nodes-without-edges + edgecolors="white", + linewidths=0.5, + zorder=1.5, # Behind regular nodes + alpha=0.5, + ) + + # Draw regular nodes + for node in nodes: + node_key = (node.layer, node.seq_pos, node.component_idx) + has_edges = node_key in nodes_with_edges + + if has_edges: + color = get_layer_color(node.layer) + alpha = 0.9 + else: + color = "#D3D3D3" + alpha = 0.3 + + ax.scatter( + node.x, + node.y, + s=node_scale, + c=color, + edgecolors="white", + linewidths=0.5, + zorder=2, + alpha=alpha, + ) + + # Add token label and probability for output layer nodes + if node.layer == "output" and output_token_labels is not None and has_edges: + token_label = output_token_labels.get(node.component_idx, "") + if token_label: + # Build label with probability if available + label_text = repr(token_label)[1:-1] # Strip quotes but show escape chars + if output_probs_by_pos is not None: + prob = output_probs_by_pos.get((node.seq_pos, node.component_idx)) + if prob is not None: + label_text = f"({prob:.2f})\n{label_text}" + ax.annotate( + label_text, + (node.x, node.y), + xytext=(0, 6), + textcoords="offset points", + fontsize=6, + ha="center", + va="bottom", + alpha=0.8, + ) + + # Configure axes + total_height = max(layer_y.values()) + ax.set_xlim(0, 1) + # Extra top margin to fit output token labels with probabilities + ax.set_ylim(-0.5, total_height + 1.0) + + # X-axis: token labels + ax.set_xticks(x_positions) + ax.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=9) + ax.xaxis.set_ticks_position("bottom") + + # Y-axis: layer labels (group q/v/k into single label) + layer_names_sorted = sorted(layer_y.keys(), key=lambda x: layer_y[x]) + # Deduplicate y positions and create combined labels for q/v/k rows + y_to_layers: dict[float, list[str]] = {} + for layer in layer_names_sorted: + y = layer_y[layer] + if y not in y_to_layers: + y_to_layers[y] = [] + y_to_layers[y].append(layer) + + layer_centers = sorted(y_to_layers.keys()) + layer_labels = [] + for y in layer_centers: + layers_at_y = y_to_layers[y] + if len(layers_at_y) > 1: + # Multiple layers at same y (q/v/k row) + # Extract block prefix and combine sublayer names + block_idx, _ = parse_layer_name(layers_at_y[0]) + sublayers = [parse_layer_name(layer)[1] for layer in layers_at_y] + # Order: q, v, k + ordered = [] + for sub in ["attn.q_proj", "attn.v_proj", "attn.k_proj"]: + if sub in sublayers: + ordered.append(sub.split(".")[-1]) + label = f"h.{block_idx}\n" + "/".join(ordered) + elif layers_at_y[0] == "output" and output_prob_threshold is not None: + label = f"output\n(prob>{output_prob_threshold})" + else: + label = layers_at_y[0].replace(".", "\n", 1) + layer_labels.append(label) + + ax.set_yticks(layer_centers) + ax.set_yticklabels(layer_labels, fontsize=9) + + # Add horizontal lines to separate layers + for y in layer_y.values(): + ax.axhline(y=y - 0.5, color="gray", linestyle="--", linewidth=0.5, alpha=0.3) + + # Grid + ax.grid(False) + + # Title + ax.set_title("Local Attribution Graph", fontsize=14, fontweight="bold", pad=10) + + # Legend for layer colors + layer_order = get_layer_order() + legend_elements = [] + for sublayer in reversed(layer_order): + color = get_layer_color(sublayer) + legend_elements.append( + plt.scatter([], [], c=color, s=50, label=sublayer, edgecolors="white") + ) + # Add ghost node to legend if there are any + if ghost_nodes: + legend_elements.append( + plt.scatter([], [], c="#909090", s=50, label="ghost (orig. alive)", edgecolors="white") + ) + ax.legend( + handles=legend_elements, + loc="upper left", + bbox_to_anchor=(1.01, 1), + fontsize=8, + framealpha=0.9, + ) + + plt.tight_layout() + return fig + + +def load_and_plot( + pt_path: Path, + output_path: Path | None = None, + **plot_kwargs: Any, +) -> plt.Figure: + """Load attributions from .pt file and create plot. + + Args: + pt_path: Path to the saved .pt file. + output_path: Optional path to save the figure. + **plot_kwargs: Additional kwargs passed to plot_local_attribution_graph. + + Returns: + Matplotlib figure. + """ + data = torch.load(pt_path, weights_only=False) + attr_pairs: list[PairAttribution] = data["attr_pairs"] + token_strings: list[str] = data["token_strings"] + output_token_labels: dict[int, str] | None = data.get("output_token_labels") + output_prob_threshold: float | None = data.get("output_prob_threshold") + output_probs_by_pos: dict[tuple[int, int], float] | None = data.get("output_probs_by_pos") + + print(f"Loaded attributions from {pt_path}") + print(f" Prompt: {data.get('prompt', 'N/A')!r}") + print(f" Tokens: {token_strings}") + print(f" Number of layer pairs: {len(attr_pairs)}") + + fig = plot_local_graph( + attr_pairs=attr_pairs, + token_strings=token_strings, + output_token_labels=output_token_labels, + output_prob_threshold=output_prob_threshold, + output_probs_by_pos=output_probs_by_pos, + **plot_kwargs, + ) + + if output_path is not None: + fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") + print(f"Saved figure to {output_path}") + + return fig + + +# %% +if __name__ == "__main__": + # Configuration + # wandb_id = "33n6xjjt" # ss_gpt2_simple-1L (new) + # wandb_id = "c0k3z78g" # ss_gpt2_simple-2L + wandb_id = "jyo9duz5" # ss_gpt2_simple-1.25M (4L) + + out_dir = get_out_dir() + pt_path = out_dir / f"local_attributions_{wandb_id}.pt" + + if not pt_path.exists(): + raise FileNotFoundError( + f"Local attributions file not found: {pt_path}\n" + "Run calc_local_attributions.py first to generate the data." + ) + + output_path = out_dir / f"local_attribution_graph_{wandb_id}.png" + + fig = load_and_plot( + pt_path=pt_path, + output_path=output_path, + min_edge_weight=0.0001, + node_scale=30.0, + edge_alpha_scale=0.7, + )