diff --git a/.gitignore b/.gitignore index b84f21934..288d13597 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ spd/scripts/sweep_params.yaml docs/coverage/** notebooks/** +# Script outputs (generated files, often large) +scripts/outputs/ + **/out/ neuronpedia_outputs/ .env diff --git a/spd/configs.py b/spd/configs.py index 935d31b37..82282550c 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -13,7 +13,34 @@ from spd.base_config import BaseConfig from spd.log import logger -from spd.spd_types import CiFnType, ModelPath, Probability +from spd.spd_types import GlobalCiFnType, LayerwiseCiFnType, ModelPath, Probability + + +class LayerwiseCiConfig(BaseConfig): + """Configuration for layerwise CI functions (one per layer).""" + + mode: Literal["layerwise"] = "layerwise" + fn_type: LayerwiseCiFnType = Field( + ..., description="Type of layerwise CI function: mlp, vector_mlp, or shared_mlp" + ) + hidden_dims: list[NonNegativeInt] = Field( + ..., description="Hidden dimensions for the CI function MLP" + ) + + +class GlobalCiConfig(BaseConfig): + """Configuration for global CI function (single function for all layers).""" + + mode: Literal["global"] = "global" + fn_type: GlobalCiFnType = Field( + ..., description="Type of global CI function: global_shared_mlp" + ) + hidden_dims: list[NonNegativeInt] = Field( + ..., description="Hidden dimensions for the global CI function MLP" + ) + + +CiConfig = LayerwiseCiConfig | GlobalCiConfig class ScheduleConfig(BaseConfig): @@ -397,13 +424,10 @@ class Config(BaseConfig): ..., description="Number of stochastic masks to sample when using stochastic recon losses", ) - ci_fn_type: CiFnType = Field( - default="vector_mlp", - description="Type of causal importance function used to calculate the causal importance.", - ) - ci_fn_hidden_dims: list[NonNegativeInt] = Field( - default=[8], - description="Hidden dimensions for the causal importance function used to calculate the causal importance", + ci_config: CiConfig = Field( + ..., + description="Configuration for the causal importance function. " + "Use LayerwiseCiConfig for per-layer CI functions or GlobalCiConfig for a single global CI function.", ) sampling: SamplingType = Field( default="continuous", @@ -614,8 +638,6 @@ def microbatch_size(self) -> PositiveInt: "pretrained_model_name_hf": "pretrained_model_name", "recon_coeff": "ci_recon_coeff", "recon_layerwise_coeff": "ci_recon_layerwise_coeff", - "gate_type": "ci_fn_type", - "gate_hidden_dims": "ci_fn_hidden_dims", } @model_validator(mode="before") diff --git a/spd/experiments/ih/ih_config.yaml b/spd/experiments/ih/ih_config.yaml index b9f98f12e..1e3baf297 100644 --- a/spd/experiments/ih/ih_config.yaml +++ b/spd/experiments/ih/ih_config.yaml @@ -34,8 +34,10 @@ stochastic_recon_layerwise_coeff: 1 importance_minimality_coeff: 1e-2 pnorm: 0.1 output_loss_type: kl -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [128] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [128] n_examples_until_dead: 10000 diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index 0c7a3f567..8699dc68a 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [12] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [12] sigmoid_type: "leaky_hard" module_info: - module_pattern: "transformer.h.1.attn.c_attn" diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index 1d0e6760c..8259b4a1c 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [12] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [12] sigmoid_type: "leaky_hard" module_info: - module_pattern: "transformer.h.1.mlp.c_fc" diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index 646602970..b7cf3ba94 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 1 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [550] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index e1d8b8bcd..3cd8ce959 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 1 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [550] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index b6c4d11bb..af810f349 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [1000] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [1000] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index d4640ea3b..23dc85bf2 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [12] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [12] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.c_fc" diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index db10d7fe2..d580e2629 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index 6517ac7d7..92aed19a0 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 95ca435ab..e40d2a96a 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -6,8 +6,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "shared_mlp" -ci_fn_hidden_dims: [1000] +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [1000] sigmoid_type: "leaky_hard" module_info: - module_pattern: "h.*.mlp.gate_proj" diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index 3974c5ae4..35e4c1895 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index 2bacab1c2..91ad8dc3a 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 1250 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [1250] sampling: continuous sigmoid_type: leaky_hard module_info: diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index 708797303..b808a16c5 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 550 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [550] sampling: continuous sigmoid_type: leaky_hard module_info: diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index 10c7de294..38b9786e1 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -3,9 +3,10 @@ wandb_run_name: null wandb_run_name_prefix: '' seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: -- 800 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [800] sampling: continuous sigmoid_type: leaky_hard module_info: diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index d05a49f02..b84014c2e 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -9,8 +9,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "vector_mlp" -ci_fn_hidden_dims: [8] +ci_config: + mode: layerwise + fn_type: vector_mlp + hidden_dims: [8] sigmoid_type: "leaky_hard" module_info: - module_pattern: "transformer.h.3.mlp.c_fc" diff --git a/spd/experiments/resid_mlp/resid_mlp1_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_config.yaml index a3b75ba98..fe11241a0 100644 --- a/spd/experiments/resid_mlp/resid_mlp1_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp1_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "layers.*.mlp_in" diff --git a/spd/experiments/resid_mlp/resid_mlp1_global_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_global_config.yaml new file mode 100644 index 000000000..26fc626bf --- /dev/null +++ b/spd/experiments/resid_mlp/resid_mlp1_global_config.yaml @@ -0,0 +1,85 @@ +# ResidualMLP 1 layer - Global CI +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 0 +n_mask_samples: 1 +ci_config: + mode: global + fn_type: global_shared_mlp + hidden_dims: [400, 300] +sigmoid_type: "leaky_hard" +module_info: + - module_pattern: "layers.*.mlp_in" + C: 100 + - module_pattern: "layers.*.mlp_out" + C: 100 +identity_module_info: null +use_delta_component: true + +# --- Loss config --- +loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: 1e-5 + pnorm: 2.0 + beta: 0 + - classname: "StochasticReconLayerwiseLoss" + coeff: 1.0 + - classname: "StochasticReconLoss" + coeff: 1.0 +output_loss_type: mse + +# --- Training --- +batch_size: 2048 +eval_batch_size: 2048 +steps: 20_000 +lr_schedule: + start_val: 2e-3 + fn_type: constant + warmup_pct: 0.0 + +# --- Faithfulness Warmup --- +faithfulness_warmup_steps: 200 +faithfulness_warmup_lr: 0.01 +faithfulness_warmup_weight_decay: 0.1 + +# --- Logging & Saving --- +train_log_freq: 100 +eval_freq: 500 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +n_examples_until_dead: 204_800 +eval_metric_configs: + - classname: "CIHistograms" + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + identity_patterns: ["layers.*.mlp_in"] + dense_patterns: ["layers.*.mlp_out"] + - classname: "IdentityCIError" + identity_ci: + - layer_pattern: "layers.*.mlp_in" + n_features: 100 + dense_ci: + - layer_pattern: "layers.*.mlp_out" + k: 50 + - classname: "CI_L0" + groups: null + - classname: "CIMeanPerComponent" + - classname: "StochasticHiddenActsReconLoss" + +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" +pretrained_model_path: "wandb:goodfire/spd-pre-Sep-2025/runs/pziyck78" + +# --- Task Specific --- +task_config: + task_name: resid_mlp + feature_probability: 0.01 + data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/resid_mlp/resid_mlp2_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_config.yaml index 2e92bc4bc..f761e23df 100644 --- a/spd/experiments/resid_mlp/resid_mlp2_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp2_config.yaml @@ -7,9 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: shared_mlp -ci_fn_hidden_dims: - - 256 +ci_config: + mode: layerwise + fn_type: shared_mlp + hidden_dims: [256] sigmoid_type: leaky_hard module_info: - module_pattern: "layers.*.mlp_in" diff --git a/spd/experiments/resid_mlp/resid_mlp3_config.yaml b/spd/experiments/resid_mlp/resid_mlp3_config.yaml index ff249b92d..43b79e768 100644 --- a/spd/experiments/resid_mlp/resid_mlp3_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp3_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [128] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [128] sigmoid_type: "leaky_hard" module_info: - module_pattern: "layers.*.mlp_in" diff --git a/spd/experiments/tms/tms_40-10-id_config.yaml b/spd/experiments/tms/tms_40-10-id_config.yaml index c1f497059..80d655f5a 100644 --- a/spd/experiments/tms/tms_40-10-id_config.yaml +++ b/spd/experiments/tms/tms_40-10-id_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/experiments/tms/tms_40-10_config.yaml b/spd/experiments/tms/tms_40-10_config.yaml index 0d307d16e..c3a9aae13 100644 --- a/spd/experiments/tms/tms_40-10_config.yaml +++ b/spd/experiments/tms/tms_40-10_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/experiments/tms/tms_5-2-id_config.yaml b/spd/experiments/tms/tms_5-2-id_config.yaml index cbe02d36b..3043aa3ea 100644 --- a/spd/experiments/tms/tms_5-2-id_config.yaml +++ b/spd/experiments/tms/tms_5-2-id_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 7b36cb18b..f760bb857 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -7,8 +7,10 @@ wandb_run_name_prefix: "" # --- General --- seed: 0 n_mask_samples: 1 -ci_fn_type: "mlp" -ci_fn_hidden_dims: [16] +ci_config: + mode: layerwise + fn_type: mlp + hidden_dims: [16] sigmoid_type: "leaky_hard" module_info: - module_pattern: "linear1" diff --git a/spd/models/component_model.py b/spd/models/component_model.py index aa923179b..5a81271ad 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -10,25 +10,48 @@ from torch.utils.hooks import RemovableHandle from transformers.pytorch_utils import Conv1D as RadfordConv1D -from spd.configs import Config, SamplingType +from spd.configs import CiConfig, Config, GlobalCiConfig, LayerwiseCiConfig, SamplingType from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo from spd.models.components import ( Components, ComponentsMaskInfo, EmbeddingComponents, + GlobalCiFnWrapper, + GlobalSharedMLPCiFn, Identity, + LayerwiseCiFnWrapper, LinearComponents, MLPCiFn, VectorMLPCiFn, VectorSharedMLPCiFn, ) from spd.models.sigmoids import SIGMOID_TYPES, SigmoidType -from spd.spd_types import CiFnType, ModelPath -from spd.utils.general_utils import resolve_class, runtime_cast +from spd.spd_types import GlobalCiFnType, LayerwiseCiFnType, ModelPath +from spd.utils.general_utils import resolve_class from spd.utils.module_utils import ModulePathInfo, expand_module_patterns +def _validate_checkpoint_ci_config_compatibility( + state_dict: dict[str, Tensor], ci_config: CiConfig +) -> None: + """Validate that checkpoint CI weights match the config CI mode.""" + has_layerwise_ci_fns = any(k.startswith("ci_fn._ci_fns") for k in state_dict) + has_global_ci_fn = any(k.startswith("ci_fn._global_ci_fn") for k in state_dict) + + match ci_config: + case LayerwiseCiConfig(): + assert has_layerwise_ci_fns, ( + f"Config specifies layerwise CI but checkpoint has no ci_fn._ci_fns keys " + f"(has ci_fn._global_ci_fn: {has_global_ci_fn})" + ) + case GlobalCiConfig(): + assert has_global_ci_fn, ( + f"Config specifies global CI but checkpoint has no ci_fn._global_ci_fn keys " + f"(has ci_fn._ci_fns: {has_layerwise_ci_fns})" + ) + + @dataclass class SPDRunInfo(RunInfo[Config]): """Run info from training a ComponentModel (i.e. from an SPD run).""" @@ -75,8 +98,7 @@ def __init__( self, target_model: nn.Module, module_path_info: list[ModulePathInfo], - ci_fn_type: CiFnType, - ci_fn_hidden_dims: list[int], + ci_config: CiConfig, sigmoid_type: SigmoidType, pretrained_model_output_attr: str | None, ): @@ -101,15 +123,34 @@ def __init__( {k.replace(".", "-"): self.components[k] for k in sorted(self.components)} ) - self.ci_fns = ComponentModel._create_ci_fns( - target_model=target_model, - module_to_c=self.module_to_c, - ci_fn_type=ci_fn_type, - ci_fn_hidden_dims=ci_fn_hidden_dims, - ) - self._ci_fns = nn.ModuleDict( - {k.replace(".", "-"): self.ci_fns[k] for k in sorted(self.ci_fns)} - ) + match ci_config: + case LayerwiseCiConfig(): + raw_layerwise_ci_fns = { + path: ComponentModel._create_layerwise_ci_fn( + target_module=target_model.get_submodule(path), + C=C, + ci_fn_type=ci_config.fn_type, + ci_fn_hidden_dims=ci_config.hidden_dims, + ) + for path, C in self.module_to_c.items() + } + self.ci_fn = LayerwiseCiFnWrapper( + ci_fns=raw_layerwise_ci_fns, + components=self.components, + ci_fn_type=ci_config.fn_type, + ) + case GlobalCiConfig(): + raw_global_ci_fn = ComponentModel._create_global_ci_fn( + target_model=target_model, + module_to_c=self.module_to_c, + components=self.components, + ci_fn_type=ci_config.fn_type, + ci_fn_hidden_dims=ci_config.hidden_dims, + ) + self.ci_fn = GlobalCiFnWrapper( + global_ci_fn=raw_global_ci_fn, + components=self.components, + ) if sigmoid_type == "leaky_hard": self.lower_leaky_fn = SIGMOID_TYPES["lower_leaky_hard"] @@ -187,28 +228,39 @@ def _create_components( return components @staticmethod - def _create_ci_fn( + def _get_module_input_dim(target_module: nn.Module) -> int: + """Extract input dimension from a Linear-like module. + + For embedding layers, this should not be called - handle them separately. + """ + match target_module: + case nn.Linear(): + return target_module.weight.shape[1] + case RadfordConv1D(): + return target_module.weight.shape[0] + case Identity(): + return target_module.d + case _: + raise ValueError( + f"Module {type(target_module)} not supported. " + "Embedding modules should be handled separately." + ) + + @staticmethod + def _create_layerwise_ci_fn( target_module: nn.Module, C: int, - ci_fn_type: CiFnType, + ci_fn_type: LayerwiseCiFnType, ci_fn_hidden_dims: list[int], ) -> nn.Module: - """Helper to create a causal importance function (ci_fn) based on ci_fn_type and module type.""" + """Helper to create a single layerwise CI function based on ci_fn_type and module type.""" if isinstance(target_module, nn.Embedding): assert ci_fn_type == "mlp", "Embedding modules only supported for ci_fn_type='mlp'" if ci_fn_type == "mlp": return MLPCiFn(C=C, hidden_dims=ci_fn_hidden_dims) - match target_module: - case nn.Linear(): - input_dim = target_module.weight.shape[1] - case RadfordConv1D(): - input_dim = target_module.weight.shape[0] - case Identity(): - input_dim = target_module.d - case _: - raise ValueError(f"Module {type(target_module)} not supported for {ci_fn_type=}") + input_dim = ComponentModel._get_module_input_dim(target_module) match ci_fn_type: case "vector_mlp": @@ -217,22 +269,35 @@ def _create_ci_fn( return VectorSharedMLPCiFn(C=C, input_dim=input_dim, hidden_dims=ci_fn_hidden_dims) @staticmethod - def _create_ci_fns( + def _create_global_ci_fn( target_model: nn.Module, module_to_c: dict[str, int], - ci_fn_type: CiFnType, + components: dict[str, Components], + ci_fn_type: GlobalCiFnType, ci_fn_hidden_dims: list[int], - ) -> dict[str, nn.Module]: - ci_fns: dict[str, nn.Module] = {} + ) -> GlobalSharedMLPCiFn: + """Create a global CI function that takes all layer activations as input.""" + # Build layer_configs: layer_name -> (input_dim, C) + layer_configs: dict[str, tuple[int, int]] = {} for target_module_path, target_module_c in module_to_c.items(): target_module = target_model.get_submodule(target_module_path) - ci_fns[target_module_path] = ComponentModel._create_ci_fn( - target_module=target_module, - C=target_module_c, - ci_fn_type=ci_fn_type, - ci_fn_hidden_dims=ci_fn_hidden_dims, - ) - return ci_fns + component = components[target_module_path] + + # For embeddings, global CI uses component acts (C dimensions) + # For linear-like modules, use the actual input dimension + if isinstance(target_module, nn.Embedding): + assert isinstance(component, EmbeddingComponents) + input_dim = component.C + else: + input_dim = ComponentModel._get_module_input_dim(target_module) + + layer_configs[target_module_path] = (input_dim, target_module_c) + + match ci_fn_type: + case "global_shared_mlp": + return GlobalSharedMLPCiFn( + layer_configs=layer_configs, hidden_dims=ci_fn_hidden_dims + ) def _extract_output(self, raw_output: Any) -> Tensor: """Extract the desired output from the model's raw output. @@ -459,10 +524,9 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": comp_model = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - pretrained_model_output_attr=config.pretrained_model_output_attr, + ci_config=config.ci_config, sigmoid_type=config.sigmoid_type, + pretrained_model_output_attr=config.pretrained_model_output_attr, ) comp_model_weights = torch.load( @@ -471,6 +535,8 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": handle_deprecated_state_dict_keys_(comp_model_weights) + _validate_checkpoint_ci_config_compatibility(comp_model_weights, config.ci_config) + comp_model.load_state_dict(comp_model_weights) return comp_model @@ -487,38 +553,33 @@ def calc_causal_importances( sampling: SamplingType, detach_inputs: bool = False, ) -> CIOutputs: - """Calculate causal importances. + """Calculate causal importances using the unified CI function interface. Args: pre_weight_acts: The activations before each layer in the target model. + sampling: The sampling type for stochastic masks. detach_inputs: Whether to detach the inputs to the causal importance function. Returns: - Tuple of (causal_importances, causal_importances_upper_leaky) dictionaries for each layer. + CIOutputs containing lower_leaky, upper_leaky, and pre_sigmoid CI values. """ + if detach_inputs: + pre_weight_acts = {k: v.detach() for k, v in pre_weight_acts.items()} + + ci_fn_outputs = self.ci_fn(pre_weight_acts) + return self._apply_sigmoid_to_ci_outputs(ci_fn_outputs, sampling) + + def _apply_sigmoid_to_ci_outputs( + self, + ci_fn_outputs: dict[str, Float[Tensor, "... C"]], + sampling: SamplingType, + ) -> CIOutputs: + """Apply sigmoid functions to CI function outputs.""" causal_importances_lower_leaky = {} causal_importances_upper_leaky = {} pre_sigmoid = {} - for target_module_name in pre_weight_acts: - input_activations = pre_weight_acts[target_module_name] - ci_fn = self.ci_fns[target_module_name] - - match ci_fn: - case MLPCiFn(): - ci_fn_input = self.components[target_module_name].get_component_acts( - input_activations - ) - case VectorMLPCiFn() | VectorSharedMLPCiFn(): - ci_fn_input = input_activations - case _: - raise ValueError(f"Unknown ci_fn type: {type(ci_fn)}") - - if detach_inputs: - ci_fn_input = ci_fn_input.detach() - - ci_fn_output = runtime_cast(Tensor, ci_fn(ci_fn_input)) - + for target_module_name, ci_fn_output in ci_fn_outputs.items(): if sampling == "binomial": ci_fn_output_for_lower_leaky = 1.05 * ci_fn_output - 0.05 * torch.rand_like( ci_fn_output @@ -527,11 +588,11 @@ def calc_causal_importances( ci_fn_output_for_lower_leaky = ci_fn_output lower_leaky_output = self.lower_leaky_fn(ci_fn_output_for_lower_leaky) - assert lower_leaky_output.all() <= 1.0 + assert (lower_leaky_output <= 1.0).all() causal_importances_lower_leaky[target_module_name] = lower_leaky_output upper_leaky_output = self.upper_leaky_fn(ci_fn_output) - assert upper_leaky_output.all() >= 0 + assert (upper_leaky_output >= 0).all() causal_importances_upper_leaky[target_module_name] = upper_leaky_output pre_sigmoid[target_module_name] = ci_fn_output diff --git a/spd/models/components.py b/spd/models/components.py index 921c42814..9c9c588ce 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Literal, override +from typing import TYPE_CHECKING, Literal, override import einops import torch @@ -9,6 +9,9 @@ from spd.utils.module_utils import _NonlinearityType, init_param_ +if TYPE_CHECKING: + from spd.spd_types import LayerwiseCiFnType + class ParallelLinear(nn.Module): """C parallel linear layers""" @@ -110,6 +113,44 @@ def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... C"]: return self.layers(x) +class GlobalSharedMLPCiFn(nn.Module): + """Global CI function that concatenates all layer activations and outputs CI for all layers.""" + + def __init__( + self, + layer_configs: dict[str, tuple[int, int]], # layer_name -> (input_dim, C) + hidden_dims: list[int], + ): + super().__init__() + + self.layer_order = sorted(layer_configs.keys()) + self.layer_configs = layer_configs + self.split_sizes = [layer_configs[name][1] for name in self.layer_order] + + total_input_dim = sum(input_dim for input_dim, _ in layer_configs.values()) + total_C = sum(C for _, C in layer_configs.values()) + + self.layers = nn.Sequential() + for i in range(len(hidden_dims)): + in_dim = total_input_dim if i == 0 else hidden_dims[i - 1] + output_dim = hidden_dims[i] + self.layers.append(Linear(in_dim, output_dim, nonlinearity="relu")) + self.layers.append(nn.GELU()) + final_dim = hidden_dims[-1] if len(hidden_dims) > 0 else total_input_dim + self.layers.append(Linear(final_dim, total_C, nonlinearity="linear")) + + @override + def forward( + self, + input_acts: dict[str, Float[Tensor, "... d_in"]], + ) -> dict[str, Float[Tensor, "... C"]]: + inputs_list = [input_acts[name] for name in self.layer_order] + concatenated = torch.cat(inputs_list, dim=-1) + output = self.layers(concatenated) + split_outputs = torch.split(output, self.split_sizes, dim=-1) + return {name: split_outputs[i] for i, name in enumerate(self.layer_order)} + + WeightDeltaAndMask = tuple[Float[Tensor, "d_out d_in"], Float[Tensor, "..."]] @@ -360,3 +401,81 @@ def make_mask_infos( ) return result + + +class LayerwiseCiFnWrapper(nn.Module): + """Wraps a dict of per-layer CI functions with a unified interface. + + Calls each layer's CI function independently on its corresponding input activations. + """ + + def __init__( + self, + ci_fns: dict[str, nn.Module], + components: dict[str, Components], + ci_fn_type: "LayerwiseCiFnType", + ): + super().__init__() + self.layer_names = sorted(ci_fns.keys()) + self.components = components + self.ci_fn_type = ci_fn_type + + # Store as ModuleDict with "." replaced by "-" for state dict compatibility + self._ci_fns = nn.ModuleDict( + {name.replace(".", "-"): ci_fns[name] for name in self.layer_names} + ) + + @override + def forward( + self, + layer_acts: dict[str, Float[Tensor, "..."]], + ) -> dict[str, Float[Tensor, "... C"]]: + outputs: dict[str, Float[Tensor, "... C"]] = {} + + for layer_name in self.layer_names: + ci_fn = self._ci_fns[layer_name.replace(".", "-")] + input_acts = layer_acts[layer_name] + + # MLPCiFn expects component activations, others take raw input + if self.ci_fn_type == "mlp": + ci_fn_input = self.components[layer_name].get_component_acts(input_acts) + else: + ci_fn_input = input_acts + + outputs[layer_name] = ci_fn(ci_fn_input) + + return outputs + + +class GlobalCiFnWrapper(nn.Module): + """Wraps GlobalSharedMLPCiFn with a unified interface. + + Transforms embedding layer inputs to component activations before calling + the underlying global CI function. + """ + + def __init__( + self, + global_ci_fn: GlobalSharedMLPCiFn, + components: dict[str, Components], + ): + super().__init__() + self._global_ci_fn = global_ci_fn + self.components = components + + @override + def forward( + self, + layer_acts: dict[str, Float[Tensor, "..."]], + ) -> dict[str, Float[Tensor, "... C"]]: + transformed: dict[str, Float[Tensor, ...]] = {} + + for layer_name, acts in layer_acts.items(): + component = self.components[layer_name] + if isinstance(component, EmbeddingComponents): + # Embeddings pass token IDs; convert to component activations + transformed[layer_name] = component.get_component_acts(acts) + else: + transformed[layer_name] = acts + + return self._global_ci_fn(transformed) diff --git a/spd/registry.py b/spd/registry.py index fc3db7a33..536bbdbcd 100644 --- a/spd/registry.py +++ b/spd/registry.py @@ -37,42 +37,48 @@ class ExperimentConfig: decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_5-2_config.yaml"), expected_runtime=4, - canonical_run="wandb:goodfire/spd/runs/nbejm03m", + canonical_run="wandb:goodfire/spd/runs/s-38e1a3e2", ), "tms_5-2-id": ExperimentConfig( task_name="tms", decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_5-2-id_config.yaml"), expected_runtime=4, - canonical_run="wandb:goodfire/spd/runs/2orsxfx4", + canonical_run="wandb:goodfire/spd/runs/s-a1c0e9e2", ), "tms_40-10": ExperimentConfig( task_name="tms", decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_40-10_config.yaml"), expected_runtime=5, - canonical_run="wandb:goodfire/spd/runs/nb25nhgw", + canonical_run="wandb:goodfire/spd/runs/s-7387fc20", ), "tms_40-10-id": ExperimentConfig( task_name="tms", decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), config_path=Path("spd/experiments/tms/tms_40-10-id_config.yaml"), expected_runtime=5, - canonical_run="wandb:goodfire/spd/runs/eobwic8t", + canonical_run="wandb:goodfire/spd/runs/s-2a2b5a57", ), "resid_mlp1": ExperimentConfig( task_name="resid_mlp", decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), config_path=Path("spd/experiments/resid_mlp/resid_mlp1_config.yaml"), expected_runtime=3, - canonical_run="wandb:goodfire/spd/runs/0d2lld8j", + canonical_run="wandb:goodfire/spd/runs/s-62fce8c4", + ), + "resid_mlp1_global": ExperimentConfig( + task_name="resid_mlp", + decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), + config_path=Path("spd/experiments/resid_mlp/resid_mlp1_global_config.yaml"), + expected_runtime=3, ), "resid_mlp2": ExperimentConfig( task_name="resid_mlp", decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), config_path=Path("spd/experiments/resid_mlp/resid_mlp2_config.yaml"), expected_runtime=5, - canonical_run="wandb:goodfire/spd/runs/q9uydy18", + canonical_run="wandb:goodfire/spd/runs/s-a9ad193d", ), "resid_mlp3": ExperimentConfig( task_name="resid_mlp", diff --git a/spd/run_spd.py b/spd/run_spd.py index 82245a126..e2d1bd981 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -152,10 +152,9 @@ def create_pgd_data_iter() -> ( model = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, + ci_config=config.ci_config, sigmoid_type=config.sigmoid_type, + pretrained_model_output_attr=config.pretrained_model_output_attr, ) if ln_stds is not None: @@ -196,10 +195,10 @@ def create_pgd_data_iter() -> ( tgt.V.data = src.U.data.T component_params: list[torch.nn.Parameter] = [] - ci_fn_params: list[torch.nn.Parameter] = [] for name in component_model.target_module_paths: component_params.extend(component_model.components[name].parameters()) - ci_fn_params.extend(component_model.ci_fns[name].parameters()) + + ci_fn_params = list(component_model.ci_fn.parameters()) assert len(component_params) > 0, "No parameters found in components to optimize" diff --git a/spd/spd_types.py b/spd/spd_types.py index 012dc554a..894ef7c2b 100644 --- a/spd/spd_types.py +++ b/spd/spd_types.py @@ -45,9 +45,7 @@ def validate_path(v: str | Path) -> str | Path: Path, BeforeValidator(to_root_path), PlainSerializer(lambda x: str(from_root_path(x))) ] - Probability = Annotated[float, Ge(0), Le(1)] - TaskName = Literal["tms", "resid_mlp", "lm", "ih"] - -CiFnType = Literal["mlp", "vector_mlp", "shared_mlp"] +LayerwiseCiFnType = Literal["mlp", "vector_mlp", "shared_mlp"] +GlobalCiFnType = Literal["global_shared_mlp"] diff --git a/spd/utils/logging_utils.py b/spd/utils/logging_utils.py index fd39afeeb..47649ca82 100644 --- a/spd/utils/logging_utils.py +++ b/spd/utils/logging_utils.py @@ -56,14 +56,13 @@ def get_grad_norms_dict( comp_grad_norm_sq_sum += param_grad_sum_sq ci_fn_grad_norm_sq_sum: Float[Tensor, ""] = torch.zeros((), device=device) - for target_module_path, ci_fn in component_model.ci_fns.items(): - for local_param_name, local_param in ci_fn.named_parameters(): - ci_fn_grad = runtime_cast(Tensor, local_param.grad) - ci_fn_grad_sum_sq = ci_fn_grad.pow(2).sum() - key = f"ci_fns/{target_module_path}.{local_param_name}" - assert key not in out, f"Key {key} already exists in grad norms log" - out[key] = ci_fn_grad_sum_sq.sqrt().item() - ci_fn_grad_norm_sq_sum += ci_fn_grad_sum_sq + for local_param_name, local_param in component_model.ci_fn.named_parameters(): + ci_fn_grad = runtime_cast(Tensor, local_param.grad) + ci_fn_grad_sum_sq = ci_fn_grad.pow(2).sum() + key = f"ci_fns/{local_param_name}" + assert key not in out, f"Key {key} already exists in grad norms log" + out[key] = ci_fn_grad_sum_sq.sqrt().item() + ci_fn_grad_norm_sq_sum += ci_fn_grad_sum_sq out["summary/components"] = comp_grad_norm_sq_sum.sqrt().item() out["summary/ci_fns"] = ci_fn_grad_norm_sq_sum.sqrt().item() diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 4647f6a14..b12c2097d 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -21,7 +21,13 @@ from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import HarvestCache, RunState, StateManager -from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig +from spd.configs import ( + Config, + LayerwiseCiConfig, + LMTaskConfig, + ModulePatternInfoConfig, + ScheduleConfig, +) from spd.models.component_model import ComponentModel from spd.utils.module_utils import expand_module_patterns @@ -82,8 +88,7 @@ def app_with_state(): config = Config( n_mask_samples=1, - ci_fn_type="shared_mlp", - ci_fn_hidden_dims=[16], + ci_config=LayerwiseCiConfig(fn_type="shared_mlp", hidden_dims=[16]), sampling="continuous", sigmoid_type="leaky_hard", module_info=[ @@ -115,8 +120,7 @@ def app_with_state(): model = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_config=config.ci_config, pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 3ee626bd0..16c27a802 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -19,7 +19,7 @@ def test_run_clustering_happy_path(monkeypatch: Any): monkeypatch.setattr("spd.utils.run_utils.SPD_OUT_DIR", temp_path) config = ClusteringRunConfig( - model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run + model_path="wandb:goodfire/spd/runs/s-a9ad193d", # A resid_mlp2 run batch_size=4, dataset_seed=0, ensemble_id=None, @@ -38,6 +38,6 @@ def test_run_clustering_happy_path(monkeypatch: Any): plot=100, artifact=100, ), - dataset_streaming=True, # tests in CI very slow without this, see https://github.com/goodfire-ai/spd/pull/199 + dataset_streaming=False, # resid_mlp doesn't support streaming ) main(config) diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index fa32cc1e3..ce594ed76 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -7,6 +7,7 @@ from jaxtyping import Float from torch import Tensor +from spd.configs import LayerwiseCiConfig from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo @@ -56,8 +57,7 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo comp_model = ComponentModel( target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -93,8 +93,7 @@ def make_two_layer_component_model( ModulePathInfo(module_path="fc1", C=1), ModulePathInfo(module_path="fc2", C=1), ], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) diff --git a/tests/scripts_run/test_grid_search.py b/tests/scripts_run/test_grid_search.py index 7b04fddf9..0a1fd7701 100644 --- a/tests/scripts_run/test_grid_search.py +++ b/tests/scripts_run/test_grid_search.py @@ -323,6 +323,11 @@ def test_tms_config_with_loss_sweep(self): "C": 10, "n_mask_samples": 1, "target_module_patterns": ["linear1"], + "ci_config": { + "mode": "layerwise", + "fn_type": "mlp", + "hidden_dims": [16], + }, "loss_metric_configs": [ { "classname": "ImportanceMinimalityLoss", @@ -378,6 +383,11 @@ def test_lm_config_with_loss_sweep(self): "C": 10, "n_mask_samples": 1, "target_module_patterns": ["transformer"], + "ci_config": { + "mode": "layerwise", + "fn_type": "vector_mlp", + "hidden_dims": [12], + }, "loss_metric_configs": [ { "classname": "ImportanceMinimalityLoss", @@ -444,6 +454,11 @@ def test_full_sweep_workflow(self): "C": 10, "n_mask_samples": 1, "target_module_patterns": ["linear1"], + "ci_config": { + "mode": "layerwise", + "fn_type": "mlp", + "hidden_dims": [16], + }, "loss_metric_configs": [ { "classname": "ImportanceMinimalityLoss", diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 7f89a7cea..572af9023 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -10,7 +10,9 @@ from spd.configs import ( Config, + GlobalCiConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ScheduleConfig, TMSTaskConfig, @@ -24,6 +26,8 @@ from spd.models.components import ( ComponentsMaskInfo, EmbeddingComponents, + GlobalCiFnWrapper, + GlobalSharedMLPCiFn, LinearComponents, MLPCiFn, ParallelLinear, @@ -89,8 +93,7 @@ def test_correct_parameters_require_grad(): ModulePathInfo(module_path="conv1d1", C=10), ModulePathInfo(module_path="conv1d2", C=5), ], - ci_fn_type="mlp", - ci_fn_hidden_dims=[4], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -142,8 +145,7 @@ def test_from_run_info(): ModulePatternInfoConfig(module_pattern="conv1d2", C=4), ], identity_module_info=[ModulePatternInfoConfig(module_pattern="linear1", C=4)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[4], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), batch_size=1, steps=1, lr_schedule=ScheduleConfig(start_val=1e-3), @@ -173,8 +175,7 @@ def test_from_run_info(): cm = ComponentModel( target_model=target_model, module_path_info=module_path_info, - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_config=config.ci_config, pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) @@ -280,8 +281,7 @@ def test_full_weight_delta_matches_target_behaviour(): cm = ComponentModel( target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], - ci_fn_type="mlp", - ci_fn_hidden_dims=[4], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -312,8 +312,7 @@ def test_input_cache_captures_pre_weight_input(): cm = ComponentModel( target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=2) for p in target_module_paths], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -347,8 +346,7 @@ def test_weight_deltas(): cm = ComponentModel( target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=3) for p in target_module_paths], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -382,8 +380,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, module_path_info=[ModulePathInfo(module_path="linear", C=C)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -438,8 +435,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, module_path_info=[ModulePathInfo(module_path="linear.pre_identity", C=C)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -488,8 +484,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, module_path_info=[ModulePathInfo(module_path="linear", C=C)], - ci_fn_type="mlp", - ci_fn_hidden_dims=[2], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) @@ -525,3 +520,795 @@ def forward(self, x: Tensor) -> Tensor: # but it should be the same for the second example (where it's not routed to components) assert torch.allclose(cm_routed_out[1], target_out[1]) + + +def test_checkpoint_ci_config_mismatch_global_to_layerwise(): + """Test that loading a global CI checkpoint with layerwise config fails.""" + target_model = SimpleTestModel() + target_model.eval() + target_model.requires_grad_(False) + + with tempfile.TemporaryDirectory() as tmp_dir: + base_dir = Path(tmp_dir) + base_model_dir = base_dir / "test_model" + base_model_dir.mkdir(parents=True, exist_ok=True) + comp_model_dir = base_dir / "comp_model" + comp_model_dir.mkdir(parents=True, exist_ok=True) + + base_model_path = base_model_dir / "model.pth" + save_file(target_model.state_dict(), base_model_path) + + # Create and save a component model with GLOBAL CI + config_global = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + n_examples_until_dead=1, + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + module_path_info = expand_module_patterns(target_model, config_global.all_module_info) + cm_global = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_config=config_global.ci_config, + pretrained_model_output_attr=config_global.pretrained_model_output_attr, + sigmoid_type=config_global.sigmoid_type, + ) + + # Save global CI checkpoint + global_checkpoint_path = comp_model_dir / "global_model.pth" + save_file(cm_global.state_dict(), global_checkpoint_path) + save_file(config_global.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + # Now try to load it with LAYERWISE config - should fail + config_layerwise = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + n_examples_until_dead=1, + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + # Override the checkpoint path and config in the directory + save_file(config_layerwise.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + cm_run_info = SPDRunInfo.from_path(global_checkpoint_path) + # Update config to layerwise after loading run_info + cm_run_info.config = config_layerwise + + with pytest.raises( + AssertionError, + match="Config specifies layerwise CI but checkpoint has no ci_fn._ci_fns keys", + ): + ComponentModel.from_run_info(cm_run_info) + + +def test_checkpoint_ci_config_mismatch_layerwise_to_global(): + """Test that loading a layerwise CI checkpoint with global config fails.""" + target_model = SimpleTestModel() + target_model.eval() + target_model.requires_grad_(False) + + with tempfile.TemporaryDirectory() as tmp_dir: + base_dir = Path(tmp_dir) + base_model_dir = base_dir / "test_model" + base_model_dir.mkdir(parents=True, exist_ok=True) + comp_model_dir = base_dir / "comp_model" + comp_model_dir.mkdir(parents=True, exist_ok=True) + + base_model_path = base_model_dir / "model.pth" + save_file(target_model.state_dict(), base_model_path) + + # Create and save a component model with LAYERWISE CI + config_layerwise = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + n_examples_until_dead=1, + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + module_path_info = expand_module_patterns(target_model, config_layerwise.all_module_info) + cm_layerwise = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_config=config_layerwise.ci_config, + pretrained_model_output_attr=config_layerwise.pretrained_model_output_attr, + sigmoid_type=config_layerwise.sigmoid_type, + ) + + # Save layerwise CI checkpoint + layerwise_checkpoint_path = comp_model_dir / "layerwise_model.pth" + save_file(cm_layerwise.state_dict(), layerwise_checkpoint_path) + save_file(config_layerwise.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + # Now try to load it with GLOBAL config - should fail + config_global = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[4]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + n_examples_until_dead=1, + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + # Override the checkpoint path and config in the directory + save_file(config_global.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + cm_run_info = SPDRunInfo.from_path(layerwise_checkpoint_path) + # Update config to global after loading run_info + cm_run_info.config = config_global + + with pytest.raises( + AssertionError, + match="Config specifies global CI but checkpoint has no ci_fn._global_ci_fn keys", + ): + ComponentModel.from_run_info(cm_run_info) + + +# ============================================================================= +# Global CI Function Tests +# ============================================================================= + + +@pytest.mark.parametrize("hidden_dims", [[], [8], [16, 8]]) +def test_global_shared_mlp_ci_fn_shapes_and_values(hidden_dims: list[int]): + """Test GlobalSharedMLPCiFn produces correct output shapes and valid values.""" + layer_configs = { + "layer1": (10, 5), # (input_dim, C) + "layer2": (20, 3), + "layer3": (15, 7), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=hidden_dims) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 20), + "layer3": torch.randn(BATCH_SIZE, 15), + } + outputs = ci_fn(inputs) + + # Check shapes + assert outputs["layer1"].shape == (BATCH_SIZE, 5) + assert outputs["layer2"].shape == (BATCH_SIZE, 3) + assert outputs["layer3"].shape == (BATCH_SIZE, 7) + + # Check values are valid (not NaN, not Inf) + for name, out in outputs.items(): + assert torch.isfinite(out).all(), f"Output {name} contains NaN or Inf" + + +def test_global_shared_mlp_ci_fn_sorted_layer_order(): + """Test that GlobalSharedMLPCiFn uses sorted layer ordering for determinism.""" + layer_configs = { + "z_layer": (5, 2), + "a_layer": (10, 3), + "m_layer": (8, 4), + } + + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + # Layer order should be sorted alphabetically for deterministic concat/split + assert ci_fn.layer_order == ["a_layer", "m_layer", "z_layer"] + assert ci_fn.split_sizes == [3, 4, 2] # C values in sorted order + + +def test_global_shared_mlp_ci_fn_different_inputs_produce_different_outputs(): + """Test that different inputs produce different CI outputs (not constant function).""" + layer_configs = { + "layer1": (10, 5), + "layer2": (8, 3), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + inputs1 = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 8), + } + inputs2 = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 8), + } + + outputs1 = ci_fn(inputs1) + outputs2 = ci_fn(inputs2) + + # Outputs should differ for different inputs + assert not torch.allclose(outputs1["layer1"], outputs2["layer1"]) + assert not torch.allclose(outputs1["layer2"], outputs2["layer2"]) + + +def test_global_shared_mlp_ci_fn_gradient_flow(): + """Test that gradients flow through GlobalSharedMLPCiFn and are meaningful.""" + layer_configs = { + "layer1": (10, 5), + "layer2": (8, 3), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10, requires_grad=True), + "layer2": torch.randn(BATCH_SIZE, 8, requires_grad=True), + } + outputs = ci_fn(inputs) + + loss = torch.stack([out.sum() for out in outputs.values()]).sum() + loss.backward() + + # Check gradients exist for inputs and are meaningful + assert inputs["layer1"].grad is not None + assert inputs["layer2"].grad is not None + assert torch.isfinite(inputs["layer1"].grad).all() + assert torch.isfinite(inputs["layer2"].grad).all() + assert inputs["layer1"].grad.abs().sum() > 0, "Input gradients should be non-zero" + assert inputs["layer2"].grad.abs().sum() > 0, "Input gradients should be non-zero" + + # Check gradients exist for parameters and are meaningful + for name, param in ci_fn.named_parameters(): + assert param.grad is not None, f"Param {name} has no gradient" + assert torch.isfinite(param.grad).all(), f"Param {name} has NaN/Inf gradient" + assert param.grad.abs().sum() > 0, f"Param {name} has zero gradient" + + +def test_global_shared_mlp_ci_fn_with_seq_dim(): + """Test GlobalSharedMLPCiFn with sequence dimension produces valid outputs.""" + seq_len = 5 + layer_configs = { + "layer1": (10, 4), + "layer2": (8, 3), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[16]) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, seq_len, 10), + "layer2": torch.randn(BATCH_SIZE, seq_len, 8), + } + outputs = ci_fn(inputs) + + # Check shapes + assert outputs["layer1"].shape == (BATCH_SIZE, seq_len, 4) + assert outputs["layer2"].shape == (BATCH_SIZE, seq_len, 3) + + # Check values are valid + for name, out in outputs.items(): + assert torch.isfinite(out).all(), f"Output {name} contains NaN or Inf" + + +def test_global_shared_mlp_ci_fn_single_component(): + """Test GlobalSharedMLPCiFn with C=1 edge case.""" + layer_configs = { + "layer1": (10, 1), + "layer2": (8, 1), + } + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[4]) + + inputs = { + "layer1": torch.randn(BATCH_SIZE, 10), + "layer2": torch.randn(BATCH_SIZE, 8), + } + outputs = ci_fn(inputs) + + assert outputs["layer1"].shape == (BATCH_SIZE, 1) + assert outputs["layer2"].shape == (BATCH_SIZE, 1) + assert torch.isfinite(outputs["layer1"]).all() + assert torch.isfinite(outputs["layer2"]).all() + + +def test_global_shared_mlp_ci_fn_single_layer(): + """Test GlobalSharedMLPCiFn with single layer edge case.""" + layer_configs = {"only_layer": (10, 5)} + ci_fn = GlobalSharedMLPCiFn(layer_configs=layer_configs, hidden_dims=[8]) + + inputs = {"only_layer": torch.randn(BATCH_SIZE, 10)} + outputs = ci_fn(inputs) + + assert outputs["only_layer"].shape == (BATCH_SIZE, 5) + assert torch.isfinite(outputs["only_layer"]).all() + + +def test_component_model_with_global_ci(): + """Test ComponentModel instantiation and forward with global CI config.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + assert isinstance(cm.ci_fn._global_ci_fn, GlobalSharedMLPCiFn) + + # Forward pass should work and match target + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + out = cm(token_ids) + torch.testing.assert_close(out, target_model(token_ids)) + + +def test_component_model_global_ci_calc_causal_importances(): + """Test causal importance calculation with global CI produces valid bounded outputs.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=False, + ) + + for path in target_module_paths: + # Check shapes + assert ci_outputs.lower_leaky[path].shape == (BATCH_SIZE, C) + assert ci_outputs.upper_leaky[path].shape == (BATCH_SIZE, C) + assert ci_outputs.pre_sigmoid[path].shape == (BATCH_SIZE, C) + + # Check bounds (leaky sigmoids allow values slightly outside [0, 1]) + # lower_leaky: bounded to [0, 1], can be negative with small leak + # upper_leaky: bounded to [0, inf), can exceed 1 with small leak + assert (ci_outputs.lower_leaky[path] >= 0).all(), f"{path} lower_leaky < 0" + assert (ci_outputs.lower_leaky[path] <= 1.0).all(), f"{path} lower_leaky > 1" + assert (ci_outputs.upper_leaky[path] >= 0).all(), f"{path} upper_leaky < 0" + # upper_leaky can exceed 1.0 due to leaky behavior (1 + alpha*(x-1) when x>1) + + # Check values are finite + assert torch.isfinite(ci_outputs.pre_sigmoid[path]).all(), f"{path} pre_sigmoid has NaN/Inf" + + +def test_component_model_global_ci_different_inputs_different_ci(): + """Test that different inputs produce different CI values (CI is input-dependent).""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + # Two different token inputs + token_ids_1 = torch.tensor([0, 1], dtype=torch.long) + token_ids_2 = torch.tensor([2, 3], dtype=torch.long) + + _, cache1 = cm(token_ids_1, cache_type="input") + _, cache2 = cm(token_ids_2, cache_type="input") + + ci1 = cm.calc_causal_importances(cache1, sampling="continuous") + ci2 = cm.calc_causal_importances(cache2, sampling="continuous") + + # CI values should differ for different inputs + for path in target_module_paths: + assert not torch.allclose(ci1.pre_sigmoid[path], ci2.pre_sigmoid[path]), ( + f"CI for {path} should differ for different inputs" + ) + + +def test_component_model_global_ci_binomial_sampling(): + """Test global CI with binomial sampling produces valid binary masks.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint(0, target_model.embed.num_embeddings, size=(BATCH_SIZE,)) + _, cache = cm(token_ids, cache_type="input") + + ci = cm.calc_causal_importances(cache, sampling="binomial") + + for path in target_module_paths: + assert ci.lower_leaky[path].shape == (BATCH_SIZE, C) + assert torch.isfinite(ci.lower_leaky[path]).all() + assert torch.isfinite(ci.upper_leaky[path]).all() + + +def test_component_model_global_ci_with_embeddings(): + """Test global CI with embedding layers produces valid outputs.""" + target_model = tiny_target() + + target_module_paths = ["embed", "mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=False, + ) + + # Check all layers including embedding + for path in target_module_paths: + assert ci_outputs.lower_leaky[path].shape == (BATCH_SIZE, C) + assert (ci_outputs.lower_leaky[path] >= 0).all() + assert (ci_outputs.lower_leaky[path] <= 1.0).all() + assert torch.isfinite(ci_outputs.pre_sigmoid[path]).all() + + +def test_component_model_global_ci_gradient_flow(): + """Test gradient flow through global CI - gradients are non-zero and finite.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=False, + ) + + ci_loss = torch.stack([ci.sum() for ci in ci_outputs.lower_leaky.values()]).sum() + ci_loss.backward() + + # Check that global CI function has meaningful gradients + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + for name, param in cm.ci_fn._global_ci_fn.named_parameters(): + assert param.grad is not None, f"Param {name} has no gradient" + assert torch.isfinite(param.grad).all(), f"Param {name} has NaN/Inf gradient" + assert param.grad.abs().sum() > 0, f"Param {name} has zero gradient" + + +def test_component_model_global_ci_detach_inputs_blocks_gradients(): + """Test that detach_inputs=True blocks gradient flow to CI function.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + _, cache = cm(token_ids, cache_type="input") + + # With detach_inputs=True, gradients should still flow to CI fn params + # but from the CI loss, not from upstream + ci_outputs = cm.calc_causal_importances( + pre_weight_acts=cache, + sampling="continuous", + detach_inputs=True, # Detach inputs + ) + + ci_loss = torch.stack([ci.sum() for ci in ci_outputs.lower_leaky.values()]).sum() + ci_loss.backward() + + # CI function should still get gradients (from its own computation) + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + for param in cm.ci_fn._global_ci_fn.parameters(): + assert param.grad is not None + + +def test_component_model_global_ci_masking_zeros(): + """Test that zero masks actually zero out component contributions.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + weight_deltas = cm.calc_weight_deltas() + + # All ones mask - should match target + all_ones_masks = {name: torch.ones(BATCH_SIZE, C) for name in target_module_paths} + weight_deltas_and_masks_ones = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos_ones = make_mask_infos( + all_ones_masks, weight_deltas_and_masks=weight_deltas_and_masks_ones + ) + out_ones = cm(token_ids, mask_infos=mask_infos_ones) + + # All zeros mask - should be different from all ones + all_zeros_masks = {name: torch.zeros(BATCH_SIZE, C) for name in target_module_paths} + weight_deltas_and_masks_zeros = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos_zeros = make_mask_infos( + all_zeros_masks, weight_deltas_and_masks=weight_deltas_and_masks_zeros + ) + out_zeros = cm(token_ids, mask_infos=mask_infos_zeros) + + # Outputs should differ + assert not torch.allclose(out_ones, out_zeros), ( + "Zero masks should produce different output than one masks" + ) + + +def test_component_model_global_ci_partial_masking(): + """Test that partial masks produce outputs between fully masked extremes.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + weight_deltas = cm.calc_weight_deltas() + + # Partial mask (0.5 for all) + partial_masks = {name: torch.full((BATCH_SIZE, C), 0.5) for name in target_module_paths} + weight_deltas_and_masks = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos = make_mask_infos(partial_masks, weight_deltas_and_masks=weight_deltas_and_masks) + out_partial = cm(token_ids, mask_infos=mask_infos) + + # Should produce valid output + assert torch.isfinite(out_partial).all(), "Partial masking produced NaN/Inf" + + +def test_component_model_global_ci_weight_deltas_all_ones_matches_target(): + """Test that all-ones mask with weight deltas matches target model output.""" + target_model = tiny_target() + + target_module_paths = ["mlp", "out"] + C = 4 + cm = ComponentModel( + target_model=target_model, + module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), + pretrained_model_output_attr=None, + sigmoid_type="leaky_hard", + ) + + token_ids = torch.randint( + low=0, high=target_model.embed.num_embeddings, size=(BATCH_SIZE,), dtype=torch.long + ) + + weight_deltas = cm.calc_weight_deltas() + component_masks = {name: torch.ones(BATCH_SIZE, C) for name in target_module_paths} + weight_deltas_and_masks = { + name: (weight_deltas[name], torch.ones(BATCH_SIZE)) for name in target_module_paths + } + mask_infos = make_mask_infos(component_masks, weight_deltas_and_masks=weight_deltas_and_masks) + out = cm(token_ids, mask_infos=mask_infos) + + torch.testing.assert_close(out, target_model(token_ids)) + + +def test_global_ci_save_and_load(): + """Test saving and loading a model with global CI preserves functionality.""" + target_model = SimpleTestModel() + target_model.eval() + target_model.requires_grad_(False) + + with tempfile.TemporaryDirectory() as tmp_dir: + base_dir = Path(tmp_dir) + base_model_dir = base_dir / "test_model" + base_model_dir.mkdir(parents=True, exist_ok=True) + comp_model_dir = base_dir / "comp_model" + comp_model_dir.mkdir(parents=True, exist_ok=True) + + base_model_path = base_model_dir / "model.pth" + save_file(target_model.state_dict(), base_model_path) + + config = Config( + pretrained_model_class="tests.test_component_model.SimpleTestModel", + pretrained_model_path=base_model_path, + pretrained_model_name=None, + module_info=[ + ModulePatternInfoConfig(module_pattern="linear1", C=4), + ModulePatternInfoConfig(module_pattern="linear2", C=4), + ], + ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[8]), + batch_size=1, + steps=1, + lr_schedule=ScheduleConfig(start_val=1e-3), + n_eval_steps=1, + eval_batch_size=1, + eval_freq=1, + slow_eval_freq=1, + loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], + n_examples_until_dead=1, + output_loss_type="mse", + train_log_freq=1, + n_mask_samples=1, + task_config=TMSTaskConfig( + task_name="tms", + feature_probability=0.5, + data_generation_type="exactly_one_active", + ), + ) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + cm = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_config=config.ci_config, + pretrained_model_output_attr=config.pretrained_model_output_attr, + sigmoid_type=config.sigmoid_type, + ) + + assert isinstance(cm.ci_fn, GlobalCiFnWrapper) + + save_file(cm.state_dict(), comp_model_dir / "model.pth") + save_file(config.model_dump(mode="json"), comp_model_dir / "final_config.yaml") + + # Load and verify + cm_run_info = SPDRunInfo.from_path(comp_model_dir / "model.pth") + cm_loaded = ComponentModel.from_run_info(cm_run_info) + + assert isinstance(cm_loaded.ci_fn, GlobalCiFnWrapper) + + # Verify state dict matches + for k, v in cm_loaded.state_dict().items(): + torch.testing.assert_close(v, cm.state_dict()[k]) + + # Verify global CI function weights specifically + global_ci_fn = cm.ci_fn._global_ci_fn + global_ci_fn_loaded = cm_loaded.ci_fn._global_ci_fn + assert global_ci_fn_loaded.layer_order == global_ci_fn.layer_order + for p1, p2 in zip(global_ci_fn.parameters(), global_ci_fn_loaded.parameters(), strict=True): + torch.testing.assert_close(p1, p2) + + # Verify global CI function produces same outputs + test_acts = { + name: torch.randn(BATCH_SIZE, global_ci_fn.layer_configs[name][0]) + for name in global_ci_fn.layer_order + } + outputs_orig = global_ci_fn(test_acts) + outputs_loaded = global_ci_fn_loaded(test_acts) + for name in global_ci_fn.layer_order: + torch.testing.assert_close(outputs_orig[name], outputs_loaded[name]) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 0289173ab..c0ccdb80e 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -18,8 +18,11 @@ "seed": 0, "C": 3, "n_mask_samples": 1, - "ci_fn_type": "vector_mlp", - "ci_fn_hidden_dims": [2], + "ci_config": { + "mode": "layerwise", + "fn_type": "vector_mlp", + "hidden_dims": [2], + }, "sigmoid_type": "leaky_hard", "target_module_patterns": ["model.layers.0.mlp.gate_proj"], # --- Loss metrics --- diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index dedf0ba58..cb6d3ec05 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -8,6 +8,7 @@ Config, FaithfulnessLossConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig, @@ -35,8 +36,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="vector_mlp", - ci_fn_hidden_dims=[128], + ci_config=LayerwiseCiConfig(fn_type="vector_mlp", hidden_dims=[128]), module_info=[ ModulePatternInfoConfig(module_pattern="transformer.h.2.attn.c_attn", C=10), ModulePatternInfoConfig(module_pattern="transformer.h.3.mlp.c_fc", C=10), diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 2e926eae1..d5af3882b 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -8,6 +8,7 @@ FaithfulnessLossConfig, IHTaskConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ScheduleConfig, StochasticHiddenActsReconLossConfig, @@ -50,8 +51,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="vector_mlp", - ci_fn_hidden_dims=[128], + ci_config=LayerwiseCiConfig(fn_type="vector_mlp", hidden_dims=[128]), module_info=[ ModulePatternInfoConfig(module_pattern="blocks.*.attn.q_proj", C=10), ModulePatternInfoConfig(module_pattern="blocks.*.attn.k_proj", C=10), diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index dfe771e00..7f3d2e3a5 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -4,6 +4,7 @@ Config, FaithfulnessLossConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ResidMLPTaskConfig, ScheduleConfig, @@ -43,8 +44,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="mlp", - ci_fn_hidden_dims=[8], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), loss_metric_configs=[ ImportanceMinimalityLossConfig( coeff=3e-3, diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 69a546f6e..4d50dd1d5 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -5,7 +5,7 @@ from jaxtyping import Float from torch import Tensor -from spd.configs import UniformKSubsetRoutingConfig +from spd.configs import LayerwiseCiConfig, UniformKSubsetRoutingConfig from spd.metrics import ( ci_masked_recon_layerwise_loss, ci_masked_recon_loss, @@ -40,8 +40,7 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel comp_model = ComponentModel( target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], - ci_fn_hidden_dims=[2], - ci_fn_type="mlp", + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), pretrained_model_output_attr=None, sigmoid_type="leaky_hard", ) diff --git a/tests/test_tms.py b/tests/test_tms.py index bbbcec4cb..81648880e 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -8,6 +8,7 @@ Config, FaithfulnessLossConfig, ImportanceMinimalityLossConfig, + LayerwiseCiConfig, ModulePatternInfoConfig, ScheduleConfig, StochasticReconLayerwiseLossConfig, @@ -47,8 +48,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: # General seed=0, n_mask_samples=1, - ci_fn_type="mlp", - ci_fn_hidden_dims=[8], + ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[8]), module_info=[ ModulePatternInfoConfig(module_pattern="linear1", C=10), ModulePatternInfoConfig(module_pattern="linear2", C=10),