From 0342fa1ac044e102eb29dffacc5a065ef1259263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 28 Jun 2024 09:43:05 -0300 Subject: [PATCH 1/2] Updating patch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 818d9fe0..49a2229d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = 'setuptools.build_meta' [project] name = "terratorch" -version = "0.99.0" +version = "0.99.1" description = "TerraTorch - A model training toolkit for geospatial tasks" license = { "text" = "Apache License, Version 2.0" } readme = "README.md" @@ -144,7 +144,7 @@ exclude_lines = [ ] [tool.bumpver] -current_version = "0.99.0" +current_version = "0.99.1" version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]" commit_message = "Bump version {old_version} -> {new_version}" commit = true From 0401bf8f6c2774d48a01632ba5b4f90d9d76eee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 28 Jun 2024 09:43:16 -0300 Subject: [PATCH 2/2] Testing Swin backbones MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- tests/test_backbones.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 6b8c36b9..760039cb 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -1,6 +1,7 @@ import pytest import timm import torch +import importlib import terratorch # noqa: F401 @@ -49,7 +50,16 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) backbone(input_224_multitemporal) +# Swin IS NOT on HuggingFace +@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_B"]) +def test_swin_instantiation(model_name): + base_module = "terratorch.models.backbones.prithvi_swin" + module = importlib.import_module(base_module) + model_class = getattr(module, model_name) + model = model_class(pretrained=False, pretrained_bands=[0,1,2,3,4,5,6,7,8,9], + bands=[1,2,3,4,5,6]) + #def test_swin_models_accept_non_divisible_by_patch_size(input_386): # backbone = timm.create_model("prithvi_swin_90_us", pretrained=False, num_frames=NUM_FRAMES) # backbone(input_386)