diff --git a/pyproject.toml b/pyproject.toml index 818d9fe0..49a2229d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = 'setuptools.build_meta' [project] name = "terratorch" -version = "0.99.0" +version = "0.99.1" description = "TerraTorch - A model training toolkit for geospatial tasks" license = { "text" = "Apache License, Version 2.0" } readme = "README.md" @@ -144,7 +144,7 @@ exclude_lines = [ ] [tool.bumpver] -current_version = "0.99.0" +current_version = "0.99.1" version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]" commit_message = "Bump version {old_version} -> {new_version}" commit = true diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 6b8c36b9..760039cb 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -1,6 +1,7 @@ import pytest import timm import torch +import importlib import terratorch # noqa: F401 @@ -49,7 +50,16 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) backbone(input_224_multitemporal) +# Swin IS NOT on HuggingFace +@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_B"]) +def test_swin_instantiation(model_name): + base_module = "terratorch.models.backbones.prithvi_swin" + module = importlib.import_module(base_module) + model_class = getattr(module, model_name) + model = model_class(pretrained=False, pretrained_bands=[0,1,2,3,4,5,6,7,8,9], + bands=[1,2,3,4,5,6]) + #def test_swin_models_accept_non_divisible_by_patch_size(input_386): # backbone = timm.create_model("prithvi_swin_90_us", pretrained=False, num_frames=NUM_FRAMES) # backbone(input_386)