diff --git a/scripts/BENCHMARK_RESULTS.md b/scripts/BENCHMARK_RESULTS.md new file mode 100644 index 000000000..2fa684f7c --- /dev/null +++ b/scripts/BENCHMARK_RESULTS.md @@ -0,0 +1,68 @@ +# torch.compile Benchmark Results for SPD + +Benchmarks comparing eager mode vs `torch.compile()` for SPD's masked forward/backward passes. + +## Summary + +- **Isolated LinearComponents**: No benefit from torch.compile (0-20% slower) +- **Full ComponentModel**: 30-35% speedup at larger batch sizes + +## Full Model Results (SS Llama Simple) + +Model: 4-layer Llama with 28 decomposed modules, C=1200 components each. + +| Batch | Seq Len | Compile Mode | Eager | Compiled | Speedup | +|-------|---------|--------------|-------|----------|---------| +| 16 | 128 | reduce-overhead | 10.56ms | 9.64ms | **1.09x (9.5%)** | +| 32 | 256 | reduce-overhead | 12.61ms | 11.20ms | **1.13x (12.6%)** | +| 64 | 256 | reduce-overhead | 20.70ms | 15.96ms | **1.30x (30%)** | +| 128 | 256 | reduce-overhead | 36.19ms | 28.34ms | **1.28x (28%)** | +| 64 | 256 | max-autotune | 20.70ms | 15.38ms | **1.35x (35%)** | + +## Isolated LinearComponents Results + +Testing just the core operation: `out = (x @ V * mask) @ U` + +| Test | Dimensions | Compile Mode | Speedup | +|------|------------|--------------|---------| +| LinearComponents | d=512, C=512 | reduce-overhead | 1.00x (no benefit) | +| LinearComponents | d=2048, C=2048 | reduce-overhead | 1.03x (2.6%) | +| Pure function | d=512, C=512 | reduce-overhead | 0.85x (15% slower) | +| Pure function | d=4096, C=4096 | reduce-overhead | 0.96x (4% slower) | +| FP16 | d=512, C=512 | reduce-overhead | 0.78x (22% slower) | + +## Why the Difference? + +### Full model benefits because: +- Many operations can be fused (28 masked layers + attention + activations + norms) +- More compute per kernel launch amortizes dispatch overhead +- torch.compile can optimize the entire forward/backward graph + +### Isolated LinearComponents don't benefit because: +- The operation `(x @ V * mask) @ U` is just 2 matmuls + 1 multiply +- cuBLAS is already highly optimized for matmuls +- No fusion opportunities between matmuls +- Dispatch overhead exceeds any micro-optimization gains + +## Recommendations + +1. **Use torch.compile for full SPD training** with `mode="reduce-overhead"` or `mode="max-autotune"` +2. **Use larger batch sizes** (64+) to maximize compile benefits +3. **Don't bother compiling isolated components** - eager is faster or equivalent +4. **Budget for warmup time** - first few steps are slow due to compilation (~24s for reduce-overhead, ~70s for max-autotune) + +## Reproduce + +```bash +# Full model benchmark +python scripts/benchmark_full_model.py --batch_size 64 --seq_len 256 + +# Isolated LinearComponents benchmark +python scripts/benchmark_linear_components.py components --C 512 +python scripts/benchmark_linear_components.py pure --d_in 512 --d_out 512 --C 512 +``` + +## Hardware + +- GPU: NVIDIA (CUDA) +- PyTorch with torch.compile (inductor backend) diff --git a/scripts/benchmark_full_model.py b/scripts/benchmark_full_model.py new file mode 100644 index 000000000..9b97de000 --- /dev/null +++ b/scripts/benchmark_full_model.py @@ -0,0 +1,209 @@ +"""Benchmark torch.compile on a full ComponentModel (e.g., SS Llama). + +Tests the masked forward/backward pass on a real model rather than isolated components. +""" + +import time +from pathlib import Path + +import fire +import torch +import torch.nn as nn +from simple_stories_train.run_info import RunInfo as SSRunInfo +from torch import Tensor + +from spd.configs import Config +from spd.models.component_model import ComponentModel +from spd.models.components import make_mask_infos +from spd.utils.general_utils import resolve_class, set_seed +from spd.utils.module_utils import expand_module_patterns + +torch.set_float32_matmul_precision("high") + + +def load_model_and_config( + config_path: str = "spd/experiments/lm/ss_llama_simple_config.yaml", +) -> tuple[nn.Module, Config]: + """Load the target model and config.""" + config = Config.from_file(Path(config_path)) + + pretrained_model_class = resolve_class(config.pretrained_model_class) + assert config.pretrained_model_name is not None + + if config.pretrained_model_class.startswith("simple_stories_train"): + run_info = SSRunInfo.from_path(config.pretrained_model_name) + target_model = pretrained_model_class.from_run_info(run_info) + else: + target_model = pretrained_model_class.from_pretrained(config.pretrained_model_name) + + target_model.eval() + target_model.requires_grad_(False) + + return target_model, config + + +def benchmark_full_model( + config_path: str = "spd/experiments/lm/ss_llama_simple_config.yaml", + batch_size: int = 32, + seq_len: int = 256, + steps: int = 50, + warmup: int = 10, + compile_mode: str = "reduce-overhead", +) -> None: + """Benchmark ComponentModel forward/backward with full model. + + Args: + config_path: Path to experiment config + batch_size: Batch size + seq_len: Sequence length + steps: Number of benchmark steps + warmup: Number of warmup steps + compile_mode: torch.compile mode + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"Config: {config_path}") + print(f"batch_size={batch_size}, seq_len={seq_len}") + print(f"Warmup: {warmup}, Steps: {steps}") + print(f"Compile mode: {compile_mode}") + print() + + set_seed(42) + + print("Loading model...") + target_model, config = load_model_and_config(config_path) + print(f"Model loaded: {type(target_model).__name__}") + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"\n{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + # Reload target model fresh each time (ComponentModel modifies it) + target_model, config = load_model_and_config(config_path) + + # Create fresh ComponentModel + module_path_info = expand_module_patterns(target_model, config.all_module_info) + + model = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_fn_type=config.ci_fn_type, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + pretrained_model_output_attr=config.pretrained_model_output_attr, + sigmoid_type=config.sigmoid_type, + ) + model.to(device) + + print(f"ComponentModel created with {len(model.target_module_paths)} modules:") + for path in model.target_module_paths[:5]: + print(f" - {path} (C={model.module_to_c[path]})") + if len(model.target_module_paths) > 5: + print(f" ... and {len(model.target_module_paths) - 5} more") + + if use_compile: + mode = None if compile_mode == "default" else compile_mode + print(f"\nCompiling model (mode={mode})...") + compile_start = time.perf_counter() + model = torch.compile(model, fullgraph=False, mode=mode) # type: ignore + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + component_model = model._orig_mod if use_compile else model # type: ignore + + # Setup optimizer + params = [] + for name in component_model.target_module_paths: + params.extend(component_model.components[name].parameters()) + params.extend(component_model.ci_fns[name].parameters()) + optimizer = torch.optim.AdamW(params, lr=1e-4) + + # Create sample mask + def create_masks() -> dict[str, Tensor]: + masks = {} + for name in component_model.target_module_paths: + C = component_model.module_to_c[name] + masks[name] = torch.rand(batch_size, seq_len, C, device=device) + return masks + + # Warmup + print(f"\nWarming up ({warmup} steps)...") + warmup_start = time.perf_counter() + for i in range(warmup): + optimizer.zero_grad() + + # Random token input + x = torch.randint(0, 1000, (batch_size, seq_len), device=device) + masks = create_masks() + mask_infos = make_mask_infos(masks) + + # Forward pass with mask + out = model(x, mask_infos=mask_infos) # type: ignore + + # Simple loss + loss = out.mean() + loss.backward() + optimizer.step() + + if (i + 1) % 5 == 0: + print(f" Warmup step {i + 1}/{warmup}") + + warmup_time = time.perf_counter() - warmup_start + print(f"Warmup took {warmup_time:.2f}s ({warmup_time / warmup * 1000:.1f}ms/step)") + + # Benchmark + print(f"\nBenchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + for i in range(steps): + x = torch.randint(0, 1000, (batch_size, seq_len), device=device) + masks = create_masks() + mask_infos = make_mask_infos(masks) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + out = model(x, mask_infos=mask_infos) # type: ignore + loss = out.mean() + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + min_time = min(step_times) + max_time = max(step_times) + + results[mode_name] = {"avg": avg_time, "min": min_time, "max": max_time} + + print(f"\n{mode_name.upper()} Results:") + print(f" Avg: {avg_time * 1000:.2f}ms") + print(f" Min: {min_time * 1000:.2f}ms") + print(f" Max: {max_time * 1000:.2f}ms") + + del model, component_model, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + # Summary + print(f"\n{'=' * 60}") + print("SUMMARY") + print(f"{'=' * 60}") + eager_time = results["eager"]["avg"] + compiled_time = results["compiled"]["avg"] + speedup = eager_time / compiled_time + print(f"Eager: {eager_time * 1000:.2f}ms") + print(f"Compiled: {compiled_time * 1000:.2f}ms") + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +if __name__ == "__main__": + fire.Fire(benchmark_full_model) diff --git a/scripts/benchmark_linear_components.py b/scripts/benchmark_linear_components.py new file mode 100644 index 000000000..6ce80b2a2 --- /dev/null +++ b/scripts/benchmark_linear_components.py @@ -0,0 +1,727 @@ +"""Minimal benchmark for LinearComponents masked forward/backward pass. + +Tests torch.compile() efficiency on the core LinearComponents operation: + out = (x @ V * mask) @ U +""" + +import time + +import fire +import torch +import torch.nn as nn +from torch import Tensor + +from spd.models.components import LinearComponents + +torch.set_float32_matmul_precision("high") + + +def benchmark_linear_components( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, + compile_mode: str = "default", +) -> None: + """Benchmark LinearComponents masked forward and backward pass. + + Args: + d_in: Input dimension + d_out: Output dimension + C: Number of components + batch_size: Batch size + seq_len: Sequence length + steps: Number of benchmark steps + warmup: Number of warmup steps + compile_mode: torch.compile mode (default, reduce-overhead, max-autotune) + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Warmup: {warmup}, Steps: {steps}") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + # Create LinearComponents + components = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None) + components.to(device) + + # Define the forward function we want to benchmark + def forward_backward( + x: Tensor, mask: Tensor, components: nn.Module, target: Tensor + ) -> Tensor: + out = components(x, mask=mask) + loss = (out - target).pow(2).mean() + return loss + + if use_compile: + mode = None if compile_mode == "default" else compile_mode + print(f"Compiling (mode={mode})...") + compile_start = time.perf_counter() + forward_backward = torch.compile(forward_backward, fullgraph=True, mode=mode) + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + # Setup optimizer + optimizer = torch.optim.AdamW(components.parameters(), lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + mask = torch.rand(batch_size, seq_len, C, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + loss = forward_backward(x, mask, components, target) + loss.backward() + optimizer.step() + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + mask = torch.rand(batch_size, seq_len, C, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + loss = forward_backward(x, mask, components, target) + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + min_time = min(step_times) + max_time = max(step_times) + + results[mode_name] = {"avg": avg_time, "min": min_time, "max": max_time} + + print(f"\n{mode_name.upper()} Results:") + print(f" Avg: {avg_time * 1000:.3f}ms") + print(f" Min: {min_time * 1000:.3f}ms") + print(f" Max: {max_time * 1000:.3f}ms") + print() + + del components, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + # Summary + print(f"{'=' * 60}") + print("SUMMARY") + print(f"{'=' * 60}") + eager_time = results["eager"]["avg"] + compiled_time = results["compiled"]["avg"] + speedup = eager_time / compiled_time + print(f"Eager: {eager_time * 1000:.3f}ms") + print(f"Compiled: {compiled_time * 1000:.3f}ms") + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_raw_einsum( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, +) -> None: + """Benchmark raw einsum operations (equivalent to LinearComponents) to isolate overhead. + + The core operation is: out = (x @ V * mask) @ U + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print() + print("Testing RAW einsum: out = (x @ V * mask) @ U") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + # Create parameters directly + V = nn.Parameter(torch.randn(d_in, C, device=device) * 0.02) + U = nn.Parameter(torch.randn(C, d_out, device=device) * 0.02) + + def forward_backward(x: Tensor, mask: Tensor, V: Tensor, U: Tensor, target: Tensor) -> Tensor: + inner = x @ V # (batch, seq, C) + masked = inner * mask + out = masked @ U # (batch, seq, d_out) + loss = (out - target).pow(2).mean() + return loss + + if use_compile: + print("Compiling...") + compile_start = time.perf_counter() + forward_backward = torch.compile(forward_backward, fullgraph=True) + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + optimizer = torch.optim.AdamW([V, U], lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + mask = torch.rand(batch_size, seq_len, C, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + loss = forward_backward(x, mask, V, U, target) + loss.backward() + optimizer.step() + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + mask = torch.rand(batch_size, seq_len, C, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + loss = forward_backward(x, mask, V, U, target) + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del V, U, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_no_mask( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, +) -> None: + """Benchmark LinearComponents WITHOUT mask to see if masking is the bottleneck.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print() + print("Testing LinearComponents WITHOUT mask") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + components = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None) + components.to(device) + + def forward_backward(x: Tensor, components: nn.Module, target: Tensor) -> Tensor: + out = components(x, mask=None) + loss = (out - target).pow(2).mean() + return loss + + if use_compile: + print("Compiling...") + compile_start = time.perf_counter() + forward_backward = torch.compile(forward_backward, fullgraph=True) + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + optimizer = torch.optim.AdamW(components.parameters(), lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + loss = forward_backward(x, components, target) + loss.backward() + optimizer.step() + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + loss = forward_backward(x, components, target) + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del components, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_forward_only( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 500, + warmup: int = 100, + compile_mode: str = "default", +) -> None: + """Benchmark LinearComponents forward pass only (no backward).""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Compile mode: {compile_mode}") + print() + print("Testing FORWARD ONLY (no backward)") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + components = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None) + components.to(device) + + def forward_fn(x: Tensor, mask: Tensor, components: nn.Module) -> Tensor: + return components(x, mask=mask) + + if use_compile: + mode = None if compile_mode == "default" else compile_mode + print(f"Compiling (mode={mode})...") + compile_start = time.perf_counter() + forward_fn = torch.compile(forward_fn, fullgraph=True, mode=mode) + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + # Fixed mask for fair comparison + mask = torch.rand(batch_size, seq_len, C, device=device) + + # Warmup + print(f"Warming up ({warmup} steps)...") + with torch.no_grad(): + for _ in range(warmup): + x = torch.randn(batch_size, seq_len, d_in, device=device) + _ = forward_fn(x, mask, components) + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + with torch.no_grad(): + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + _ = forward_fn(x, mask, components) + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del components + if device == "cuda": + torch.cuda.empty_cache() + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +class MaskedLinearComponents(nn.Module): + """Wrapper that stores mask as a module property for simpler forward signature.""" + + def __init__(self, C: int, d_in: int, d_out: int): + super().__init__() + self.components = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None) + self.mask: Tensor | None = None + + def set_mask(self, mask: Tensor) -> None: + self.mask = mask + + def forward(self, x: Tensor) -> Tensor: + return self.components(x, mask=self.mask) + + +def benchmark_module_compile( + d_in: int = 512, + d_out: int = 512, + C: int = 512, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, + compile_mode: str = "reduce-overhead", +) -> None: + """Benchmark compiling the module directly instead of a function wrapper.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Compile mode: {compile_mode}") + print() + print("Testing MODULE compile (mask as property)") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + model = MaskedLinearComponents(C=C, d_in=d_in, d_out=d_out) + model.to(device) + + if use_compile: + mode = None if compile_mode == "default" else compile_mode + print(f"Compiling module (mode={mode})...") + compile_start = time.perf_counter() + model = torch.compile(model, fullgraph=True, mode=mode) # type: ignore + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + base_model = model._orig_mod if use_compile else model # type: ignore + optimizer = torch.optim.AdamW(base_model.components.parameters(), lr=1e-4) + + # Warmup with fixed mask + print(f"Warming up ({warmup} steps)...") + warmup_mask = torch.rand(batch_size, seq_len, C, device=device) + base_model.set_mask(warmup_mask) + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + out = model(x) + loss = (out - target).pow(2).mean() + loss.backward() + optimizer.step() + + # Benchmark with FIXED mask (no recompilation) + print(f"Benchmarking ({steps} steps) with FIXED mask...") + if device == "cuda": + torch.cuda.synchronize() + + fixed_mask = torch.rand(batch_size, seq_len, C, device=device) + base_model.set_mask(fixed_mask) + + step_times = [] + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + out = model(x) + loss = (out - target).pow(2).mean() + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del model, base_model, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_pure_function( + d_in: int = 512, + d_out: int = 512, + C: int = 512, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, + compile_mode: str = "reduce-overhead", +) -> None: + """Benchmark a pure function (no module state) with explicit parameters.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Compile mode: {compile_mode}") + print() + print("Testing PURE FUNCTION: out = (x @ V * mask) @ U") + print() + + def masked_linear(x: Tensor, V: Tensor, U: Tensor, mask: Tensor) -> Tensor: + """Pure function: (x @ V * mask) @ U""" + inner = x @ V + masked = inner * mask + return masked @ U + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + V = nn.Parameter(torch.randn(d_in, C, device=device) * 0.02) + U = nn.Parameter(torch.randn(C, d_out, device=device) * 0.02) + + fn = masked_linear + if use_compile: + mode = None if compile_mode == "default" else compile_mode + print(f"Compiling (mode={mode})...") + compile_start = time.perf_counter() + fn = torch.compile(masked_linear, fullgraph=True, mode=mode) + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + optimizer = torch.optim.AdamW([V, U], lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + mask = torch.rand(batch_size, seq_len, C, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + out = fn(x, V, U, mask) + loss = (out - target).pow(2).mean() + loss.backward() + optimizer.step() + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + mask = torch.rand(batch_size, seq_len, C, device=device) + target = torch.randn(batch_size, seq_len, d_out, device=device) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + out = fn(x, V, U, mask) + loss = (out - target).pow(2).mean() + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del V, U, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_fp16( + d_in: int = 512, + d_out: int = 512, + C: int = 512, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, + compile_mode: str = "reduce-overhead", +) -> None: + """Benchmark with fp16 to see if tensor cores + compile helps.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Compile mode: {compile_mode}") + print() + print("Testing FP16: out = (x @ V * mask) @ U") + print() + + def masked_linear(x: Tensor, V: Tensor, U: Tensor, mask: Tensor) -> Tensor: + inner = x @ V + masked = inner * mask + return masked @ U + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + V = nn.Parameter(torch.randn(d_in, C, device=device, dtype=torch.float16) * 0.02) + U = nn.Parameter(torch.randn(C, d_out, device=device, dtype=torch.float16) * 0.02) + + fn = masked_linear + if use_compile: + mode = None if compile_mode == "default" else compile_mode + print(f"Compiling (mode={mode})...") + compile_start = time.perf_counter() + fn = torch.compile(masked_linear, fullgraph=True, mode=mode) + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + optimizer = torch.optim.AdamW([V, U], lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device, dtype=torch.float16) + mask = torch.rand(batch_size, seq_len, C, device=device, dtype=torch.float16) + target = torch.randn(batch_size, seq_len, d_out, device=device, dtype=torch.float16) + + out = fn(x, V, U, mask) + loss = (out - target).pow(2).mean() + loss.backward() + optimizer.step() + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + if device == "cuda": + torch.cuda.synchronize() + + step_times = [] + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device, dtype=torch.float16) + mask = torch.rand(batch_size, seq_len, C, device=device, dtype=torch.float16) + target = torch.randn(batch_size, seq_len, d_out, device=device, dtype=torch.float16) + + if device == "cuda": + torch.cuda.synchronize() + step_start = time.perf_counter() + + optimizer.zero_grad() + out = fn(x, V, U, mask) + loss = (out - target).pow(2).mean() + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del V, U, optimizer + if device == "cuda": + torch.cuda.empty_cache() + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +if __name__ == "__main__": + fire.Fire( + { + "components": benchmark_linear_components, + "einsum": benchmark_raw_einsum, + "no_mask": benchmark_no_mask, + "forward": benchmark_forward_only, + "module": benchmark_module_compile, + "pure": benchmark_pure_function, + "fp16": benchmark_fp16, + } + ) diff --git a/scripts/benchmark_toy.py b/scripts/benchmark_toy.py new file mode 100644 index 000000000..95d3372cb --- /dev/null +++ b/scripts/benchmark_toy.py @@ -0,0 +1,457 @@ +"""Minimal toy benchmark to isolate torch.compile() performance. + +This tests a single linear layer with component decomposition to understand +why torch.compile() isn't providing more speedup. +""" + +import time +from typing import override + +import fire +import torch +import torch.nn as nn +from torch import Tensor + +from spd.models.component_model import ComponentModel +from spd.models.components import ComponentsMaskInfo, make_mask_infos +from spd.utils.module_utils import ModulePathInfo + +torch.set_float32_matmul_precision("high") + + +class ToyModel(nn.Module): + """Simple model with a single linear layer.""" + + def __init__(self, d_in: int, d_out: int): + super().__init__() + self.linear = nn.Linear(d_in, d_out, bias=False) + + @override + def forward(self, x: Tensor) -> Tensor: + return self.linear(x) + + +def benchmark_toy( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 100, + warmup: int = 20, +) -> None: + """Benchmark a toy model with component decomposition. + + Args: + d_in: Input dimension + d_out: Output dimension + C: Number of components + batch_size: Batch size + seq_len: Sequence length + steps: Number of benchmark steps + warmup: Number of warmup steps + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Warmup: {warmup}, Steps: {steps}") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + # Create fresh model + target_model = ToyModel(d_in, d_out) + target_model.eval() + target_model.requires_grad_(False) + + model = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path="linear", C=C)], + ci_fn_type="mlp", + ci_fn_hidden_dims=[32], + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + model.to(device) + + if use_compile: + print("Compiling model...") + compile_start = time.perf_counter() + model = torch.compile(model, fullgraph=False) # type: ignore + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + component_model = model._orig_mod if use_compile else model # type: ignore + + # Setup optimizer + params = list(component_model.components["linear"].parameters()) + params += list(component_model.ci_fns["linear"].parameters()) + optimizer = torch.optim.AdamW(params, lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + warmup_start = time.perf_counter() + for _ in range(warmup): + optimizer.zero_grad() + + # Random input + x = torch.randn(batch_size, seq_len, d_in, device=device) + + # Forward with caching + out, cache = model(x, cache_type="input") # type: ignore + + # Calculate CI + ci = component_model.calc_causal_importances( + pre_weight_acts=cache, + detach_inputs=False, + sampling="continuous", + ) + + # Create binary mask from CI + mask = (ci.lower_leaky["linear"] > 0.5).float() + mask_infos = make_mask_infos({"linear": mask}) + + # Forward with mask + masked_out = model(x, mask_infos=mask_infos) # type: ignore + + # Simple MSE loss + loss = (masked_out - out.detach()).pow(2).mean() + loss.backward() + optimizer.step() + + warmup_time = time.perf_counter() - warmup_start + print(f"Warmup took {warmup_time:.2f}s ({warmup_time / warmup * 1000:.1f}ms/step)") + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + torch.cuda.synchronize() if device == "cuda" else None + + step_times = [] + for _ in range(steps): + step_start = time.perf_counter() + + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + + out, cache = model(x, cache_type="input") # type: ignore + + ci = component_model.calc_causal_importances( + pre_weight_acts=cache, + detach_inputs=False, + sampling="continuous", + ) + + mask = (ci.lower_leaky["linear"] > 0.5).float() + mask_infos = make_mask_infos({"linear": mask}) + + masked_out = model(x, mask_infos=mask_infos) # type: ignore + + loss = (masked_out - out.detach()).pow(2).mean() + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + min_time = min(step_times) + max_time = max(step_times) + + results[mode_name] = { + "avg": avg_time, + "min": min_time, + "max": max_time, + } + + print(f"\n{mode_name.upper()} Results:") + print(f" Avg: {avg_time * 1000:.2f}ms") + print(f" Min: {min_time * 1000:.2f}ms") + print(f" Max: {max_time * 1000:.2f}ms") + print() + + del model, component_model, optimizer + torch.cuda.empty_cache() if device == "cuda" else None + + # Summary + print(f"{'=' * 60}") + print("SUMMARY") + print(f"{'=' * 60}") + eager_time = results["eager"]["avg"] + compiled_time = results["compiled"]["avg"] + speedup = eager_time / compiled_time + print(f"Eager: {eager_time * 1000:.2f}ms") + print(f"Compiled: {compiled_time * 1000:.2f}ms") + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_raw_linear( + d_in: int = 512, + d_out: int = 512, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 100, + warmup: int = 20, +) -> None: + """Benchmark a raw linear layer (no SPD) to establish baseline.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Warmup: {warmup}, Steps: {steps}") + print() + print("This tests a RAW linear layer without SPD to establish baseline.") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + model = nn.Linear(d_in, d_out, bias=False).to(device) + + if use_compile: + print("Compiling model...") + model = torch.compile(model, fullgraph=False) # type: ignore + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + + # Warmup + print(f"Warming up ({warmup} steps)...") + for _ in range(warmup): + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + out = model(x) + loss = out.pow(2).mean() + loss.backward() + optimizer.step() + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + torch.cuda.synchronize() if device == "cuda" else None + + step_times = [] + for _ in range(steps): + step_start = time.perf_counter() + + optimizer.zero_grad() + x = torch.randn(batch_size, seq_len, d_in, device=device) + out = model(x) + loss = out.pow(2).mean() + loss.backward() + optimizer.step() + + if device == "cuda": + torch.cuda.synchronize() + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.2f}ms") + print() + + del model, optimizer + torch.cuda.empty_cache() if device == "cuda" else None + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_forward_only( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, +) -> None: + """Benchmark just the forward pass (no backward) to isolate compilation benefit.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print(f"Warmup: {warmup}, Steps: {steps}") + print() + print("Testing FORWARD ONLY (no backward)") + print() + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + # Create fresh model + target_model = ToyModel(d_in, d_out) + target_model.eval() + target_model.requires_grad_(False) + + model = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path="linear", C=C)], + ci_fn_type="mlp", + ci_fn_hidden_dims=[32], + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + model.to(device) + + if use_compile: + print("Compiling model...") + compile_start = time.perf_counter() + model = torch.compile(model, fullgraph=False) # type: ignore + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + # Prepare fixed mask + mask = torch.ones(batch_size, seq_len, C, device=device) + mask_infos = make_mask_infos({"linear": mask}) + + # Warmup + print(f"Warming up ({warmup} steps)...") + with torch.no_grad(): + for _ in range(warmup): + x = torch.randn(batch_size, seq_len, d_in, device=device) + _ = model(x, mask_infos=mask_infos) # type: ignore + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + torch.cuda.synchronize() if device == "cuda" else None + + step_times = [] + with torch.no_grad(): + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + + torch.cuda.synchronize() if device == "cuda" else None + step_start = time.perf_counter() + + _ = model(x, mask_infos=mask_infos) # type: ignore + + torch.cuda.synchronize() if device == "cuda" else None + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del model + torch.cuda.empty_cache() if device == "cuda" else None + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +def benchmark_masked_module_direct( + d_in: int = 512, + d_out: int = 512, + C: int = 100, + batch_size: int = 64, + seq_len: int = 256, + steps: int = 200, + warmup: int = 50, +) -> None: + """Benchmark MaskedModule directly to isolate its performance.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"d_in={d_in}, d_out={d_out}, C={C}, batch_size={batch_size}, seq_len={seq_len}") + print() + print("Testing MaskedModule.forward() DIRECTLY") + print() + + from spd.models.components import LinearComponents + from spd.models.masked_module import MaskedModule + + results = {} + + for use_compile in [False, True]: + mode_name = "compiled" if use_compile else "eager" + print(f"{'=' * 60}") + print(f"Running: {mode_name}") + print(f"{'=' * 60}") + + # Create MaskedModule directly + base = nn.Linear(d_in, d_out, bias=False) + base.requires_grad_(False) + components = LinearComponents(C=C, d_in=d_in, d_out=d_out, bias=None) + + masked_module = MaskedModule( + module_name="test", + base=base, + components=components, + ) + masked_module.to(device) + + # Set up state for active forward + mask = torch.ones(batch_size, seq_len, C, device=device) + mask_info = ComponentsMaskInfo(component_mask=mask) + masked_module.set_runtime_state( + active=True, + mask_info=mask_info, + cache_type="none", + cache=None, + ) + + if use_compile: + print("Compiling MaskedModule...") + compile_start = time.perf_counter() + masked_module = torch.compile(masked_module, fullgraph=False) # type: ignore + print(f"torch.compile() call took {time.perf_counter() - compile_start:.2f}s") + + # Warmup + print(f"Warming up ({warmup} steps)...") + with torch.no_grad(): + for _ in range(warmup): + x = torch.randn(batch_size, seq_len, d_in, device=device) + _ = masked_module(x) + + # Benchmark + print(f"Benchmarking ({steps} steps)...") + torch.cuda.synchronize() if device == "cuda" else None + + step_times = [] + with torch.no_grad(): + for _ in range(steps): + x = torch.randn(batch_size, seq_len, d_in, device=device) + + torch.cuda.synchronize() if device == "cuda" else None + step_start = time.perf_counter() + + _ = masked_module(x) + + torch.cuda.synchronize() if device == "cuda" else None + step_times.append(time.perf_counter() - step_start) + + avg_time = sum(step_times) / len(step_times) + results[mode_name] = {"avg": avg_time} + + print(f"{mode_name.upper()}: {avg_time * 1000:.3f}ms") + print() + + del masked_module + torch.cuda.empty_cache() if device == "cuda" else None + + speedup = results["eager"]["avg"] / results["compiled"]["avg"] + print(f"Speedup: {speedup:.2f}x ({(speedup - 1) * 100:.1f}%)") + + +if __name__ == "__main__": + fire.Fire( + { + "toy": benchmark_toy, + "raw": benchmark_raw_linear, + "forward": benchmark_forward_only, + "masked": benchmark_masked_module_direct, + } + ) diff --git a/spd/configs.py b/spd/configs.py index f57305208..2a91bccb6 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -404,6 +404,11 @@ class Config(BaseConfig): default="leaky_hard", description="Type of sigmoid to use for causal importance calculation", ) + torch_compile: bool = Field( + default=False, + description="Whether to use torch.compile() for the forward pass. " + "Can provide speedups but may increase compilation time on first steps.", + ) module_info: list[ModulePatternInfoConfig] = Field( ..., description="List of module patterns with C values specifying which modules to decompose. " diff --git a/spd/identity_insertion.py b/spd/identity_insertion.py index 6995693cc..f31d035e6 100644 --- a/spd/identity_insertion.py +++ b/spd/identity_insertion.py @@ -17,6 +17,14 @@ from spd.models.components import Identity +def _is_identity_or_masked_identity(module: nn.Module) -> bool: + """Check if module is Identity or MaskedModule wrapping Identity.""" + if isinstance(module, Identity): + return True + # Avoid circular import by checking attribute instead of importing MaskedModule + return hasattr(module, "base") and isinstance(module.base, Identity) + + def pre_id_hook( mod: nn.Module, args: tuple[Any, ...], @@ -27,10 +35,14 @@ def pre_id_hook( # simple for now. assert not kwargs, f"Expected no kwargs, got {kwargs.keys()}" assert hasattr(mod, "pre_identity"), f"Module {mod} has no pre_identity attribute" - assert isinstance(mod.pre_identity, Identity), ( - f"Module {mod} pre_identity is not an Identity layer" + pre_identity = mod.pre_identity + assert isinstance(pre_identity, nn.Module), ( + f"pre_identity is not a Module: {type(pre_identity)}" + ) + assert _is_identity_or_masked_identity(pre_identity), ( + f"Module {mod} pre_identity is not an Identity layer (or MaskedModule wrapping Identity)" ) - return (mod.pre_identity(args[0]),), {} + return (pre_identity(args[0]),), {} def insert_identity_operations_( diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 1c7a5a5c6..a67353ec5 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,13 +1,10 @@ -from collections.abc import Callable, Generator, Sequence -from contextlib import contextmanager +from collections.abc import Sequence from dataclasses import dataclass -from functools import partial from typing import Any, Literal, NamedTuple, overload, override import torch from jaxtyping import Float, Int from torch import Tensor, nn -from torch.utils.hooks import RemovableHandle from transformers.pytorch_utils import Conv1D as RadfordConv1D from spd.configs import Config, SamplingType @@ -23,6 +20,7 @@ VectorMLPCiFn, VectorSharedMLPCiFn, ) +from spd.models.masked_module import MaskedModule from spd.models.sigmoids import SIGMOID_TYPES, SigmoidType from spd.spd_types import CiFnType, ModelPath from spd.utils.general_utils import resolve_class, runtime_cast @@ -93,14 +91,11 @@ def __init__( self.module_to_c = {info.module_path: info.C for info in module_path_info} self.target_module_paths = list(self.module_to_c.keys()) + # Create trainable components and CI functions against the *original* target modules. + # We patch the target model only after this, to avoid CI fn creation seeing wrappers. self.components = ComponentModel._create_components( - target_model=target_model, - module_to_c=self.module_to_c, - ) - self._components = nn.ModuleDict( - {k.replace(".", "-"): self.components[k] for k in sorted(self.components)} + target_model=target_model, module_to_c=self.module_to_c ) - self.ci_fns = ComponentModel._create_ci_fns( target_model=target_model, module_to_c=self.module_to_c, @@ -111,6 +106,22 @@ def __init__( {k.replace(".", "-"): self.ci_fns[k] for k in sorted(self.ci_fns)} ) + # Patch the target model in-place: wrap each decomposed module with MaskedModule. + # IMPORTANT: Process paths in order of decreasing depth (deepest first) so that + # wrapping a parent module doesn't prevent access to its children. + # E.g., wrap "linear1.pre_identity" before "linear1". + self._masked_modules: dict[str, MaskedModule] = {} + sorted_paths = sorted(self.target_module_paths, key=lambda p: p.count("."), reverse=True) + for module_path in sorted_paths: + base_module = target_model.get_submodule(module_path) + masked = MaskedModule( + module_name=module_path, + base=base_module, + components=self.components[module_path], + ) + _set_submodule_by_path(target_model, module_path, masked) + self._masked_modules[module_path] = masked + if sigmoid_type == "leaky_hard": self.lower_leaky_fn = SIGMOID_TYPES["lower_leaky_hard"] self.upper_leaky_fn = SIGMOID_TYPES["upper_leaky_hard"] @@ -121,6 +132,8 @@ def __init__( def target_weight(self, module_name: str) -> Float[Tensor, "rows cols"]: target_module = self.target_model.get_submodule(module_name) + if isinstance(target_module, MaskedModule): + target_module = target_module.base match target_module: case RadfordConv1D(): @@ -295,134 +308,54 @@ def forward( cache_type: Literal["component_acts", "input", "none"] = "none", **kwargs: Any, ) -> Tensor | OutputWithCache: - """Forward pass with optional component replacement and/or input caching. + """Forward pass with optional component replacement and/or caching, without hooks. - This method handles the following 4 cases: - 1. mask_infos is None and cache_type is "none": Regular forward pass. - 2. mask_infos is None and cache_type is "input" or "component_acts": Forward pass with - caching on all modules in self.target_module_paths. - 3. mask_infos is not None and cache_type is "input" or "component_acts": Forward pass with - component replacement and caching on the modules provided in mask_infos. - 4. mask_infos is not None and cache_type is "none": Forward pass with component replacement - on the modules provided in mask_infos and no caching. - - Args: - mask_infos: Dictionary mapping module names to ComponentsMaskInfo. - If provided, those modules will be replaced with their components. - cache_type: If "input" or "component_acts", cache the inputs or component acts to the - modules provided in mask_infos. If "none", no caching is done. If mask_infos is None, - cache the inputs or component acts to all modules in self.target_module_paths. - - Returns: - OutputWithCache object if cache_type is "input" or "component_acts", otherwise the - model output tensor. + Semantics match the previous hook-based implementation: + - `mask_infos is None and cache_type == "none"`: pure target model forward. + - `mask_infos is None and cache_type != "none"`: cache on all decomposed modules. + - `mask_infos is not None`: activate exactly those modules for component replacement. + If also caching, cache only on the modules provided in `mask_infos`. """ - if mask_infos is None and cache_type == "none": - # No hooks needed. Do a regular forward pass of the target model. - return self._extract_output(self.target_model(*args, **kwargs)) - - cache: dict[str, Tensor] = {} - hooks: dict[str, Callable[..., Any]] = {} - - hook_module_names = list(mask_infos.keys()) if mask_infos else self.target_module_paths - - for module_name in hook_module_names: - mask_info = mask_infos[module_name] if mask_infos else None - components = self.components[module_name] if mask_info else None - - hooks[module_name] = partial( - self._components_and_cache_hook, - module_name=module_name, - components=components, - mask_info=mask_info, - cache_type=cache_type, - cache=cache, + cache: dict[str, Tensor] | None = {} if cache_type != "none" else None + + if cache_type == "none": + cache_names: set[str] = set() + else: + cache_names = ( + set(mask_infos.keys()) if mask_infos is not None else set(self.target_module_paths) ) - with self._attach_forward_hooks(hooks): - raw_out = self.target_model(*args, **kwargs) + active_infos = mask_infos if mask_infos is not None else {} + + for name in self.target_module_paths: + masked = self._masked_modules[name] + is_active = name in active_infos + is_caching = name in cache_names + masked.set_runtime_state( + active=is_active, + mask_info=(active_infos[name] if is_active else None), + cache_type=(cache_type if is_caching else "none"), + cache=(cache if is_caching else None), + ) + raw_out = self.target_model(*args, **kwargs) out = self._extract_output(raw_out) - match cache_type: - case "input" | "component_acts": - return OutputWithCache(output=out, cache=cache) - case "none": - return out + if cache_type == "none": + return out + assert cache is not None + return OutputWithCache(output=out, cache=cache) - def _components_and_cache_hook( - self, - _module: nn.Module, - args: list[Any], - kwargs: dict[Any, Any], - output: Any, - module_name: str, - components: Components | None, - mask_info: ComponentsMaskInfo | None, - cache_type: Literal["component_acts", "input", "none"], - cache: dict[str, Tensor], - ) -> Any | None: - """Unified hook function that handles both component replacement and caching. + def validate_masked_module_state(self) -> None: + """Validate that all MaskedModules have consistent runtime state. - Args: - module: The module being hooked - args: Module forward args - kwargs: Module forward kwargs - output: Module forward output - module_name: Name of the module in the target model - components: Component replacement (if using components) - mask_info: Mask information (if using components) - cache_type: Whether to cache the component acts, input, or none - cache: Cache dictionary to populate (if cache_type is not None) + Call this for debugging after set_runtime_state() has been called on all modules + but before the forward pass. Useful for catching state configuration bugs. - Returns: - If using components: modified output (or None to keep original) - If not using components: None (keeps original output) + Raises: + AssertionError: If any MaskedModule has inconsistent state. """ - assert len(args) == 1, "Expected 1 argument" - assert len(kwargs) == 0, "Expected no keyword arguments" - x = args[0] - assert isinstance(x, Tensor), "Expected input tensor" - - if cache_type == "input": - cache[module_name] = x - - if components is not None and mask_info is not None: - assert isinstance(output, Tensor), ( - f"Only supports single-tensor outputs, got {type(output)}" - ) - - component_acts_cache = {} if cache_type == "component_acts" else None - components_out = components( - x, - mask=mask_info.component_mask, - weight_delta_and_mask=mask_info.weight_delta_and_mask, - component_acts_cache=component_acts_cache, - ) - if component_acts_cache is not None: - for k, v in component_acts_cache.items(): - cache[f"{module_name}_{k}"] = v - - if mask_info.routing_mask == "all": - return components_out - - return torch.where(mask_info.routing_mask[..., None], components_out, output) - - # No component replacement - keep original output - return None - - @contextmanager - def _attach_forward_hooks(self, hooks: dict[str, Callable[..., Any]]) -> Generator[None]: - """Context manager to temporarily attach forward hooks to the target model.""" - handles: list[RemovableHandle] = [] - for module_name, hook in hooks.items(): - target_module = self.target_model.get_submodule(module_name) - handle = target_module.register_forward_hook(hook, with_kwargs=True) - handles.append(handle) - try: - yield - finally: - for handle in handles: - handle.remove() + for masked in self._masked_modules.values(): + masked.validate_state() @classmethod @override @@ -572,6 +505,36 @@ def handle_deprecated_state_dict_keys_(state_dict: dict[str, Tensor]) -> None: ) # module path has "." replaced with "-" new_key = f"_components.{target_module_path.replace('.', '-')}.{new_key.split('.')[-1]}" + # If we now store components under target_model..components.{U,V}, map legacy keys. + if new_key.startswith("_components."): + # _components.. -> target_model..components. + parts = new_key.split(".") + if len(parts) == 3 and parts[0] == "_components" and parts[2] in {"U", "V"}: + dotted_path = parts[1].replace("-", ".") + new_key = f"target_model.{dotted_path}.components.{parts[2]}" # replace if modified if new_key != key: state_dict[new_key] = state_dict.pop(key) + + +def _set_submodule_by_path(root: nn.Module, module_path: str, new_module: nn.Module) -> None: + """Set `root.` to `new_module`, supporting ModuleList/Sequential numeric segments.""" + if module_path == "": + raise ValueError("module_path cannot be empty") + parts = module_path.split(".") + parent: nn.Module = root + for part in parts[:-1]: + parent = _get_child_module(parent, part) + last = parts[-1] + if last.isdigit(): + # Parent is a container type (ModuleList/Sequential) that supports indexing + parent[int(last)] = new_module # pyright: ignore[reportIndexIssue] + else: + setattr(parent, last, new_module) + + +def _get_child_module(parent: nn.Module, part: str) -> nn.Module: + if part.isdigit(): + # Parent is a container type (ModuleList/Sequential) that supports indexing + return parent[int(part)] # pyright: ignore[reportIndexIssue] + return parent.get_submodule(part) diff --git a/spd/models/components.py b/spd/models/components.py index 7686e4cbf..d6f32c337 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -176,11 +176,11 @@ def __init__( @override def weight(self) -> Float[Tensor, "d_out d_in"]: """(V @ U).T. Transposed to match nn.Linear which uses (d_out, d_in)""" - return einops.einsum(self.V, self.U, "d_in C, C d_out -> d_out d_in") + return (self.V @ self.U).T @override def get_inner_acts(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... C"]: - return einops.einsum(x, self.V, "... d_in, d_in C -> ... C") + return x @ self.V @override def forward( @@ -211,18 +211,17 @@ def forward( if mask is not None: component_acts = component_acts * mask - out = einops.einsum(component_acts, self.U, "... C, C d_out -> ... d_out") + out = component_acts @ self.U if weight_delta_and_mask is not None: weight_delta, weight_delta_mask = weight_delta_and_mask - unmasked_delta_out = einops.einsum(x, weight_delta, "... d_in, d_out d_in -> ... d_out") + # weight_delta is (d_out, d_in), so transpose for matmul: x @ weight_delta.T + unmasked_delta_out = x @ weight_delta.T assert unmasked_delta_out.shape[:-1] == weight_delta_mask.shape - out += einops.einsum( - weight_delta_mask, unmasked_delta_out, "..., ... d_out -> ... d_out" - ) + out = out + weight_delta_mask.unsqueeze(-1) * unmasked_delta_out if self.bias is not None: - out += self.bias + out = out + self.bias return out @@ -244,9 +243,7 @@ def __init__( @override def weight(self) -> Float[Tensor, "vocab_size embedding_dim"]: """V @ U""" - return einops.einsum( - self.V, self.U, "vocab_size C, C embedding_dim -> vocab_size embedding_dim" - ) + return self.V @ self.U @override def get_inner_acts(self, x: Int[Tensor, "..."]) -> Float[Tensor, "... C"]: @@ -282,15 +279,13 @@ def forward( if mask is not None: component_acts = component_acts * mask - out = einops.einsum(component_acts, self.U, "... C, C embedding_dim -> ... embedding_dim") + out = component_acts @ self.U if weight_delta_and_mask is not None: weight_delta, weight_delta_mask = weight_delta_and_mask unmasked_delta_out = weight_delta[x] assert unmasked_delta_out.shape[:-1] == weight_delta_mask.shape - out += einops.einsum( - weight_delta_mask, unmasked_delta_out, "..., ... embedding_dim -> ... embedding_dim" - ) + out = out + weight_delta_mask.unsqueeze(-1) * unmasked_delta_out return out diff --git a/spd/models/masked_module.py b/spd/models/masked_module.py new file mode 100644 index 000000000..21794156b --- /dev/null +++ b/spd/models/masked_module.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Literal, override + +import torch +from jaxtyping import Bool, Float +from torch import Tensor, nn + +from spd.models.components import Components, ComponentsMaskInfo + + +class MaskedModule(nn.Module): + """Wraps a frozen base module with trainable SPD Components, without forward hooks. + + This module is designed to replace an `nn.Linear` / `nn.Embedding` / `Conv1D` / `Identity` leaf + module inside an arbitrary PyTorch model. + + Why this design (vs forward hooks): + The previous hook-based approach used `register_forward_hook()` with dynamic + registration/removal via context managers. This was incompatible with `torch.compile()` + because dynamic hook management breaks graph tracing. This module-patching approach + keeps the module structure static, making it compatible with compilation. + + Usage: + `ComponentModel.forward(...)` MUST call `set_runtime_state()` on each MaskedModule + BEFORE executing the target model's forward pass. This configures: + - active: whether to use components (True) or pass through to base module (False) + - mask_info: component mask, routing mask, and optional weight-delta mask + - cache_type: whether to cache inputs or component activations + - cache: dictionary to populate with cached values + + After configuring state, the target model is executed normally, and this wrapper + applies the component replacement logic in its forward() method. + + Note on torch.compile(): + The mutable state pattern (setting attributes before forward) works with torch.compile(), + but may cause recompilation if the pattern of `active` flags changes frequently across + forward calls. For best performance, try to maintain consistent activation patterns. + """ + + def __init__(self, *, module_name: str, base: nn.Module, components: Components) -> None: + super().__init__() + self.module_name = module_name + self.base = base + self.components = components + + # Per-forward runtime state, set by ComponentModel.forward(...) before execution. + self.active: bool = False + self.mask_info: ComponentsMaskInfo | None = None + self.cache_type: Literal["none", "input", "component_acts"] = "none" + self.cache: dict[str, Tensor] | None = None + + @override + def __getattr__(self, name: str) -> Tensor | nn.Module: + """Forward attribute access to base module for nested submodule compatibility. + + This allows code to access nested submodules (e.g., model.linear1.pre_identity) + even after linear1 has been wrapped with MaskedModule. + """ + # nn.Module stores _modules in self._modules dict, access it directly to avoid recursion + _modules = object.__getattribute__(self, "_modules") + if "base" in _modules: + base = _modules["base"] + if hasattr(base, name): + return getattr(base, name) + # Fall back to nn.Module's __getattr__ for standard behavior + return super().__getattr__(name) + + def set_runtime_state( + self, + *, + active: bool, + mask_info: ComponentsMaskInfo | None, + cache_type: Literal["none", "input", "component_acts"], + cache: dict[str, Tensor] | None, + ) -> None: + self.active = active + self.mask_info = mask_info + self.cache_type = cache_type + self.cache = cache + + def validate_state(self) -> None: + """Validate that runtime state is consistent. Call this for debugging. + + Raises: + AssertionError: If state is inconsistent (e.g., active=True but mask_info=None). + """ + if self.active: + assert self.mask_info is not None, ( + f"MaskedModule '{self.module_name}': active=True but mask_info is None. " + "set_runtime_state() must be called with mask_info when active=True." + ) + if self.cache_type != "none": + assert self.cache is not None, ( + f"MaskedModule '{self.module_name}': cache_type='{self.cache_type}' but cache is None. " + "set_runtime_state() must be called with a cache dict when caching is enabled." + ) + + @override + def forward(self, x: Tensor) -> Tensor: # type: ignore[override] + if self.cache_type == "input": + assert self.cache is not None + self.cache[self.module_name] = x + + if not self.active: + return self.base(x) + + assert self.mask_info is not None + mask_info = self.mask_info + + component_acts_cache: dict[str, Float[Tensor, "... C"]] | None = ( + {} if self.cache_type == "component_acts" else None + ) + + components_out = self.components( + x, + mask=mask_info.component_mask, + weight_delta_and_mask=mask_info.weight_delta_and_mask, + component_acts_cache=component_acts_cache, + ) + + if component_acts_cache is not None: + assert self.cache is not None + for k, v in component_acts_cache.items(): + self.cache[f"{self.module_name}_{k}"] = v + + if mask_info.routing_mask == "all": + return components_out + + base_out = self.base(x) + routing_mask: Bool[Tensor, ...] = mask_info.routing_mask + return torch.where(routing_mask[..., None], components_out, base_out) diff --git a/spd/run_spd.py b/spd/run_spd.py index 82245a126..ad1fd5e96 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -163,9 +163,17 @@ def create_pgd_data_iter() -> ( replace_std_values_in_layernorm(model, ln_stds) model.to(device) + # Optionally compile the model for potential speedups + compiled_model: ComponentModel | nn.Module = model + if config.torch_compile: + logger.info("Compiling model with torch.compile()...") + # torch.compile returns a callable that wraps the model; type system doesn't understand this + compiled_model = torch.compile(model, fullgraph=False) # pyright: ignore[reportAssignmentType] + logger.info("Model compiled.") + # Wrap model with DDP if distributed dist_state = get_distributed_state() - wrapped_model: nn.Module = model + wrapped_model: nn.Module = compiled_model if dist_state is not None: if dist_state.backend == "nccl": device_id = dist_state.local_rank diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 5787c2f07..f866d6eb2 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -1,6 +1,6 @@ import tempfile from pathlib import Path -from typing import Any, override +from typing import Any, Literal, override import pytest import torch @@ -31,6 +31,7 @@ VectorSharedMLPCiFn, make_mask_infos, ) +from spd.models.masked_module import MaskedModule from spd.spd_types import ModelPath from spd.utils.module_utils import ModulePathInfo, expand_module_patterns from spd.utils.run_utils import save_file @@ -100,6 +101,8 @@ def test_correct_parameters_require_grad(): assert components.V.requires_grad target_module = component_model.target_model.get_submodule(module_path) + if isinstance(target_module, MaskedModule): + target_module = target_module.base if isinstance(target_module, nn.Linear | RadfordConv1D): assert not target_module.weight.requires_grad @@ -286,8 +289,11 @@ def test_full_weight_delta_matches_target_behaviour(): sigmoid_type="leaky_hard", ) + embed_module = target_model.embed + assert isinstance(embed_module, MaskedModule) + assert isinstance(embed_module.base, nn.Embedding) token_ids = torch.randint( - low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + low=0, high=embed_module.base.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long ) # WHEN we forward the component model with weight deltas and a weight delta mask of all 1s @@ -319,9 +325,12 @@ def test_input_cache_captures_pre_weight_input(): ) # WHEN we forward the component model with input caching + embed_module = target_model.embed + assert isinstance(embed_module, MaskedModule) + assert isinstance(embed_module.base, nn.Embedding) token_ids = torch.randint( low=0, - high=target_model.embed.num_embeddings, + high=embed_module.base.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long, ) @@ -336,7 +345,9 @@ def test_input_cache_captures_pre_weight_input(): assert torch.equal(cache["embed"], token_ids) embed_out = target_model.embed(token_ids) - assert cache["mlp"].shape == (BATCH_SIZE, target_model.mlp.in_features) + mlp_module = target_model.mlp + assert isinstance(mlp_module, MaskedModule) + assert cache["mlp"].shape == (BATCH_SIZE, mlp_module.base.in_features) torch.testing.assert_close(cache["mlp"], embed_out) @@ -362,6 +373,13 @@ def test_weight_deltas(): def test_replacement_effects_fwd_pass(): + """Test that component forward pass replaces the base module correctly. + + With the MaskedModule design, calling cm() with mask_infos uses components, + while calling cm() without mask_infos uses the base module. We verify: + 1. When component weights match base weights, both produce the same output + 2. When base weights differ from component weights, outputs differ proportionally + """ d_in = 10 d_out = 20 C = 30 @@ -388,28 +406,37 @@ def forward(self, x: Tensor) -> Tensor: sigmoid_type="leaky_hard", ) - # WHEN we set the target model weights to be UV - model.linear.weight.copy_(cm.components["linear"].weight) + # WHEN we set the base model weights to match the component weights (V @ U) + assert isinstance(model.linear, MaskedModule) + assert isinstance(model.linear.base, nn.Linear) + model.linear.base.weight.copy_(cm.components["linear"].weight) # AND we use all components - input = torch.randn(BATCH_SIZE, d_in) + input_tensor = torch.randn(BATCH_SIZE, d_in) use_all_components = ComponentsMaskInfo(component_mask=torch.ones(BATCH_SIZE, C)) - # THEN the model output matches the component model output - model_out = model(input) - cm_out_with_all_components = cm(input, mask_infos={"linear": use_all_components}) - torch.testing.assert_close(model_out, cm_out_with_all_components) + # THEN the base model output matches the component model output + # (use cm() for both to ensure consistent state management) + base_out = cm(input_tensor) # No mask_infos -> uses base module + components_out = cm(input_tensor, mask_infos={"linear": use_all_components}) + torch.testing.assert_close(base_out, components_out) - # however, WHEN we double the values of the model weights - model.linear.weight.mul_(2) + # however, WHEN we double the base model weights (components unchanged) + model.linear.base.weight.mul_(2) - # THEN the component-only output should be 1/2 the model output - new_model_out = model(input) - new_cm_out_with_all_components = cm(input, mask_infos={"linear": use_all_components}) - torch.testing.assert_close(new_model_out, new_cm_out_with_all_components * 2) + # THEN the base output should be 2x the component output + new_base_out = cm(input_tensor) # Uses modified base weights + new_components_out = cm(input_tensor, mask_infos={"linear": use_all_components}) + torch.testing.assert_close(new_base_out, new_components_out * 2) def test_replacing_identity(): + """Test that identity insertion works with MaskedModule wrapping. + + With MaskedModule design, we must use cm() for all forwards to ensure + consistent state management. Calling model() directly after cm() would + use stale MaskedModule state. + """ d = 10 C = 20 @@ -445,27 +472,104 @@ def forward(self, x: Tensor) -> Tensor: ) # and a random input - input = torch.randn(BATCH_SIZE, d) + input_tensor = torch.randn(BATCH_SIZE, d) - # WHEN we forward with the model + # WHEN we forward with cm (no mask_infos -> uses base modules) # THEN it should just act as the identity - torch.testing.assert_close(model(input), input) - torch.testing.assert_close(cm(input), input) + cm_base_out = cm(input_tensor) + torch.testing.assert_close(cm_base_out, input_tensor) # WHEN we forward with the identity components use_all_components = ComponentsMaskInfo(component_mask=torch.ones(BATCH_SIZE, C)) - cm_components_out = cm(input, mask_infos={"linear.pre_identity": use_all_components}) + cm_components_out = cm(input_tensor, mask_infos={"linear.pre_identity": use_all_components}) + + # THEN it should modify the input (components have random init, not identity) + assert not torch.allclose(cm_components_out, input_tensor) + + # AND when we forward again without mask_infos, it should return to identity behavior + cm_base_out_again = cm(input_tensor) + torch.testing.assert_close(cm_base_out_again, input_tensor) + + +@pytest.mark.parametrize("cache_type", ["none", "input", "component_acts"]) +def test_torch_compile_masked_forward(cache_type: Literal["none", "input", "component_acts"]): + """Verify that torch.compile() works with MaskedModule forward passes. + + This test ensures that the module-patching approach (replacing target modules with + MaskedModule wrappers) is compatible with torch.compile(). The previous hook-based + approach was incompatible because dynamic hook registration/removal breaks graph tracing. + """ + d_in = 8 + d_out = 6 + C = 4 + + class TwoLayerModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(d_in, d_out, bias=False) + self.layer2 = nn.Linear(d_out, d_out, bias=True) + + @override + def forward(self, x: Tensor) -> Tensor: + return self.layer2(torch.relu(self.layer1(x))) + + model = TwoLayerModel() + model.eval() + model.requires_grad_(False) - # THEN it should modify the input - assert not torch.allclose(cm_components_out, input) + cm = ComponentModel( + target_model=model, + module_path_info=[ + ModulePathInfo(module_path="layer1", C=C), + ModulePathInfo(module_path="layer2", C=C), + ], + ci_fn_type="mlp", + ci_fn_hidden_dims=[4], + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) - # BUT the original model output should be unchanged - cm_target_out = cm(input) - assert torch.allclose(cm_target_out, model(input)) + # Compile the component model + compiled_cm = torch.compile(cm, fullgraph=False) + + input_tensor = torch.randn(BATCH_SIZE, d_in) + component_masks = { + "layer1": torch.ones(BATCH_SIZE, C), + "layer2": torch.ones(BATCH_SIZE, C), + } + mask_infos = make_mask_infos(component_masks) + + # Test 1: Compiled forward without components (pure target model) + if cache_type == "none": + eager_out = cm(input_tensor) + compiled_out = compiled_cm(input_tensor) + torch.testing.assert_close(compiled_out, eager_out) + + # Test 2: Compiled forward with components + if cache_type == "none": + eager_out = cm(input_tensor, mask_infos=mask_infos) + compiled_out = compiled_cm(input_tensor, mask_infos=mask_infos) + torch.testing.assert_close(compiled_out, eager_out) + else: + eager_out, eager_cache = cm(input_tensor, mask_infos=mask_infos, cache_type=cache_type) + compiled_out, compiled_cache = compiled_cm( + input_tensor, mask_infos=mask_infos, cache_type=cache_type + ) + torch.testing.assert_close(compiled_out, eager_out) + # Verify cache contents match + assert set(eager_cache.keys()) == set(compiled_cache.keys()) + for key in eager_cache: + torch.testing.assert_close(compiled_cache[key], eager_cache[key]) def test_routing(): + """Test that routing_mask correctly routes some positions to components and others to base. + + With MaskedModule design, we must use cm() for all forwards to ensure + consistent state management. Calling model() directly after cm() would + use stale MaskedModule state. + """ d = 10 C = 20 @@ -495,20 +599,20 @@ def forward(self, x: Tensor) -> Tensor: ) # and a random input - input = torch.randn(BATCH_SIZE, d) + input_tensor = torch.randn(BATCH_SIZE, d) - # WHEN we forward with the model + # WHEN we forward with cm (no mask_infos -> uses base modules) # THEN it should just act as the identity - torch.testing.assert_close(model(input), input) - torch.testing.assert_close(cm(input), input) + base_out = cm(input_tensor) + torch.testing.assert_close(base_out, input_tensor) - # WHEN we forward with the components + # WHEN we forward with the components (all positions routed to components) use_all_components = ComponentsMaskInfo(component_mask=torch.ones(BATCH_SIZE, C)) - cm_components_out = cm(input, mask_infos={"linear": use_all_components}) + cm_components_out = cm(input_tensor, mask_infos={"linear": use_all_components}) - # THEN it should modify the input - assert not torch.allclose(cm_components_out, input) + # THEN it should modify the input (components have random init, not identity) + assert not torch.allclose(cm_components_out, input_tensor) # but WHEN we forward with the components with routing: use_all_components_for_example_0 = ComponentsMaskInfo( @@ -516,12 +620,13 @@ def forward(self, x: Tensor) -> Tensor: routing_mask=torch.tensor([True, False]), # route to components only for example 0 ) - cm_routed_out = cm(input, mask_infos={"linear": use_all_components_for_example_0}) + cm_routed_out = cm(input_tensor, mask_infos={"linear": use_all_components_for_example_0}) - target_out = model(input) + # Get base output for comparison (must use cm() to ensure state is reset properly) + base_out_for_comparison = cm(input_tensor) # THEN the output should be different for the first example (where it's routed to components) - assert not torch.allclose(cm_routed_out[0], target_out[0]) + assert not torch.allclose(cm_routed_out[0], base_out_for_comparison[0]) # but it should be the same for the second example (where it's not routed to components) - assert torch.allclose(cm_routed_out[1], target_out[1]) + torch.testing.assert_close(cm_routed_out[1], base_out_for_comparison[1])