From 2842f2d062b9a07c5f2ea1fa753bcba5c2b29679 Mon Sep 17 00:00:00 2001 From: Carlos Gomes <carlosmiguel.gomes@live.com.pt> Date: Thu, 30 May 2024 15:10:22 +0200 Subject: [PATCH 1/2] add dice loss, add tests for loss instantiation Signed-off-by: Carlos Gomes <carlosmiguel.gomes@live.com.pt> --- src/terratorch/tasks/segmentation_tasks.py | 4 +++- tests/test_prithvi_tasks.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/terratorch/tasks/segmentation_tasks.py b/src/terratorch/tasks/segmentation_tasks.py index d5724d68..6e7fe8cf 100644 --- a/src/terratorch/tasks/segmentation_tasks.py +++ b/src/terratorch/tasks/segmentation_tasks.py @@ -152,9 +152,11 @@ def configure_losses(self) -> None: f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}." ) raise RuntimeError(exception_message) - self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=self.hparams["num_classes"]) + self.criterion = smp.losses.JaccardLoss(mode="multiclass") elif loss == "focal": self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True) + elif loss == "dice": + self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index) else: exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss." raise ValueError(exception_message) diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py index 22a19642..66a911eb 100644 --- a/tests/test_prithvi_tasks.py +++ b/tests/test_prithvi_tasks.py @@ -22,7 +22,8 @@ def model_input() -> torch.Tensor: @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModelFactory): +@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) +def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory): SemanticSegmentationTask( { "backbone": backbone, @@ -33,12 +34,14 @@ def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModel "num_classes": NUM_CLASSES, }, model_factory, + loss=loss ) @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFactory): +@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) +def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory): PixelwiseRegressionTask( { "backbone": backbone, @@ -48,12 +51,14 @@ def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFa "pretrained": False, }, model_factory, + loss=loss ) @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_classification_task(backbone, decoder, model_factory: PrithviModelFactory): +@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) +def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory): ClassificationTask( { "backbone": backbone, @@ -64,4 +69,5 @@ def test_create_classification_task(backbone, decoder, model_factory: PrithviMod "num_classes": NUM_CLASSES, }, model_factory, + loss=loss ) From 1ab3d1f86e3f8791572b28aaf818c93a9a9d5d41 Mon Sep 17 00:00:00 2001 From: Carlos Gomes <carlosmiguel.gomes@live.com.pt> Date: Thu, 30 May 2024 15:16:21 +0200 Subject: [PATCH 2/2] edit exception message to include dice loss Signed-off-by: Carlos Gomes <carlosmiguel.gomes@live.com.pt> --- src/terratorch/tasks/segmentation_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/terratorch/tasks/segmentation_tasks.py b/src/terratorch/tasks/segmentation_tasks.py index 6e7fe8cf..7070b9c5 100644 --- a/src/terratorch/tasks/segmentation_tasks.py +++ b/src/terratorch/tasks/segmentation_tasks.py @@ -158,7 +158,7 @@ def configure_losses(self) -> None: elif loss == "dice": self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index) else: - exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss." + exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss." raise ValueError(exception_message) def configure_metrics(self) -> None: