From 56d77c5b23b9a7405da5a95f9b4ff9172804aa58 Mon Sep 17 00:00:00 2001 From: Misko Date: Thu, 20 Feb 2025 11:55:35 -0800 Subject: [PATCH] add backbone override to hydra (#997) * add backbone override to hydra * make it harder to sneak fields into finetune config * fix * fix lint * add a test to make sure override works * fix override test by adding cleanup after assertion is caught --- src/fairchem/core/models/base.py | 9 ++++++ tests/core/e2e/test_e2e_finetune_hydra.py | 34 +++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index e6c3e08206..ce1a68af6c 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -251,6 +251,11 @@ def __init__( # if finetune_config is provided, then attempt to load the model from the given finetune checkpoint starting_model = None if finetune_config is not None: + # Make it hard to sneak more fields into finetuneconfig + assert ( + len(set(finetune_config.keys()) - {"starting_checkpoint", "override"}) + == 0 + ) starting_model: HydraModel = load_model_and_weights_from_checkpoint( finetune_config["starting_checkpoint"] ) @@ -260,6 +265,10 @@ def __init__( assert isinstance( starting_model, HydraModel ), "Can only finetune starting from other hydra models!" + # TODO this is a bit hacky to overrride attrs in the backbone + if "override" in finetune_config: + for key, value in finetune_config["override"].items(): + setattr(starting_model.backbone, key, value) if backbone is not None: backbone = copy.deepcopy(backbone) diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index 4dc2e2efc7..6f0618afb6 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -8,6 +8,7 @@ import torch from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths +from fairchem.core.common import distutils from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint @@ -283,6 +284,39 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src): ) +def test_finetune_hydra_override(tutorial_val_src): + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + model_config = { + "name": "hydra", + "finetune_config": { + "starting_checkpoint": starting_ckpt, + "override": {"forward": None}, + }, + "heads": { + "energy": {"module": "equiformer_v2_energy_head"}, + "forces": {"module": "equiformer_v2_force_head"}, + }, + } + + # TODO add a better test for override when we get there + # for now just override .forward() with None + with pytest.raises(TypeError): + run_main_with_ft_hydra( + tempdir=ft_temp_dir, + yaml=ft_yml, + data_src=tutorial_val_src, + run_args={"seed": 1000}, + model_config=model_config, + output_checkpoint=ck_ft_path, + ) + distutils.cleanup() + + def test_finetune_hydra_data_only(tutorial_val_src): with tempfile.TemporaryDirectory() as orig_ckpt_dir: starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)