From 6d1f127efa499cacf2e908b9d24885d4b07b0316 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 6 Mar 2025 09:59:06 -0300 Subject: [PATCH 1/6] Trying to deal with possible errors during the test step of the segmentation task by invoking tiled inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/segmentation_tasks.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 6064cdf8..a94df376 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -18,6 +18,7 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference from terratorch.tasks.base_task import TerraTorchTask +from terratorch.models.model import ModelOutput BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 @@ -266,7 +267,27 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) + + def model_forward(x, **kwargs): + return self(x, **kwargs).output + + # When the input sample cannot be fit on memory for some reason + # the tiled inference is automatically invoked. + try: + model_output: ModelOutput = self(x, **rest) + except Exception: + if self.tiled_inference_parameters: + y_hat: Tensor = tiled_inference( + model_forward, + x, + self.hparams["model_args"]["num_classes"], + self.tiled_inference_parameters, + **rest, + ) + model_output = ModelOutput(mask=y_hat) + else: + raise Exception("You need to define a configuration for the tiled inference.") + if dataloader_idx >= len(self.test_loss_handler): msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." raise ValueError(msg) From 4f9c4de6dcff680e69052fbe0d77a9f833e0213f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 6 Mar 2025 10:18:16 -0300 Subject: [PATCH 2/6] Doing the same for regression. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/regression_tasks.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 118298a0..4cd3bb99 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -355,7 +355,27 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) + + def model_forward(x, **kwargs): + return self(x, **kwargs).output + + # When the input sample cannot be fit on memory for some reason + # the tiled inference is automatically invoked. + try: + model_output: ModelOutput = self(x, **rest) + except Exception: + if self.tiled_inference_parameters: + y_hat: Tensor = tiled_inference( + model_forward, + x, + 1, + self.tiled_inference_parameters, + **rest, + ) + model_output = ModelOutput(mask=y_hat) + else: + raise Exception("You need to define a configuration for the tiled inference.") + if dataloader_idx >= len(self.test_loss_handler): msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." raise ValueError(msg) From aa5532259ea2b264fbf070049589268192c3438e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 6 Mar 2025 10:43:58 -0300 Subject: [PATCH 3/6] Adjusting the error directive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/regression_tasks.py | 5 +++-- terratorch/tasks/segmentation_tasks.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 4cd3bb99..d9a8d2a4 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -363,7 +363,8 @@ def model_forward(x, **kwargs): # the tiled inference is automatically invoked. try: model_output: ModelOutput = self(x, **rest) - except Exception: + except RuntimeError: + logger.info("\n The input sample could not run in a full format. Using tiled inference.") if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward, @@ -372,7 +373,7 @@ def model_forward(x, **kwargs): self.tiled_inference_parameters, **rest, ) - model_output = ModelOutput(mask=y_hat) + model_output = ModelOutput(output=y_hat) else: raise Exception("You need to define a configuration for the tiled inference.") diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index a94df376..e1ba0094 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -275,7 +275,8 @@ def model_forward(x, **kwargs): # the tiled inference is automatically invoked. try: model_output: ModelOutput = self(x, **rest) - except Exception: + except RuntimeError: + logger.info("\n The input sample could not run in a full format. Using tiled inference.") if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward, @@ -284,7 +285,7 @@ def model_forward(x, **kwargs): self.tiled_inference_parameters, **rest, ) - model_output = ModelOutput(mask=y_hat) + model_output = ModelOutput(output=y_hat) else: raise Exception("You need to define a configuration for the tiled inference.") From c199e9ad180c3be3d5d45abacdc5d63d93c7c7ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 6 Mar 2025 10:44:23 -0300 Subject: [PATCH 4/6] updating config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- ...tured-finetune_prithvi_swin_B_segmentation.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml index 586aa6b1..2d0ebd79 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: cpu + accelerator: auto strategy: auto devices: auto num_nodes: 1 @@ -118,12 +118,12 @@ model: model_factory: PrithviModelFactory # uncomment this block for tiled inference - # tiled_inference_parameters: - # h_crop: 224 - # h_stride: 192 - # w_crop: 224 - # w_stride: 192 - # average_patches: true + tiled_inference_parameters: + h_crop: 224 + h_stride: 192 + w_crop: 224 + w_stride: 192 + average_patches: true optimizer: class_path: torch.optim.AdamW init_args: From af8f22b91eb035bd6babc0b8baf00fc2bb663297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 6 Mar 2025 12:03:23 -0300 Subject: [PATCH 5/6] Warnin g the user about the risks of tiled inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/regression_tasks.py | 1 + terratorch/tasks/segmentation_tasks.py | 1 + 2 files changed, 2 insertions(+) diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index d9a8d2a4..537b4d97 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -365,6 +365,7 @@ def model_forward(x, **kwargs): model_output: ModelOutput = self(x, **rest) except RuntimeError: logger.info("\n The input sample could not run in a full format. Using tiled inference.") + looger.info("Notice thaat the tiled inference WON'T produce the exactly same result as the full inference.") if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward, diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index e1ba0094..61e6bece 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -277,6 +277,7 @@ def model_forward(x, **kwargs): model_output: ModelOutput = self(x, **rest) except RuntimeError: logger.info("\n The input sample could not run in a full format. Using tiled inference.") + looger.info("Notice thaat the tiled inference WON'T produce the exactly same result as the full inference.") if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward, From f96f7e9713b3f56bd55c6de68aaf4ca52891193d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 6 Mar 2025 12:34:55 -0300 Subject: [PATCH 6/6] Fixing typos MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/regression_tasks.py | 2 +- terratorch/tasks/segmentation_tasks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 537b4d97..0686933d 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -365,7 +365,7 @@ def model_forward(x, **kwargs): model_output: ModelOutput = self(x, **rest) except RuntimeError: logger.info("\n The input sample could not run in a full format. Using tiled inference.") - looger.info("Notice thaat the tiled inference WON'T produce the exactly same result as the full inference.") + looger.info("Notice that the tiled inference WON'T produce the exactly same result as the full inference.") if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward, diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 61e6bece..460c3e00 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -277,7 +277,7 @@ def model_forward(x, **kwargs): model_output: ModelOutput = self(x, **rest) except RuntimeError: logger.info("\n The input sample could not run in a full format. Using tiled inference.") - looger.info("Notice thaat the tiled inference WON'T produce the exactly same result as the full inference.") + looger.info("Notice that the tiled inference WON'T produce the exactly same result as the full inference.") if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( model_forward,