Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] Using the TerraTorch registry for Swin #475

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import torch
from timm.models import SwinTransformer
from timm.models._builder import build_model_with_cfg
from timm.models._registry import generate_default_cfgs, register_model
from timm.models._registry import generate_default_cfgs

from terratorch.datasets.utils import HLSBands
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer
from terratorch.datasets.utils import generate_bands_intervals
from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -261,8 +262,7 @@ def prepare_features_for_image_model(x):
model.prepare_features_for_image_model = prepare_features_for_image_model
return model


@register_model
@TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_swin_B(
pretrained: bool = False, # noqa: FBT002, FBT001
pretrained_bands: list[HLSBands] | None = None,
Expand Down Expand Up @@ -293,7 +293,7 @@ def prithvi_swin_B(
return transformer


@register_model
@TERRATORCH_BACKBONE_REGISTRY.register
def prithvi_swin_L(
pretrained: bool = False, # noqa: FBT002, FBT001
pretrained_bands: list[HLSBands] | None = None,
Expand Down
3 changes: 3 additions & 0 deletions terratorch/models/backbones/swin_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,9 @@ def __init__(
in_chans = downsample.out_channels
self.stages = nn.Sequential(*stages)
self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]

self.out_channels = [int(embed_dim * 2**i) for i in range(self.num_layers)]

# Add a norm layer for each output

self.head = ClassifierHead(
Expand Down
7 changes: 6 additions & 1 deletion terratorch/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ def build(self, name: str, *constructor_args, **constructor_kwargs):
"""Build and return the component.
Use prefixes ending with _ to forward to a specific source
"""
return self._registry[name](*constructor_args, **constructor_kwargs)
try:
return self._registry[name](*constructor_args, **constructor_kwargs)
except KeyError as ke:
for key in self._registry.keys():
logger.debug(f'Registered registry {key}')
raise ValueError(f"Class '{name}' is not registered.") from ke

def __iter__(self):
return iter(self._registry)
Expand Down
38 changes: 19 additions & 19 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@ def input_386():
return torch.ones((1, NUM_CHANNELS, 386, 386))


@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_swin_B"])
@pytest.mark.parametrize("prefix", ["", "timm_"])
#@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"])
#@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
#def test_can_create_backbones_from_timm(model_name, test_input, request):
# backbone = timm.create_model(model_name, pretrained=False)
# input_tensor = request.getfixturevalue(test_input)
# backbone(input_tensor)
# gc.collect()

#@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"])
#@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
#def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
# backbone = timm.create_model(model_name, pretrained=False, features_only=True)
# input_tensor = request.getfixturevalue(test_input)
# backbone(input_tensor)
# gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_B"])
#@pytest.mark.parametrize("prefix", [""])#, "timm_"])
def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix):
backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False)
backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False)
backbone(input_224)
gc.collect()

Expand Down
Loading