Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions spd/metrics/pgd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ def calc_multibatch_pgd_masked_recon_loss(
return final_loss / final_n_examples


def _is_shared_mask(adv_sources: dict[str, Tensor]) -> bool:
"""Check if adv_sources have singleton batch dims (shared across batch)."""
first_adv = next(iter(adv_sources.values()))
return all(d == 1 for d in first_adv.shape[:-1])


def _forward_with_adv_sources(
model: ComponentModel,
batch: Int[Tensor, "..."] | Float[Tensor, "..."],
Expand All @@ -164,23 +170,50 @@ def _forward_with_adv_sources(
output_loss_type: Literal["mse", "kl"],
batch_dims: tuple[int, ...],
):
expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()}
adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]]
match weight_deltas:
case None:
weight_deltas_and_masks = None
adv_sources_components = expanded_adv_sources
case dict():
weight_deltas_and_masks = {
k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas
}
adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()}

mask_infos = make_mask_infos(
component_masks=_interpolate_component_mask(ci, adv_sources_components),
weight_deltas_and_masks=weight_deltas_and_masks,
routing_masks=routing_masks,
)
# When mask is shared across batch (singleton batch dims), skip expansion and CI interpolation.
# This enables weight-masking optimization in LinearComponents.forward().
mask_is_shared = _is_shared_mask(adv_sources)

if mask_is_shared:
# Keep masks with singleton batch dims for weight-masking optimization
adv_sources_components: dict[str, Float[Tensor, ...]]
match weight_deltas:
case None:
weight_deltas_and_masks = None
adv_sources_components = adv_sources
case dict():
weight_deltas_and_masks = {
k: (weight_deltas[k], adv_sources[k][..., -1].expand(*batch_dims))
for k in weight_deltas
}
adv_sources_components = {k: v[..., :-1] for k, v in adv_sources.items()}

# Skip CI interpolation for shared masks - use adv_sources directly as masks.
# This is valid because: (1) CI interpolation produces per-example masks which defeats
# the shared optimization, (2) for PGD we want to find worst-case masks independent of CI.
mask_infos = make_mask_infos(
component_masks=adv_sources_components,
weight_deltas_and_masks=weight_deltas_and_masks,
routing_masks=routing_masks,
)
else:
expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()}
match weight_deltas:
case None:
weight_deltas_and_masks = None
adv_sources_components = expanded_adv_sources
case dict():
weight_deltas_and_masks = {
k: (weight_deltas[k], expanded_adv_sources[k][..., -1]) for k in weight_deltas
}
adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()}

mask_infos = make_mask_infos(
component_masks=_interpolate_component_mask(ci, adv_sources_components),
weight_deltas_and_masks=weight_deltas_and_masks,
routing_masks=routing_masks,
)

out = model(batch, mask_infos=mask_infos)

sum_loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type)
Expand Down
86 changes: 65 additions & 21 deletions spd/models/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def weight(self) -> Float[Tensor, "d_out d_in"]:
def get_inner_acts(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... C"]:
return einops.einsum(x, self.V, "... d_in, d_in C -> ... C")

def _is_shared_mask(self, mask: Tensor) -> bool:
"""Check if mask is shared across batch (all dims except last are singleton)."""
return all(d == 1 for d in mask.shape[:-1])

@override
def forward(
self,
Expand All @@ -194,24 +198,42 @@ def forward(

Args:
x: Input tensor
mask: Tensor which masks parameter components.
mask: Tensor which masks parameter components. If the mask has singleton dimensions
for all batch dims (i.e., shape [1, ..., 1, C]), an optimized weight-masking path
is used that precomputes masked weights instead of masking per-element.
weight_delta_and_mask: Optional tuple of tensors containing:
0: the weight differences between the target model and summed component weights
1: mask over the weight delta component for each sample
component_acts_cache: Cache dictionary to populate with component acts
Returns:
output: The summed output across all components
"""
component_acts = self.get_inner_acts(x)
if component_acts_cache is not None:
component_acts_cache["pre_detach"] = component_acts
component_acts = component_acts.detach().requires_grad_(True)
component_acts_cache["post_detach"] = component_acts

if mask is not None:
component_acts = component_acts * mask
# Optimized path: when mask is shared across batch and we don't need component_acts_cache,
# precompute masked weights and do a single matmul instead of two matmuls with intermediate.
# This is more efficient because: (1) avoids storing batch*seq*C intermediate tensor,
# (2) single fused matmul is faster than two separate matmuls.
use_weight_masking = (
mask is not None and component_acts_cache is None and self._is_shared_mask(mask)
)

out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out")
if use_weight_masking:
assert mask is not None # for type checker
squeezed_mask = mask.view(self.C)
# V @ diag(mask) @ U = (V * mask) @ U
masked_V = self.V * squeezed_mask
W = einops.einsum(masked_V, self.U, "d_in C, C d_out -> d_in d_out")
out = einops.einsum(x, W, "... d_in, d_in d_out -> ... d_out")
else:
component_acts = self.get_inner_acts(x)
if component_acts_cache is not None:
component_acts_cache["pre_detach"] = component_acts
component_acts = component_acts.detach().requires_grad_(True)
component_acts_cache["post_detach"] = component_acts

if mask is not None:
component_acts = component_acts * mask

out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out")

if weight_delta_and_mask is not None:
weight_delta, weight_delta_mask = weight_delta_and_mask
Expand Down Expand Up @@ -252,6 +274,10 @@ def weight(self) -> Float[Tensor, "vocab_size embedding_dim"]:
def get_inner_acts(self, x: Int[Tensor, "..."]) -> Float[Tensor, "... C"]:
return self.V[x]

def _is_shared_mask(self, mask: Tensor) -> bool:
"""Check if mask is shared across batch (all dims except last are singleton)."""
return all(d == 1 for d in mask.shape[:-1])

@override
def forward(
self,
Expand All @@ -264,25 +290,43 @@ def forward(

Args:
x: Input tensor of token indices
mask: Tensor which masks parameter components. May be boolean or float.
mask: Tensor which masks parameter components. If the mask has singleton dimensions
for all batch dims (i.e., shape [1, ..., 1, C]), an optimized path is used that
precomputes the masked embedding table instead of masking per-element.
weight_delta_and_mask: Optional tuple of tensors containing:
0: the weight differences between the target model and summed component weights
1: mask over the weight delta component for each sample
component_acts_cache: Cache dictionary to populate with component acts
"""
assert x.dtype == torch.long, "x must be an integer tensor"

component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x)

if component_acts_cache is not None:
component_acts_cache["pre_detach"] = component_acts
component_acts = component_acts.detach().requires_grad_(True)
component_acts_cache["post_detach"] = component_acts

if mask is not None:
component_acts = component_acts * mask
# Optimized path: when mask is shared across batch and we don't need component_acts_cache,
# precompute masked embedding table and index into it instead of per-element masking.
use_weight_masking = (
mask is not None and component_acts_cache is None and self._is_shared_mask(mask)
)

out = einops.einsum(component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim")
if use_weight_masking:
assert mask is not None # for type checker
squeezed_mask = mask.view(self.C)
# (V * mask) @ U gives masked embedding table
masked_V = self.V * squeezed_mask
W = einops.einsum(masked_V, self.U, "vocab C, C d_emb -> vocab d_emb")
out = W[x]
else:
component_acts: Float[Tensor, "... C"] = self.get_inner_acts(x)

if component_acts_cache is not None:
component_acts_cache["pre_detach"] = component_acts
component_acts = component_acts.detach().requires_grad_(True)
component_acts_cache["post_detach"] = component_acts

if mask is not None:
component_acts = component_acts * mask

out = einops.einsum(
component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim"
)

if weight_delta_and_mask is not None:
weight_delta, weight_delta_mask = weight_delta_and_mask
Expand Down
85 changes: 85 additions & 0 deletions tests/test_component_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,88 @@ def forward(self, x: Tensor) -> Tensor:

# but it should be the same for the second example (where it's not routed to components)
assert torch.allclose(cm_routed_out[1], target_out[1])


def test_shared_mask_optimization_linear():
"""Test that shared mask optimization produces same results as per-element masking."""
batch_size = 8
seq_len = 16
d_in = 32
d_out = 64
C = 10

linear = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None)

x = torch.randn(batch_size, seq_len, d_in)

# Shared mask (singleton batch dims) - triggers weight-masking optimization
shared_mask = torch.rand(1, 1, C)

# Expanded mask (same values, full batch dims) - uses standard path
expanded_mask = shared_mask.expand(batch_size, seq_len, C).clone()

# Both should produce same output
out_shared = linear(x, mask=shared_mask)
out_expanded = linear(x, mask=expanded_mask)

torch.testing.assert_close(out_shared, out_expanded)

# Verify gradients flow correctly through both paths
x_shared = x.clone().requires_grad_(True)
x_expanded = x.clone().requires_grad_(True)
mask_shared = shared_mask.clone().requires_grad_(True)
mask_expanded = expanded_mask.clone().requires_grad_(True)

loss_shared = linear(x_shared, mask=mask_shared).sum()
loss_expanded = linear(x_expanded, mask=mask_expanded).sum()

loss_shared.backward()
loss_expanded.backward()

# Gradients wrt x should be the same
torch.testing.assert_close(x_shared.grad, x_expanded.grad)

# Gradient wrt mask - shared should be the sum over batch dims of expanded
assert mask_expanded.grad is not None
expected_mask_grad = mask_expanded.grad.sum(dim=(0, 1), keepdim=True)
torch.testing.assert_close(mask_shared.grad, expected_mask_grad)


def test_shared_mask_optimization_embedding():
"""Test that shared mask optimization produces same results for embeddings."""
batch_size = 8
seq_len = 16
vocab_size = 100
embedding_dim = 64
C = 10

embedding = EmbeddingComponents(C=C, vocab_size=vocab_size, embedding_dim=embedding_dim)

x = torch.randint(0, vocab_size, (batch_size, seq_len))

# Shared mask (singleton batch dims) - triggers weight-masking optimization
shared_mask = torch.rand(1, 1, C)

# Expanded mask (same values, full batch dims) - uses standard path
expanded_mask = shared_mask.expand(batch_size, seq_len, C).clone()

# Both should produce same output
out_shared = embedding(x, mask=shared_mask)
out_expanded = embedding(x, mask=expanded_mask)

torch.testing.assert_close(out_shared, out_expanded)

# Verify gradients flow correctly
mask_shared = shared_mask.clone().requires_grad_(True)
mask_expanded = expanded_mask.clone().requires_grad_(True)

loss_shared = embedding(x, mask=mask_shared).sum()
loss_expanded = embedding(x, mask=mask_expanded).sum()

loss_shared.backward()
loss_expanded.backward()

# Gradient wrt mask - shared should be the sum over batch dims of expanded
assert mask_expanded.grad is not None
expected_mask_grad = mask_expanded.grad.sum(dim=(0, 1), keepdim=True)
torch.testing.assert_close(mask_shared.grad, expected_mask_grad)