Skip to content

Commit

Permalink
Merge pull request #393 from IBM/multiclass
Browse files Browse the repository at this point in the history
Allowing the segmentation task to output multiple class labels
  • Loading branch information
Joao-L-S-Almeida authored Feb 11, 2025
2 parents 3c3d59d + 2b6404b commit 8b8278c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
5 changes: 5 additions & 0 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def is_one_band(img):

def write_tiff(img_wrt, filename, metadata):

# Adapting the number of bands to be compatible with the
# output dimensions.
count = img_wrt.shape[0]
metadata['count'] = count

with rasterio.open(filename, "w", **metadata) as dest:
if is_one_band(img_wrt):
img_wrt = img_wrt[None]
Expand Down
13 changes: 12 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
output_most_probable: bool = True,
) -> None:
"""Constructor
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
output_most_probable (bool): A boolean to define if the output during the inference will be just
for the most probable class or if it will include all of them.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand All @@ -138,6 +141,12 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)
self.output_most_probable = output_most_probable

if output_most_probable:
self.select_classes = lambda y: y.argmax(dim=1)
else:
self.select_classes = lambda y: y

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -351,5 +360,7 @@ def model_forward(x):
)
else:
y_hat: Tensor = self(x, **rest).output
y_hat = y_hat.argmax(dim=1)

y_hat = self.select_classes(y_hat)

return y_hat, file_names
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ data:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
val_label_data_root: tests/resources/inputs
test_data_root: tests/resources/inputs
test_label_data_root: tests/resources/inputs
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
Expand All @@ -83,8 +83,8 @@ model:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
#backbone_pretrained_cfg_overlay:
#file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
Expand All @@ -99,6 +99,7 @@ model:
num_frames: 1
num_classes: 2
head_dropout: 0.5708022831486758
output_most_probable: false
loss: ce
#aux_heads:
# - name: aux_head
Expand Down

0 comments on commit 8b8278c

Please sign in to comment.