Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b93b9d6
Geometric similarity comparison made consistent with other evals and …
leesharkey Sep 16, 2025
cd5fda2
Replaced mean max cosine sim with mean max ABS cosine sim
leesharkey Sep 17, 2025
61d3408
Configs for geom comparison runs
leesharkey Sep 17, 2025
63c85f0
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey Sep 17, 2025
770a5c5
Minor modifications to make PR-ready
leesharkey Sep 17, 2025
49ba925
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey Sep 17, 2025
364198e
Update seed to be consistent with other configs again
leesharkey Sep 17, 2025
57c2c76
Cleaned up some comments and other bits
leesharkey Sep 18, 2025
2e7752d
Major update of PR following review: Now implemented as script rather…
leesharkey Sep 18, 2025
4fbf807
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey Sep 18, 2025
98a6620
Updated registry to delete old obselete experiments
leesharkey Sep 18, 2025
bede346
Merge branch 'main' into feature/geom_sim_compar
leesharkey Sep 18, 2025
acc04f1
Merge branch 'main' into feature/geom_sim_compar
leesharkey Sep 22, 2025
62bd77e
Reorganized compare_models into subdirectory and cleaned up config code
leesharkey Sep 22, 2025
b84814a
Merging
leesharkey Sep 22, 2025
5173a6a
Updated README.md
leesharkey Sep 22, 2025
181cac8
Added some example models to the config
leesharkey Sep 22, 2025
8db7559
Getting rid of newline
leesharkey Sep 22, 2025
0d05f0a
Minor changes to make the PR mergeable
leesharkey Sep 23, 2025
8767194
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey Sep 23, 2025
019eb2d
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey Sep 24, 2025
b935b4c
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey Sep 29, 2025
3d1edeb
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Sep 30, 2025
1dd738d
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 3, 2025
956f3d4
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 5, 2025
f7ad411
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 6, 2025
ade1377
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 7, 2025
08875a9
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 13, 2025
7ca7037
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 22, 2025
cbbdb61
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 22, 2025
267deb6
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Oct 28, 2025
f49e9e0
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 5, 2025
22f7cfc
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 12, 2025
ab5346d
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 14, 2025
7cb528f
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 20, 2025
01d1b6b
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 21, 2025
a78fdc5
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Nov 24, 2025
296a8d2
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Dec 2, 2025
acf0574
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey Dec 5, 2025
917163f
Fix test_resid_mlp_decomposition_happy_path config mismatch
leesharkey Dec 5, 2025
e029348
Update happy path tests to use default configs
leesharkey Dec 5, 2025
8c2d008
Fix ih_config.yaml: Replace deprecated loss coefficients with loss_me…
leesharkey Dec 5, 2025
b43f70e
Fix test_gpt_2_decomposition_happy_path for new config loading approach
leesharkey Dec 5, 2025
65c5346
Address Dan's PR review comments on test files
leesharkey Dec 18, 2025
b3bd2c6
Remove redundant comment in test_tms.py
leesharkey Dec 18, 2025
ca7ad57
Merge remote-tracking branch 'origin/main' into feature/default_tests
leesharkey Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions spd/experiments/ih/ih_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ target_module_patterns: [
"blocks.*.attn.out_proj",
]

faithfulness_coeff: 100
ci_recon_coeff: 1
stochastic_recon_coeff: 1
ci_recon_layerwise_coeff: null
stochastic_recon_layerwise_coeff: 1
importance_minimality_coeff: 1e-2
pnorm: 0.1
loss_metric_configs:
- classname: "ImportanceMinimalityLoss"
coeff: 1e-2
pnorm: 0.1
- classname: "CIMaskedReconLoss"
coeff: 1.0
- classname: "StochasticReconLoss"
coeff: 1.0
- classname: "StochasticReconLayerwiseLoss"
coeff: 1.0
output_loss_type: kl
ci_fn_type: "vector_mlp"
ci_fn_hidden_dims: [128]
Expand Down
150 changes: 0 additions & 150 deletions tests/test_gpt2.py

This file was deleted.

123 changes: 123 additions & 0 deletions tests/test_gpt2_configs.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the change this actually test ss_gpt2_config, which uses the transformers library gpt2 architecture and our custom model that was uploaded to huggingface.

Now it tests ss_gpt2_simple, which uses the simplestories gpt2_simple architecture and our custom model that was uploaded to wandb.

I think it's fine to keep this test, but I'd rename it to tests/test_ss_gpt2_simple.py.

But I do want to keep one test that has a pretrained_model_class from the transformers library, like the old one had. You may be able to test both configurations in this one test file with mark.parameterize or just a for loop. If doing that, the name of the file might be test_gpt2_configurations.py or something like that.

It would be nice to test the raw gpt too, i.e. the one in the gpt2_config.yaml. But you should test the runtime of that. Not worth it if it adds 5+ seconds to the slow tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed test_gpt2.py → test_gpt2_configs.py and added parametrized test covering both:

  • ss_gpt2_simple (simple_stories_train GPT2Simple, wandb-hosted)
  • ss_gpt2 (transformers.GPT2LMHeadModel from HuggingFace)

So we now test both the new config and a transformers library model like the old test had.

Re: raw gpt2 config - tested the runtime and it times out downloading the openwebtext dataset (8.8M examples), so didn't include it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import pytest

from spd.configs import Config
from spd.data import DatasetConfig, create_data_loader
from spd.experiments.lm.configs import LMTaskConfig
from spd.identity_insertion import insert_identity_operations_
from spd.registry import get_experiment_config_file_contents
from spd.run_spd import optimize
from spd.utils.general_utils import resolve_class, set_seed
from spd.utils.run_utils import apply_nested_updates

# Config-specific test parameters for different GPT2 configurations
GPT2_CONFIG_PARAMS = {
"ss_gpt2_simple": {
# Uses simple_stories_train.models.gpt2_simple.GPT2Simple (wandb-hosted model)
"target_module_patterns": ["h.2.attn.q_proj", "h.3.mlp.c_fc"],
"identity_module_patterns": ["h.1.attn.q_proj"],
},
"ss_gpt2": {
# Uses transformers.GPT2LMHeadModel (HuggingFace transformers library)
"target_module_patterns": ["transformer.h.1.mlp.c_fc"],
"identity_module_patterns": None,
},
}


@pytest.mark.slow
@pytest.mark.parametrize("experiment_name", ["ss_gpt2_simple", "ss_gpt2"])
def test_gpt2_decomposition_happy_path(experiment_name: str) -> None:
"""Test that SPD decomposition works on different GPT-2 configurations.

Tests both:
- ss_gpt2_simple: Uses simple_stories_train GPT2Simple model (wandb-hosted)
- ss_gpt2: Uses transformers.GPT2LMHeadModel (HuggingFace transformers library)
"""
set_seed(0)
device = "cpu"

config_params = GPT2_CONFIG_PARAMS[experiment_name]
base_config_dict = get_experiment_config_file_contents(experiment_name)
test_overrides = {
"wandb_project": None,
"C": 10,
"steps": 2,
"batch_size": 4,
"eval_batch_size": 1,
"train_log_freq": 50,
"n_examples_until_dead": 999,
"task_config.max_seq_len": 8,
"task_config.train_data_split": "train[:100]",
"task_config.eval_data_split": "test[100:200]",
"target_module_patterns": config_params["target_module_patterns"],
"identity_module_patterns": config_params["identity_module_patterns"],
"eval_metric_configs": [], # Disable eval metrics to avoid layer matching issues
}
config_dict = apply_nested_updates(base_config_dict, test_overrides)
config = Config(**config_dict)

assert isinstance(config.task_config, LMTaskConfig), "task_config not LMTaskConfig"
pretrained_model_class = resolve_class(config.pretrained_model_class)
assert hasattr(pretrained_model_class, "from_pretrained"), (
f"Model class {pretrained_model_class} should have a `from_pretrained` method"
)
assert config.pretrained_model_name is not None

# Handle simple_stories_train models specially (they use from_run_info)
if config.pretrained_model_class.startswith("simple_stories_train"):
from simple_stories_train.run_info import RunInfo as SSRunInfo

run_info = SSRunInfo.from_path(config.pretrained_model_name)
assert hasattr(pretrained_model_class, "from_run_info")
target_model = pretrained_model_class.from_run_info(run_info) # pyright: ignore[reportAttributeAccessIssue]
else:
target_model = pretrained_model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue]
target_model.eval()

if config.identity_module_patterns is not None:
insert_identity_operations_(target_model, identity_patterns=config.identity_module_patterns)

train_data_config = DatasetConfig(
name=config.task_config.dataset_name,
hf_tokenizer_path=config.tokenizer_name,
split=config.task_config.train_data_split,
n_ctx=config.task_config.max_seq_len,
is_tokenized=config.task_config.is_tokenized,
streaming=config.task_config.streaming,
column_name=config.task_config.column_name,
seed=None,
)

train_loader, _tokenizer = create_data_loader(
dataset_config=train_data_config,
batch_size=config.batch_size,
buffer_size=config.task_config.buffer_size,
global_seed=config.seed,
)

eval_data_config = DatasetConfig(
name=config.task_config.dataset_name,
hf_tokenizer_path=config.tokenizer_name,
split=config.task_config.eval_data_split,
n_ctx=config.task_config.max_seq_len,
is_tokenized=config.task_config.is_tokenized,
streaming=config.task_config.streaming,
column_name=config.task_config.column_name,
seed=None,
)
eval_loader, _ = create_data_loader(
dataset_config=eval_data_config,
batch_size=config.batch_size,
buffer_size=config.task_config.buffer_size,
global_seed=config.seed + 1,
)

optimize(
target_model=target_model,
config=config,
device=device,
train_loader=train_loader,
eval_loader=eval_loader,
n_eval_steps=config.n_eval_steps,
out_dir=None,
)
Loading