diff --git a/src/terratorch/models/prithvi_model_factory.py b/src/terratorch/models/prithvi_model_factory.py index ad424084..a84a35a8 100644 --- a/src/terratorch/models/prithvi_model_factory.py +++ b/src/terratorch/models/prithvi_model_factory.py @@ -32,7 +32,7 @@ def build_model( backbone: str | nn.Module, decoder: str | nn.Module, bands: list[HLSBands | int], - in_channels: int = int | None, # this should be removed, can be derived from bands. But it is a breaking change + in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change num_classes: int | None = None, pretrained: bool = True, # noqa: FBT001, FBT002 num_frames: int = 1, diff --git a/tests/test_prithvi_model_factory.py b/tests/test_prithvi_model_factory.py index 6d76fa44..ad7f53a5 100644 --- a/tests/test_prithvi_model_factory.py +++ b/tests/test_prithvi_model_factory.py @@ -13,6 +13,7 @@ NUM_CLASSES = 2 EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224) +EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES) @pytest.fixture(scope="session") @@ -25,6 +26,37 @@ def model_input() -> torch.Tensor: return torch.ones((1, NUM_CHANNELS, 224, 224)) @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +def test_create_classification_model(backbone, model_factory: PrithviModelFactory, model_input): + model = model_factory.build_model( + "classification", + backbone=backbone, + decoder="IdentityDecoder", + in_channels=NUM_CHANNELS, + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +def test_create_classification_model_no_in_channels(backbone, model_factory: PrithviModelFactory, model_input): + model = model_factory.build_model( + "classification", + backbone=backbone, + decoder="IdentityDecoder", + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) def test_create_segmentation_model(backbone, decoder, model_factory: PrithviModelFactory, model_input): model = model_factory.build_model( @@ -41,8 +73,24 @@ def test_create_segmentation_model(backbone, decoder, model_factory: PrithviMode with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) +@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) +def test_create_segmentation_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + decoder=decoder, + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + + +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input): aux_heads_name = ["first_aux", "second_aux"] @@ -67,7 +115,7 @@ def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_facto assert output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) def test_create_regression_model(backbone, decoder, model_factory: PrithviModelFactory, model_input): model = model_factory.build_model( @@ -83,8 +131,22 @@ def test_create_regression_model(backbone, decoder, model_factory: PrithviModelF with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) +@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) +def test_create_regression_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input): + model = model_factory.build_model( + "regression", + backbone=backbone, + decoder=decoder, + bands=PRETRAINED_BANDS, + pretrained=False, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input): aux_heads_name = ["first_aux", "second_aux"] @@ -108,14 +170,14 @@ def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory assert output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE -@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) def test_create_model_with_extra_bands(backbone, decoder, model_factory: PrithviModelFactory): model = model_factory.build_model( "segmentation", backbone=backbone, decoder=decoder, - in_channels=NUM_CHANNELS, + in_channels=NUM_CHANNELS + 1, bands=[*PRETRAINED_BANDS, 7], # add an extra band pretrained=False, num_classes=NUM_CLASSES,