diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 9849b0b2..0686933d 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -355,7 +355,29 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) + + def model_forward(x, **kwargs): + return self(x, **kwargs).output + + # When the input sample cannot be fit on memory for some reason + # the tiled inference is automatically invoked. + try: + model_output: ModelOutput = self(x, **rest) + except RuntimeError: + logger.info("\n The input sample could not run in a full format. Using tiled inference.") + looger.info("Notice that the tiled inference WON'T produce the exactly same result as the full inference.") + if self.tiled_inference_parameters: + y_hat: Tensor = tiled_inference( + model_forward, + x, + 1, + self.tiled_inference_parameters, + **rest, + ) + model_output = ModelOutput(output=y_hat) + else: + raise Exception("You need to define a configuration for the tiled inference.") + if dataloader_idx >= len(self.test_loss_handler): msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." raise ValueError(msg) @@ -384,12 +406,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} - def model_forward(x): + def model_forward(x, **kwargs): return self(x).output if self.tiled_inference_parameters: # TODO: tiled inference does not work with additional input data (**rest) - y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters) + y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters, **rest) else: y_hat: Tensor = self(x, **rest).output return y_hat, file_names diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 882f5c86..460c3e00 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -18,6 +18,7 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference from terratorch.tasks.base_task import TerraTorchTask +from terratorch.models.model import ModelOutput BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 @@ -266,7 +267,29 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) + + def model_forward(x, **kwargs): + return self(x, **kwargs).output + + # When the input sample cannot be fit on memory for some reason + # the tiled inference is automatically invoked. + try: + model_output: ModelOutput = self(x, **rest) + except RuntimeError: + logger.info("\n The input sample could not run in a full format. Using tiled inference.") + looger.info("Notice that the tiled inference WON'T produce the exactly same result as the full inference.") + if self.tiled_inference_parameters: + y_hat: Tensor = tiled_inference( + model_forward, + x, + self.hparams["model_args"]["num_classes"], + self.tiled_inference_parameters, + **rest, + ) + model_output = ModelOutput(output=y_hat) + else: + raise Exception("You need to define a configuration for the tiled inference.") + if dataloader_idx >= len(self.test_loss_handler): msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." raise ValueError(msg) @@ -345,18 +368,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) - - def model_forward(x): - return self(x).output + def model_forward(x, **kwargs): + return self(x, **kwargs).output if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( - # TODO: tiled inference does not work with additional input data (**rest) model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters, + **rest, ) else: y_hat: Tensor = self(x, **rest).output diff --git a/terratorch/tasks/tiled_inference.py b/terratorch/tasks/tiled_inference.py index 28d93140..1e130089 100644 --- a/terratorch/tasks/tiled_inference.py +++ b/terratorch/tasks/tiled_inference.py @@ -46,6 +46,7 @@ def tiled_inference( input_batch: torch.Tensor, out_channels: int, inference_parameters: TiledInferenceParameters, + **kwargs ) -> torch.Tensor: """ Like divide an image into (potentially) overlapping tiles and perform inference on them. @@ -163,7 +164,7 @@ def tiled_inference( end = min(len(coordinates_and_inputs), start + process_batch_size) batch = coordinates_and_inputs[start:end] tensor_input = torch.stack([b.input_data for b in batch], dim=0) - output = model_forward(tensor_input) + output = model_forward(tensor_input, **kwargs) output = [output[i] for i in range(len(batch))] for batch_input, predicted in zip(batch, output, strict=True): if batch_input.output_crop is not None: diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml index 586aa6b1..2d0ebd79 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B_segmentation.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: cpu + accelerator: auto strategy: auto devices: auto num_nodes: 1 @@ -118,12 +118,12 @@ model: model_factory: PrithviModelFactory # uncomment this block for tiled inference - # tiled_inference_parameters: - # h_crop: 224 - # h_stride: 192 - # w_crop: 224 - # w_stride: 192 - # average_patches: true + tiled_inference_parameters: + h_crop: 224 + h_stride: 192 + w_crop: 224 + w_stride: 192 + average_patches: true optimizer: class_path: torch.optim.AdamW init_args: