Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/swin instantiation #25

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ dependencies = [
"lightly>=1.4.25",
"h5py>=3.10.0",
"geobench>=1.0.0",
"mlflow>=2.12.1"
"mlflow>=2.12.1",
"lightning<=2.2.5"
]

[project.optional-dependencies]
Expand Down
7 changes: 2 additions & 5 deletions src/terratorch/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# import so they get registered
from terratorch.models.backbones.prithvi_vit import TemporalViTEncoder

__all__ = ["TemporalViTEncoder"]
__all__ = ["TemporalViTEncoder"]
__all__ = ["TemporalViTEncoder"]
import terratorch.models.backbones.prithvi_vit
import terratorch.models.backbones.prithvi_swin
40 changes: 0 additions & 40 deletions src/terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,6 @@ def _cfg(file: Path = "", **kwargs) -> dict:
**kwargs,
}

default_cfgs = generate_default_cfgs(
{
"prithvi_swin_90_us": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
"hf_hub_filename": "Prithvi_100M.pt"
}
}
)

def convert_weights_swin2mmseg(ckpt):
# from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py
new_ckpt = OrderedDict()
Expand Down Expand Up @@ -215,37 +206,6 @@ def prepare_features_for_image_model(x):
return model


@register_model
def prithvi_swin_90_us(
pretrained: bool = False, # noqa: FBT002, FBT001
pretrained_bands: list[HLSBands] | None = None,
bands: list[int] | None = None,
**kwargs,
) -> MMSegSwinTransformer:
"""Prithvi Swin 90M"""
if pretrained_bands is None:
pretrained_bands = PRETRAINED_BANDS
if bands is None:
bands = pretrained_bands
logging.info(
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
Pretrained patch_embed layer may be misaligned with current bands"
)

model_args = {
"patch_size": 4,
"window_size": 7,
"embed_dim": 128,
"depths": (2, 2, 18, 2),
"in_chans": 6,
"num_heads": (4, 8, 16, 32),
}
transformer = _create_swin_mmseg_transformer(
"prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
)
return transformer


@register_model
def prithvi_swin_B(
pretrained: bool = False, # noqa: FBT002, FBT001
Expand Down
4 changes: 2 additions & 2 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def input_386():
return torch.ones((1, NUM_CHANNELS, 386, 386))


@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) #["prithvi_swin_90_us", "prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)


@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_prithvi_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def model_input() -> torch.Tensor:
return torch.ones((1, NUM_CHANNELS, 224, 224))


@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
Expand All @@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv
)


@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
Expand All @@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM
)


@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
Expand Down
Loading