diff --git a/pyproject.toml b/pyproject.toml index adb75a3d..818d9fe0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ "lightly>=1.4.25", "h5py>=3.10.0", "geobench>=1.0.0", - "mlflow>=2.12.1" + "mlflow>=2.12.1", + "lightning<=2.2.5" ] [project.optional-dependencies] diff --git a/src/terratorch/models/backbones/__init__.py b/src/terratorch/models/backbones/__init__.py index 9f91a5ea..50c59bd5 100644 --- a/src/terratorch/models/backbones/__init__.py +++ b/src/terratorch/models/backbones/__init__.py @@ -1,6 +1,3 @@ # import so they get registered -from terratorch.models.backbones.prithvi_vit import TemporalViTEncoder - -__all__ = ["TemporalViTEncoder"] -__all__ = ["TemporalViTEncoder"] -__all__ = ["TemporalViTEncoder"] +import terratorch.models.backbones.prithvi_vit +import terratorch.models.backbones.prithvi_swin \ No newline at end of file diff --git a/src/terratorch/models/backbones/prithvi_swin.py b/src/terratorch/models/backbones/prithvi_swin.py index 71aea43a..d83edac0 100644 --- a/src/terratorch/models/backbones/prithvi_swin.py +++ b/src/terratorch/models/backbones/prithvi_swin.py @@ -36,15 +36,6 @@ def _cfg(file: Path = "", **kwargs) -> dict: **kwargs, } -default_cfgs = generate_default_cfgs( - { - "prithvi_swin_90_us": { - "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M", - "hf_hub_filename": "Prithvi_100M.pt" - } - } -) - def convert_weights_swin2mmseg(ckpt): # from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py new_ckpt = OrderedDict() @@ -215,37 +206,6 @@ def prepare_features_for_image_model(x): return model -@register_model -def prithvi_swin_90_us( - pretrained: bool = False, # noqa: FBT002, FBT001 - pretrained_bands: list[HLSBands] | None = None, - bands: list[int] | None = None, - **kwargs, -) -> MMSegSwinTransformer: - """Prithvi Swin 90M""" - if pretrained_bands is None: - pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - logging.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - - model_args = { - "patch_size": 4, - "window_size": 7, - "embed_dim": 128, - "depths": (2, 2, 18, 2), - "in_chans": 6, - "num_heads": (4, 8, 16, 32), - } - transformer = _create_swin_mmseg_transformer( - "prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs) - ) - return transformer - - @register_model def prithvi_swin_B( pretrained: bool = False, # noqa: FBT002, FBT001 diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 17b3f6ca..6b8c36b9 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -28,7 +28,7 @@ def input_386(): return torch.ones((1, NUM_CHANNELS, 386, 386)) -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) #["prithvi_swin_90_us", "prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False) @@ -36,7 +36,7 @@ def test_can_create_backbones_from_timm(model_name, test_input, request): backbone(input_tensor) -@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) def test_can_create_backbones_from_timm_features_only(model_name, test_input, request): backbone = timm.create_model(model_name, pretrained=False, features_only=True) diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py index 66a911eb..ca5280a3 100644 --- a/tests/test_prithvi_tasks.py +++ b/tests/test_prithvi_tasks.py @@ -20,7 +20,7 @@ def model_input() -> torch.Tensor: return torch.ones((1, NUM_CHANNELS, 224, 224)) -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory): @@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv ) -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory): @@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM ) -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):