diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index bb3c5f090..4fd9bf2c1 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -153,6 +153,12 @@ def calc_multibatch_pgd_masked_recon_loss( return final_loss / final_n_examples +def _is_shared_mask(adv_sources: dict[str, Tensor]) -> bool: + """Check if adv_sources have singleton batch dims (shared across batch).""" + first_adv = next(iter(adv_sources.values())) + return all(d == 1 for d in first_adv.shape[:-1]) + + def _forward_with_adv_sources( model: ComponentModel, batch: Int[Tensor, "..."] | Float[Tensor, "..."], @@ -164,23 +170,50 @@ def _forward_with_adv_sources( output_loss_type: Literal["mse", "kl"], batch_dims: tuple[int, ...], ): - expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} - adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] - match weight_deltas: - case None: - weight_deltas_and_masks = None - adv_sources_components = expanded_adv_sources - case dict(): - weight_deltas_and_masks = { - k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas - } - adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} - - mask_infos = make_mask_infos( - component_masks=_interpolate_component_mask(ci, adv_sources_components), - weight_deltas_and_masks=weight_deltas_and_masks, - routing_masks=routing_masks, - ) + # When mask is shared across batch (singleton batch dims), skip expansion and CI interpolation. + # This enables weight-masking optimization in LinearComponents.forward(). + mask_is_shared = _is_shared_mask(adv_sources) + + if mask_is_shared: + # Keep masks with singleton batch dims for weight-masking optimization + adv_sources_components: dict[str, Float[Tensor, ...]] + match weight_deltas: + case None: + weight_deltas_and_masks = None + adv_sources_components = adv_sources + case dict(): + weight_deltas_and_masks = { + k: (weight_deltas[k], adv_sources[k][..., -1].expand(*batch_dims)) + for k in weight_deltas + } + adv_sources_components = {k: v[..., :-1] for k, v in adv_sources.items()} + + # Skip CI interpolation for shared masks - use adv_sources directly as masks. + # This is valid because: (1) CI interpolation produces per-example masks which defeats + # the shared optimization, (2) for PGD we want to find worst-case masks independent of CI. + mask_infos = make_mask_infos( + component_masks=adv_sources_components, + weight_deltas_and_masks=weight_deltas_and_masks, + routing_masks=routing_masks, + ) + else: + expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} + match weight_deltas: + case None: + weight_deltas_and_masks = None + adv_sources_components = expanded_adv_sources + case dict(): + weight_deltas_and_masks = { + k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas + } + adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} + + mask_infos = make_mask_infos( + component_masks=_interpolate_component_mask(ci, adv_sources_components), + weight_deltas_and_masks=weight_deltas_and_masks, + routing_masks=routing_masks, + ) + out = model(batch, mask_infos=mask_infos) sum_loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) diff --git a/spd/models/components.py b/spd/models/components.py index 7686e4cbf..a05b60a28 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -182,6 +182,10 @@ def weight(self) -> Float[Tensor, "d_out d_in"]: def get_inner_acts(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... C"]: return einops.einsum(x, self.V, "... d_in, d_in C -> ... C") + def _is_shared_mask(self, mask: Tensor) -> bool: + """Check if mask is shared across batch (all dims except last are singleton).""" + return all(d == 1 for d in mask.shape[:-1]) + @override def forward( self, @@ -194,7 +198,9 @@ def forward( Args: x: Input tensor - mask: Tensor which masks parameter components. + mask: Tensor which masks parameter components. If the mask has singleton dimensions + for all batch dims (i.e., shape [1, ..., 1, C]), an optimized weight-masking path + is used that precomputes masked weights instead of masking per-element. weight_delta_and_mask: Optional tuple of tensors containing: 0: the weight differences between the target model and summed component weights 1: mask over the weight delta component for each sample @@ -202,16 +208,32 @@ def forward( Returns: output: The summed output across all components """ - component_acts = self.get_inner_acts(x) - if component_acts_cache is not None: - component_acts_cache["pre_detach"] = component_acts - component_acts = component_acts.detach().requires_grad_(True) - component_acts_cache["post_detach"] = component_acts - - if mask is not None: - component_acts = component_acts * mask + # Optimized path: when mask is shared across batch and we don't need component_acts_cache, + # precompute masked weights and do a single matmul instead of two matmuls with intermediate. + # This is more efficient because: (1) avoids storing batch*seq*C intermediate tensor, + # (2) single fused matmul is faster than two separate matmuls. + use_weight_masking = ( + mask is not None and component_acts_cache is None and self._is_shared_mask(mask) + ) - out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out") + if use_weight_masking: + assert mask is not None # for type checker + squeezed_mask = mask.view(self.C) + # V @ diag(mask) @ U = (V * mask) @ U + masked_V = self.V * squeezed_mask + W = einops.einsum(masked_V, self.U, "d_in C, C d_out -> d_in d_out") + out = einops.einsum(x, W, "... d_in, d_in d_out -> ... d_out") + else: + component_acts = self.get_inner_acts(x) + if component_acts_cache is not None: + component_acts_cache["pre_detach"] = component_acts + component_acts = component_acts.detach().requires_grad_(True) + component_acts_cache["post_detach"] = component_acts + + if mask is not None: + component_acts = component_acts * mask + + out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out") if weight_delta_and_mask is not None: weight_delta, weight_delta_mask = weight_delta_and_mask @@ -252,6 +274,10 @@ def weight(self) -> Float[Tensor, "vocab_size embedding_dim"]: def get_inner_acts(self, x: Int[Tensor, "..."]) -> Float[Tensor, "... C"]: return self.V[x] + def _is_shared_mask(self, mask: Tensor) -> bool: + """Check if mask is shared across batch (all dims except last are singleton).""" + return all(d == 1 for d in mask.shape[:-1]) + @override def forward( self, @@ -264,7 +290,9 @@ def forward( Args: x: Input tensor of token indices - mask: Tensor which masks parameter components. May be boolean or float. + mask: Tensor which masks parameter components. If the mask has singleton dimensions + for all batch dims (i.e., shape [1, ..., 1, C]), an optimized path is used that + precomputes the masked embedding table instead of masking per-element. weight_delta_and_mask: Optional tuple of tensors containing: 0: the weight differences between the target model and summed component weights 1: mask over the weight delta component for each sample @@ -272,17 +300,33 @@ def forward( """ assert x.dtype == torch.long, "x must be an integer tensor" - component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x) - - if component_acts_cache is not None: - component_acts_cache["pre_detach"] = component_acts - component_acts = component_acts.detach().requires_grad_(True) - component_acts_cache["post_detach"] = component_acts - - if mask is not None: - component_acts = component_acts * mask + # Optimized path: when mask is shared across batch and we don't need component_acts_cache, + # precompute masked embedding table and index into it instead of per-element masking. + use_weight_masking = ( + mask is not None and component_acts_cache is None and self._is_shared_mask(mask) + ) - out = einops.einsum(component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim") + if use_weight_masking: + assert mask is not None # for type checker + squeezed_mask = mask.view(self.C) + # (V * mask) @ U gives masked embedding table + masked_V = self.V * squeezed_mask + W = einops.einsum(masked_V, self.U, "vocab C, C d_emb -> vocab d_emb") + out = W[x] + else: + component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x) + + if component_acts_cache is not None: + component_acts_cache["pre_detach"] = component_acts + component_acts = component_acts.detach().requires_grad_(True) + component_acts_cache["post_detach"] = component_acts + + if mask is not None: + component_acts = component_acts * mask + + out = einops.einsum( + component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim" + ) if weight_delta_and_mask is not None: weight_delta, weight_delta_mask = weight_delta_and_mask diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 5787c2f07..b2ecf08b4 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -525,3 +525,88 @@ def forward(self, x: Tensor) -> Tensor: # but it should be the same for the second example (where it's not routed to components) assert torch.allclose(cm_routed_out[1], target_out[1]) + + +def test_shared_mask_optimization_linear(): + """Test that shared mask optimization produces same results as per-element masking.""" + batch_size = 8 + seq_len = 16 + d_in = 32 + d_out = 64 + C = 10 + + linear = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None) + + x = torch.randn(batch_size, seq_len, d_in) + + # Shared mask (singleton batch dims) - triggers weight-masking optimization + shared_mask = torch.rand(1, 1, C) + + # Expanded mask (same values, full batch dims) - uses standard path + expanded_mask = shared_mask.expand(batch_size, seq_len, C).clone() + + # Both should produce same output + out_shared = linear(x, mask=shared_mask) + out_expanded = linear(x, mask=expanded_mask) + + torch.testing.assert_close(out_shared, out_expanded) + + # Verify gradients flow correctly through both paths + x_shared = x.clone().requires_grad_(True) + x_expanded = x.clone().requires_grad_(True) + mask_shared = shared_mask.clone().requires_grad_(True) + mask_expanded = expanded_mask.clone().requires_grad_(True) + + loss_shared = linear(x_shared, mask=mask_shared).sum() + loss_expanded = linear(x_expanded, mask=mask_expanded).sum() + + loss_shared.backward() + loss_expanded.backward() + + # Gradients wrt x should be the same + torch.testing.assert_close(x_shared.grad, x_expanded.grad) + + # Gradient wrt mask - shared should be the sum over batch dims of expanded + assert mask_expanded.grad is not None + expected_mask_grad = mask_expanded.grad.sum(dim=(0, 1), keepdim=True) + torch.testing.assert_close(mask_shared.grad, expected_mask_grad) + + +def test_shared_mask_optimization_embedding(): + """Test that shared mask optimization produces same results for embeddings.""" + batch_size = 8 + seq_len = 16 + vocab_size = 100 + embedding_dim = 64 + C = 10 + + embedding = EmbeddingComponents(C=C, vocab_size=vocab_size, embedding_dim=embedding_dim) + + x = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # Shared mask (singleton batch dims) - triggers weight-masking optimization + shared_mask = torch.rand(1, 1, C) + + # Expanded mask (same values, full batch dims) - uses standard path + expanded_mask = shared_mask.expand(batch_size, seq_len, C).clone() + + # Both should produce same output + out_shared = embedding(x, mask=shared_mask) + out_expanded = embedding(x, mask=expanded_mask) + + torch.testing.assert_close(out_shared, out_expanded) + + # Verify gradients flow correctly + mask_shared = shared_mask.clone().requires_grad_(True) + mask_expanded = expanded_mask.clone().requires_grad_(True) + + loss_shared = embedding(x, mask=mask_shared).sum() + loss_expanded = embedding(x, mask=mask_expanded).sum() + + loss_shared.backward() + loss_expanded.backward() + + # Gradient wrt mask - shared should be the sum over batch dims of expanded + assert mask_expanded.grad is not None + expected_mask_grad = mask_expanded.grad.sum(dim=(0, 1), keepdim=True) + torch.testing.assert_close(mask_shared.grad, expected_mask_grad)