From d8aa006e0d37620e54443d7bb9849eda2e0e1a5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 5 Mar 2025 18:05:27 -0300 Subject: [PATCH 1/3] Using the TerraTorch registry for Swin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_swin.py | 8 ++++---- terratorch/models/backbones/swin_encoder_decoder.py | 3 +++ terratorch/registry/registry.py | 10 +++++++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/terratorch/models/backbones/prithvi_swin.py b/terratorch/models/backbones/prithvi_swin.py index 084bc3c6..0396267a 100644 --- a/terratorch/models/backbones/prithvi_swin.py +++ b/terratorch/models/backbones/prithvi_swin.py @@ -10,12 +10,13 @@ import torch from timm.models import SwinTransformer from timm.models._builder import build_model_with_cfg -from timm.models._registry import generate_default_cfgs, register_model +from timm.models._registry import generate_default_cfgs from terratorch.datasets.utils import HLSBands from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer from terratorch.datasets.utils import generate_bands_intervals +from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY logger = logging.getLogger(__name__) @@ -261,8 +262,7 @@ def prepare_features_for_image_model(x): model.prepare_features_for_image_model = prepare_features_for_image_model return model - -@register_model +@TERRATORCH_BACKBONE_REGISTRY.register def prithvi_swin_B( pretrained: bool = False, # noqa: FBT002, FBT001 pretrained_bands: list[HLSBands] | None = None, @@ -293,7 +293,7 @@ def prithvi_swin_B( return transformer -@register_model +@TERRATORCH_BACKBONE_REGISTRY.register def prithvi_swin_L( pretrained: bool = False, # noqa: FBT002, FBT001 pretrained_bands: list[HLSBands] | None = None, diff --git a/terratorch/models/backbones/swin_encoder_decoder.py b/terratorch/models/backbones/swin_encoder_decoder.py index 0ca68550..bf75f8fa 100644 --- a/terratorch/models/backbones/swin_encoder_decoder.py +++ b/terratorch/models/backbones/swin_encoder_decoder.py @@ -984,6 +984,9 @@ def __init__( in_chans = downsample.out_channels self.stages = nn.Sequential(*stages) self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + + self.out_channels = [int(embed_dim * 2**i) for i in range(self.num_layers)] + # Add a norm layer for each output self.head = ClassifierHead( diff --git a/terratorch/registry/registry.py b/terratorch/registry/registry.py index 91235923..231db55b 100644 --- a/terratorch/registry/registry.py +++ b/terratorch/registry/registry.py @@ -3,6 +3,9 @@ from collections.abc import Callable, Mapping, Set from contextlib import suppress from reprlib import recursive_repr as _recursive_repr +import logging + +logger = logging.getLogger(__name__) class BuildableRegistry(typing.Protocol): def __iter__(self): ... @@ -138,7 +141,12 @@ def build(self, name: str, *constructor_args, **constructor_kwargs): """Build and return the component. Use prefixes ending with _ to forward to a specific source """ - return self._registry[name](*constructor_args, **constructor_kwargs) + try: + return self._registry[name](*constructor_args, **constructor_kwargs) + except KeyError as ke: + for key in self._registry.keys(): + logger.debug(f'Registered registry {key}') + raise ValueError(f"Class '{name}' is not registered.") from ke def __iter__(self): return iter(self._registry) From 1866dde1200a449c3f56ed8722d4be7a9ec127bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 5 Mar 2025 19:05:14 -0300 Subject: [PATCH 2/3] Updating tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- tests/test_backbones.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 546250d9..21492a98 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -35,13 +35,13 @@ def input_386(): return torch.ones((1, NUM_CHANNELS, 386, 386)) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) -@pytest.mark.parametrize("test_input", ["input_224", "input_512"]) -def test_can_create_backbones_from_timm(model_name, test_input, request): - backbone = timm.create_model(model_name, pretrained=False) - input_tensor = request.getfixturevalue(test_input) - backbone(input_tensor) - gc.collect() +#@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) +#@pytest.mark.parametrize("test_input", ["input_224", "input_512"]) +#def test_can_create_backbones_from_timm(model_name, test_input, request): +# backbone = timm.create_model(model_name, pretrained=False) +# input_tensor = request.getfixturevalue(test_input) +# backbone(input_tensor) +# gc.collect() @pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) @pytest.mark.parametrize("test_input", ["input_224", "input_512"]) From 127e7a1257d94be733cc85fde4e55eb427f26850 Mon Sep 17 00:00:00 2001 From: Joao Lucas de Sousa Almeida Date: Wed, 5 Mar 2025 21:29:45 -0300 Subject: [PATCH 3/3] Removing these tests Signed-off-by: Joao Lucas de Sousa Almeida --- tests/test_backbones.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 546250d9..413b3c16 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -35,26 +35,26 @@ def input_386(): return torch.ones((1, NUM_CHANNELS, 386, 386)) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) -@pytest.mark.parametrize("test_input", ["input_224", "input_512"]) -def test_can_create_backbones_from_timm(model_name, test_input, request): - backbone = timm.create_model(model_name, pretrained=False) - input_tensor = request.getfixturevalue(test_input) - backbone(input_tensor) - gc.collect() - -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) -@pytest.mark.parametrize("test_input", ["input_224", "input_512"]) -def test_can_create_backbones_from_timm_features_only(model_name, test_input, request): - backbone = timm.create_model(model_name, pretrained=False, features_only=True) - input_tensor = request.getfixturevalue(test_input) - backbone(input_tensor) - gc.collect() - -@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_swin_B"]) -@pytest.mark.parametrize("prefix", ["", "timm_"]) +#@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) +#@pytest.mark.parametrize("test_input", ["input_224", "input_512"]) +#def test_can_create_backbones_from_timm(model_name, test_input, request): +# backbone = timm.create_model(model_name, pretrained=False) +# input_tensor = request.getfixturevalue(test_input) +# backbone(input_tensor) +# gc.collect() + +#@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"]) +#@pytest.mark.parametrize("test_input", ["input_224", "input_512"]) +#def test_can_create_backbones_from_timm_features_only(model_name, test_input, request): +# backbone = timm.create_model(model_name, pretrained=False, features_only=True) +# input_tensor = request.getfixturevalue(test_input) +# backbone(input_tensor) +# gc.collect() + +@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_B"]) +#@pytest.mark.parametrize("prefix", [""])#, "timm_"]) def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix): - backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False) + backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False) backbone(input_224) gc.collect()