-
Notifications
You must be signed in to change notification settings - Fork 35
Update happy path tests to use default configs #290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lee-goodfire
wants to merge
46
commits into
main
Choose a base branch
from
feature/default_tests
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
b93b9d6
Geometric similarity comparison made consistent with other evals and …
leesharkey cd5fda2
Replaced mean max cosine sim with mean max ABS cosine sim
leesharkey 61d3408
Configs for geom comparison runs
leesharkey 63c85f0
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey 770a5c5
Minor modifications to make PR-ready
leesharkey 49ba925
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey 364198e
Update seed to be consistent with other configs again
leesharkey 57c2c76
Cleaned up some comments and other bits
leesharkey 2e7752d
Major update of PR following review: Now implemented as script rather…
leesharkey 4fbf807
Merge remote-tracking branch 'origin/main' into feature/geom_sim_compar
leesharkey 98a6620
Updated registry to delete old obselete experiments
leesharkey bede346
Merge branch 'main' into feature/geom_sim_compar
leesharkey acc04f1
Merge branch 'main' into feature/geom_sim_compar
leesharkey 62bd77e
Reorganized compare_models into subdirectory and cleaned up config code
leesharkey b84814a
Merging
leesharkey 5173a6a
Updated README.md
leesharkey 181cac8
Added some example models to the config
leesharkey 8db7559
Getting rid of newline
leesharkey 0d05f0a
Minor changes to make the PR mergeable
leesharkey 8767194
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey 019eb2d
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey b935b4c
Merge branch 'main' of https://github.com/goodfire-ai/spd
leesharkey 3d1edeb
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 1dd738d
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 956f3d4
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey f7ad411
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey ade1377
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 08875a9
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 7ca7037
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey cbbdb61
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 267deb6
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey f49e9e0
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 22f7cfc
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey ab5346d
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 7cb528f
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 01d1b6b
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey a78fdc5
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 296a8d2
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey acf0574
Merge branch 'main' of github.com:goodfire-ai/spd
leesharkey 917163f
Fix test_resid_mlp_decomposition_happy_path config mismatch
leesharkey e029348
Update happy path tests to use default configs
leesharkey 8c2d008
Fix ih_config.yaml: Replace deprecated loss coefficients with loss_me…
leesharkey b43f70e
Fix test_gpt_2_decomposition_happy_path for new config loading approach
leesharkey 65c5346
Address Dan's PR review comments on test files
leesharkey b3bd2c6
Remove redundant comment in test_tms.py
leesharkey ca7ad57
Merge remote-tracking branch 'origin/main' into feature/default_tests
leesharkey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| import pytest | ||
|
|
||
| from spd.configs import Config | ||
| from spd.data import DatasetConfig, create_data_loader | ||
| from spd.experiments.lm.configs import LMTaskConfig | ||
| from spd.identity_insertion import insert_identity_operations_ | ||
| from spd.registry import get_experiment_config_file_contents | ||
| from spd.run_spd import optimize | ||
| from spd.utils.general_utils import resolve_class, set_seed | ||
| from spd.utils.run_utils import apply_nested_updates | ||
|
|
||
| # Config-specific test parameters for different GPT2 configurations | ||
| GPT2_CONFIG_PARAMS = { | ||
| "ss_gpt2_simple": { | ||
| # Uses simple_stories_train.models.gpt2_simple.GPT2Simple (wandb-hosted model) | ||
| "target_module_patterns": ["h.2.attn.q_proj", "h.3.mlp.c_fc"], | ||
| "identity_module_patterns": ["h.1.attn.q_proj"], | ||
| }, | ||
| "ss_gpt2": { | ||
| # Uses transformers.GPT2LMHeadModel (HuggingFace transformers library) | ||
| "target_module_patterns": ["transformer.h.1.mlp.c_fc"], | ||
| "identity_module_patterns": None, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.slow | ||
| @pytest.mark.parametrize("experiment_name", ["ss_gpt2_simple", "ss_gpt2"]) | ||
| def test_gpt2_decomposition_happy_path(experiment_name: str) -> None: | ||
| """Test that SPD decomposition works on different GPT-2 configurations. | ||
|
|
||
| Tests both: | ||
| - ss_gpt2_simple: Uses simple_stories_train GPT2Simple model (wandb-hosted) | ||
| - ss_gpt2: Uses transformers.GPT2LMHeadModel (HuggingFace transformers library) | ||
| """ | ||
| set_seed(0) | ||
| device = "cpu" | ||
|
|
||
| config_params = GPT2_CONFIG_PARAMS[experiment_name] | ||
| base_config_dict = get_experiment_config_file_contents(experiment_name) | ||
| test_overrides = { | ||
| "wandb_project": None, | ||
| "C": 10, | ||
| "steps": 2, | ||
| "batch_size": 4, | ||
| "eval_batch_size": 1, | ||
| "train_log_freq": 50, | ||
| "n_examples_until_dead": 999, | ||
| "task_config.max_seq_len": 8, | ||
| "task_config.train_data_split": "train[:100]", | ||
| "task_config.eval_data_split": "test[100:200]", | ||
| "target_module_patterns": config_params["target_module_patterns"], | ||
| "identity_module_patterns": config_params["identity_module_patterns"], | ||
| "eval_metric_configs": [], # Disable eval metrics to avoid layer matching issues | ||
| } | ||
| config_dict = apply_nested_updates(base_config_dict, test_overrides) | ||
| config = Config(**config_dict) | ||
|
|
||
| assert isinstance(config.task_config, LMTaskConfig), "task_config not LMTaskConfig" | ||
| pretrained_model_class = resolve_class(config.pretrained_model_class) | ||
| assert hasattr(pretrained_model_class, "from_pretrained"), ( | ||
| f"Model class {pretrained_model_class} should have a `from_pretrained` method" | ||
| ) | ||
| assert config.pretrained_model_name is not None | ||
|
|
||
| # Handle simple_stories_train models specially (they use from_run_info) | ||
| if config.pretrained_model_class.startswith("simple_stories_train"): | ||
| from simple_stories_train.run_info import RunInfo as SSRunInfo | ||
|
|
||
| run_info = SSRunInfo.from_path(config.pretrained_model_name) | ||
| assert hasattr(pretrained_model_class, "from_run_info") | ||
| target_model = pretrained_model_class.from_run_info(run_info) # pyright: ignore[reportAttributeAccessIssue] | ||
| else: | ||
| target_model = pretrained_model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] | ||
| target_model.eval() | ||
|
|
||
| if config.identity_module_patterns is not None: | ||
| insert_identity_operations_(target_model, identity_patterns=config.identity_module_patterns) | ||
|
|
||
| train_data_config = DatasetConfig( | ||
| name=config.task_config.dataset_name, | ||
| hf_tokenizer_path=config.tokenizer_name, | ||
| split=config.task_config.train_data_split, | ||
| n_ctx=config.task_config.max_seq_len, | ||
| is_tokenized=config.task_config.is_tokenized, | ||
| streaming=config.task_config.streaming, | ||
| column_name=config.task_config.column_name, | ||
| seed=None, | ||
| ) | ||
|
|
||
| train_loader, _tokenizer = create_data_loader( | ||
| dataset_config=train_data_config, | ||
| batch_size=config.batch_size, | ||
| buffer_size=config.task_config.buffer_size, | ||
| global_seed=config.seed, | ||
| ) | ||
|
|
||
| eval_data_config = DatasetConfig( | ||
| name=config.task_config.dataset_name, | ||
| hf_tokenizer_path=config.tokenizer_name, | ||
| split=config.task_config.eval_data_split, | ||
| n_ctx=config.task_config.max_seq_len, | ||
| is_tokenized=config.task_config.is_tokenized, | ||
| streaming=config.task_config.streaming, | ||
| column_name=config.task_config.column_name, | ||
| seed=None, | ||
| ) | ||
| eval_loader, _ = create_data_loader( | ||
| dataset_config=eval_data_config, | ||
| batch_size=config.batch_size, | ||
| buffer_size=config.task_config.buffer_size, | ||
| global_seed=config.seed + 1, | ||
| ) | ||
|
|
||
| optimize( | ||
| target_model=target_model, | ||
| config=config, | ||
| device=device, | ||
| train_loader=train_loader, | ||
| eval_loader=eval_loader, | ||
| n_eval_steps=config.n_eval_steps, | ||
| out_dir=None, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before the change this actually test ss_gpt2_config, which uses the transformers library gpt2 architecture and our custom model that was uploaded to huggingface.
Now it tests ss_gpt2_simple, which uses the simplestories gpt2_simple architecture and our custom model that was uploaded to wandb.
I think it's fine to keep this test, but I'd rename it to t
ests/test_ss_gpt2_simple.py.But I do want to keep one test that has a pretrained_model_class from the transformers library, like the old one had. You may be able to test both configurations in this one test file with mark.parameterize or just a for loop. If doing that, the name of the file might be test_gpt2_configurations.py or something like that.
It would be nice to test the raw gpt too, i.e. the one in the gpt2_config.yaml. But you should test the runtime of that. Not worth it if it adds 5+ seconds to the slow tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed test_gpt2.py → test_gpt2_configs.py and added parametrized test covering both:
So we now test both the new config and a transformers library model like the old test had.
Re: raw gpt2 config - tested the runtime and it times out downloading the openwebtext dataset (8.8M examples), so didn't include it.