diff --git a/src/terratorch/tasks/segmentation_tasks.py b/src/terratorch/tasks/segmentation_tasks.py index d5724d68..7070b9c5 100644 --- a/src/terratorch/tasks/segmentation_tasks.py +++ b/src/terratorch/tasks/segmentation_tasks.py @@ -152,11 +152,13 @@ def configure_losses(self) -> None: f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}." ) raise RuntimeError(exception_message) - self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=self.hparams["num_classes"]) + self.criterion = smp.losses.JaccardLoss(mode="multiclass") elif loss == "focal": self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True) + elif loss == "dice": + self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index) else: - exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss." + exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss." raise ValueError(exception_message) def configure_metrics(self) -> None: diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py index 22a19642..66a911eb 100644 --- a/tests/test_prithvi_tasks.py +++ b/tests/test_prithvi_tasks.py @@ -22,7 +22,8 @@ def model_input() -> torch.Tensor: @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModelFactory): +@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) +def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory): SemanticSegmentationTask( { "backbone": backbone, @@ -33,12 +34,14 @@ def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModel "num_classes": NUM_CLASSES, }, model_factory, + loss=loss ) @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFactory): +@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) +def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory): PixelwiseRegressionTask( { "backbone": backbone, @@ -48,12 +51,14 @@ def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFa "pretrained": False, }, model_factory, + loss=loss ) @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_classification_task(backbone, decoder, model_factory: PrithviModelFactory): +@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) +def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory): ClassificationTask( { "backbone": backbone, @@ -64,4 +69,5 @@ def test_create_classification_task(backbone, decoder, model_factory: PrithviMod "num_classes": NUM_CLASSES, }, model_factory, + loss=loss )