From 2842f2d062b9a07c5f2ea1fa753bcba5c2b29679 Mon Sep 17 00:00:00 2001
From: Carlos Gomes <carlosmiguel.gomes@live.com.pt>
Date: Thu, 30 May 2024 15:10:22 +0200
Subject: [PATCH 1/2] add dice loss, add tests for loss instantiation

Signed-off-by: Carlos Gomes <carlosmiguel.gomes@live.com.pt>
---
 src/terratorch/tasks/segmentation_tasks.py |  4 +++-
 tests/test_prithvi_tasks.py                | 12 +++++++++---
 2 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/terratorch/tasks/segmentation_tasks.py b/src/terratorch/tasks/segmentation_tasks.py
index d5724d68..6e7fe8cf 100644
--- a/src/terratorch/tasks/segmentation_tasks.py
+++ b/src/terratorch/tasks/segmentation_tasks.py
@@ -152,9 +152,11 @@ def configure_losses(self) -> None:
                     f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
                 )
                 raise RuntimeError(exception_message)
-            self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=self.hparams["num_classes"])
+            self.criterion = smp.losses.JaccardLoss(mode="multiclass")
         elif loss == "focal":
             self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
+        elif loss == "dice":
+            self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
         else:
             exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss."
             raise ValueError(exception_message)
diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py
index 22a19642..66a911eb 100644
--- a/tests/test_prithvi_tasks.py
+++ b/tests/test_prithvi_tasks.py
@@ -22,7 +22,8 @@ def model_input() -> torch.Tensor:
 
 @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
 @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
-def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModelFactory):
+@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
+def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
     SemanticSegmentationTask(
         {
             "backbone": backbone,
@@ -33,12 +34,14 @@ def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModel
             "num_classes": NUM_CLASSES,
         },
         model_factory,
+        loss=loss
     )
 
 
 @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
 @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
-def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFactory):
+@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
+def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
     PixelwiseRegressionTask(
         {
             "backbone": backbone,
@@ -48,12 +51,14 @@ def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFa
             "pretrained": False,
         },
         model_factory,
+        loss=loss
     )
 
 
 @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
 @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
-def test_create_classification_task(backbone, decoder, model_factory: PrithviModelFactory):
+@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
+def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
     ClassificationTask(
         {
             "backbone": backbone,
@@ -64,4 +69,5 @@ def test_create_classification_task(backbone, decoder, model_factory: PrithviMod
             "num_classes": NUM_CLASSES,
         },
         model_factory,
+        loss=loss
     )

From 1ab3d1f86e3f8791572b28aaf818c93a9a9d5d41 Mon Sep 17 00:00:00 2001
From: Carlos Gomes <carlosmiguel.gomes@live.com.pt>
Date: Thu, 30 May 2024 15:16:21 +0200
Subject: [PATCH 2/2] edit exception message to include dice loss

Signed-off-by: Carlos Gomes <carlosmiguel.gomes@live.com.pt>
---
 src/terratorch/tasks/segmentation_tasks.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/terratorch/tasks/segmentation_tasks.py b/src/terratorch/tasks/segmentation_tasks.py
index 6e7fe8cf..7070b9c5 100644
--- a/src/terratorch/tasks/segmentation_tasks.py
+++ b/src/terratorch/tasks/segmentation_tasks.py
@@ -158,7 +158,7 @@ def configure_losses(self) -> None:
         elif loss == "dice":
             self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
         else:
-            exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss."
+            exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
             raise ValueError(exception_message)
 
     def configure_metrics(self) -> None: