Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c887c05
WIP global attribution calcs
danbraunai-goodfire Nov 24, 2025
975f1fc
First draft of calculation and plotting of attrs
danbraunai-goodfire Nov 25, 2025
8e40188
Make ci_mean alive threshold a hyperparam
danbraunai-goodfire Nov 25, 2025
8530335
Misc cleaning
danbraunai-goodfire Nov 26, 2025
4c45678
Add naive global attribution calc with for loops
danbraunai-goodfire Nov 26, 2025
03dd761
Speedups
danbraunai-goodfire Nov 26, 2025
f9db75c
Remove stray profiler comments
danbraunai-goodfire Nov 26, 2025
929bc65
More speedups
danbraunai-goodfire Nov 27, 2025
52d449a
remove double multiplication by ci weights in in_layers
danbraunai-goodfire Nov 27, 2025
04c093f
Normalise attrs and move to spd/scripts
danbraunai-goodfire Nov 27, 2025
85ee006
Simplify get_sources_by_target
danbraunai-goodfire Nov 27, 2025
6915a1f
Misc tweaks
danbraunai-goodfire Nov 27, 2025
2793043
Normalize over sum to output node
danbraunai-goodfire Nov 27, 2025
c471c29
Add more thresholds
danbraunai-goodfire Nov 27, 2025
4d7c0b3
Add local attribution calcs
danbraunai-goodfire Nov 28, 2025
344671b
Remove grad_outputs
danbraunai-goodfire Nov 28, 2025
563bbbb
Add new 1L model
danbraunai-goodfire Nov 28, 2025
05b9d5c
Merge branch 'main' into feature/global-attr
danbraunai-goodfire Nov 28, 2025
958f150
Allow diff w.r.t multiple inputs in local attr
danbraunai-goodfire Nov 28, 2025
c341c05
Misc removals
danbraunai-goodfire Nov 28, 2025
7d3fc5e
Add wte (NOTE: breaks calc_global_attributions)
danbraunai-goodfire Nov 28, 2025
3c0ebae
Add wte to calc_global_attributions.py
danbraunai-goodfire Nov 28, 2025
631013a
Add plot_local_attributions.py
danbraunai-goodfire Nov 28, 2025
7774afa
Tweak matplotlib plot
danbraunai-goodfire Nov 28, 2025
f580c83
Add output and fix masks
danbraunai-goodfire Nov 29, 2025
16353c9
Minor tweaks
danbraunai-goodfire Nov 29, 2025
48e406d
Cleanup plotting
danbraunai-goodfire Dec 1, 2025
56370b1
Show l0 for final seq position
danbraunai-goodfire Dec 1, 2025
6bf3c31
Use ci-masks and minor tweaks
danbraunai-goodfire Dec 1, 2025
89913de
Add optim_cis
danbraunai-goodfire Dec 1, 2025
13e2fdc
Merge branch 'main' into feature/global-attr
danbraunai-goodfire Dec 1, 2025
787ae37
Simplify losses in run_optim
danbraunai-goodfire Dec 2, 2025
9757306
Use stoch masks instead of ci masks in optim_cis
danbraunai-goodfire Dec 2, 2025
2dae809
Tweaks to run_optim_cis.py
danbraunai-goodfire Dec 2, 2025
46ad296
Minor clean
danbraunai-goodfire Dec 2, 2025
c31a67d
Support multiple prompts in calc_local_attributions
danbraunai-goodfire Dec 2, 2025
6aa6575
Misc cleaning
danbraunai-goodfire Dec 2, 2025
e0793a5
Misc minor cleaning
danbraunai-goodfire Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions spd/models/component_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ class ComponentModel(LoadableModule):
`LlamaForCausalLM`, `AutoModelForCausalLM`) as long as its sub-module names
match the patterns you pass in `target_module_patterns`.

Forward passes support optional component replacement and/or input caching:
Forward passes support optional component replacement and/or caching:
- No args: Standard forward pass of the target model
- With mask_infos: Components replace the specified modules via forward hooks
- With cache_type="input": Input activations are cached for the specified modules
- With cache_type="component_acts": Component activations are cached for the specified modules
- Both can be used simultaneously for component forward pass with input caching

We register components and causal importance functions (ci_fns) as modules in this class in order to have them update
Expand Down Expand Up @@ -307,7 +308,7 @@ def __call__(
self,
*args: Any,
mask_infos: dict[str, ComponentsMaskInfo] | None = None,
cache_type: Literal["input"],
cache_type: Literal["component_acts", "input"],
**kwargs: Any,
) -> OutputWithCache: ...

Expand All @@ -329,31 +330,30 @@ def forward(
self,
*args: Any,
mask_infos: dict[str, ComponentsMaskInfo] | None = None,
cache_type: Literal["input", "none"] = "none",
cache_type: Literal["component_acts", "input", "none"] = "none",
**kwargs: Any,
) -> Tensor | OutputWithCache:
"""Forward pass with optional component replacement and/or input caching.

This method handles the following 4 cases:
1. mask_infos is None and cache_type is "none": Regular forward pass.
2. mask_infos is None and cache_type is "input": Forward pass with input caching on
all modules in self.target_module_paths.
3. mask_infos is not None and cache_type is "input": Forward pass with component replacement
and input caching on the modules provided in mask_infos.
2. mask_infos is None and cache_type is "input" or "component_acts": Forward pass with
caching on all modules in self.target_module_paths.
3. mask_infos is not None and cache_type is "input" or "component_acts": Forward pass with
component replacement and caching on the modules provided in mask_infos.
4. mask_infos is not None and cache_type is "none": Forward pass with component replacement
on the modules provided in mask_infos and no caching.

We use the same _components_and_cache_hook for cases 2, 3, and 4, and don't use any hooks
for case 1.

Args:
mask_infos: Dictionary mapping module names to ComponentsMaskInfo.
If provided, those modules will be replaced with their components.
cache_type: If "input", cache the inputs to the modules provided in mask_infos. If
mask_infos is None, cache the inputs to all modules in self.target_module_paths.
cache_type: If "input" or "component_acts", cache the inputs or component acts to the
modules provided in mask_infos. If "none", no caching is done. If mask_infos is None,
cache the inputs or component acts to all modules in self.target_module_paths.

Returns:
OutputWithCache object if cache_type is "input", otherwise the model output tensor.
OutputWithCache object if cache_type is "input" or "component_acts", otherwise the
model output tensor.
"""
if mask_infos is None and cache_type == "none":
# No hooks needed. Do a regular forward pass of the target model.
Expand Down Expand Up @@ -382,7 +382,7 @@ def forward(

out = self._extract_output(raw_out)
match cache_type:
case "input":
case "input" | "component_acts":
return OutputWithCache(output=out, cache=cache)
case "none":
return out
Expand All @@ -396,7 +396,7 @@ def _components_and_cache_hook(
module_name: str,
components: Components | None,
mask_info: ComponentsMaskInfo | None,
cache_type: Literal["input", "none"],
cache_type: Literal["component_acts", "input", "none"],
cache: dict[str, Tensor],
) -> Any | None:
"""Unified hook function that handles both component replacement and caching.
Expand All @@ -409,7 +409,7 @@ def _components_and_cache_hook(
module_name: Name of the module in the target model
components: Component replacement (if using components)
mask_info: Mask information (if using components)
cache_type: Whether to cache the input
cache_type: Whether to cache the component acts, input, or none
cache: Cache dictionary to populate (if cache_type is not None)

Returns:
Expand All @@ -420,7 +420,6 @@ def _components_and_cache_hook(
assert len(kwargs) == 0, "Expected no keyword arguments"
x = args[0]
assert isinstance(x, Tensor), "Expected input tensor"
assert cache_type in ["input", "none"], "Expected cache_type to be 'input' or 'none'"

if cache_type == "input":
cache[module_name] = x
Expand All @@ -430,11 +429,16 @@ def _components_and_cache_hook(
f"Only supports single-tensor outputs, got {type(output)}"
)

component_acts_cache = {} if cache_type == "component_acts" else None
components_out = components(
x,
mask=mask_info.component_mask,
weight_delta_and_mask=mask_info.weight_delta_and_mask,
component_acts_cache=component_acts_cache,
)
if component_acts_cache is not None:
for k, v in component_acts_cache.items():
cache[f"{module_name}_{k}"] = v

if mask_info.routing_mask == "all":
return components_out
Expand Down
17 changes: 15 additions & 2 deletions spd/models/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def forward(
x: Float[Tensor, "... d_in"],
mask: Float[Tensor, "... C"] | None = None,
weight_delta_and_mask: WeightDeltaAndMask | None = None,
component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = None,
) -> Float[Tensor, "... d_out"]:
"""Forward pass through V and U matrices.

Expand All @@ -199,13 +200,18 @@ def forward(
weight_delta_and_mask: Optional tuple of tensors containing:
0: the weight differences between the target model and summed component weights
1: mask over the weight delta component for each sample
component_acts_cache: Cache dictionary to populate with component acts
Returns:
output: The summed output across all components
"""
component_acts = self.get_inner_acts(x)
if component_acts_cache is not None:
component_acts_cache["pre_detach"] = component_acts
component_acts = component_acts.detach().requires_grad_(True)
component_acts_cache["post_detach"] = component_acts

if mask is not None:
component_acts *= mask
component_acts = component_acts * mask

out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out")

Expand Down Expand Up @@ -254,6 +260,7 @@ def forward(
x: Int[Tensor, "..."],
mask: Float[Tensor, "... C"] | None = None,
weight_delta_and_mask: WeightDeltaAndMask | None = None,
component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = None,
) -> Float[Tensor, "... embedding_dim"]:
"""Forward through the embedding component using indexing instead of one-hot matmul.

Expand All @@ -263,13 +270,19 @@ def forward(
weight_delta_and_mask: Optional tuple of tensors containing:
0: the weight differences between the target model and summed component weights
1: mask over the weight delta component for each sample
component_acts_cache: Cache dictionary to populate with component acts
"""
assert x.dtype == torch.long, "x must be an integer tensor"

component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x)

if component_acts_cache is not None:
component_acts_cache["pre_detach"] = component_acts
component_acts = component_acts.detach().requires_grad_(True)
component_acts_cache["post_detach"] = component_acts

if mask is not None:
component_acts *= mask
component_acts = component_acts * mask

out = einops.einsum(component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim")

Expand Down
Loading