Skip to content

Commit

Permalink
Removing output_size
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jan 24, 2025
1 parent 1ff8938 commit b0a4780
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 24 deletions.
7 changes: 0 additions & 7 deletions terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ def build_model(
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
backbone = _get_backbone(backbone, **backbone_kwargs)

# The image can be optionally cropped to a final format when necessary
output_size = backbone_kwargs.get("output_size", None)

# If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
patch_size = backbone_kwargs.get("patch_size", None)

Expand Down Expand Up @@ -184,7 +181,6 @@ def build_model(
head_kwargs,
patch_size=patch_size,
padding=padding,
output_size=output_size,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand Down Expand Up @@ -212,7 +208,6 @@ def build_model(
head_kwargs,
patch_size=patch_size,
padding=padding,
output_size=output_size,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -227,7 +222,6 @@ def _build_appropriate_model(
head_kwargs: dict,
patch_size: int | list | None,
padding: str,
output_size: List[int] | None = None,
decoder_includes_head: bool = False,
necks: list[Neck] | None = None,
rescale: bool = True, # noqa: FBT001, FBT002
Expand All @@ -245,7 +239,6 @@ def _build_appropriate_model(
head_kwargs,
patch_size=patch_size,
padding=padding,
output_size=output_size,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
rescale=rescale,
Expand Down
8 changes: 0 additions & 8 deletions terratorch/models/pixel_wise_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
patch_size: int = None,
padding: str = None,
decoder_includes_head: bool = False,
output_size: List[int] | None = None,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
rescale: bool = True, # noqa: FBT002, FBT001
Expand All @@ -44,7 +43,6 @@ def __init__(
decoder (nn.Module): Decoder to be used
head_kwargs (dict): Arguments to be passed at instantiation of the head.
decoder_includes_head (bool): Whether the decoder already incldes a head. If true, a head will not be added. Defaults to False.
output_size (List[int]): The size of the epxected output/target tensor. It is used to crop the output before returning it.
auxiliary_heads (list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None, optional): List of
AuxiliaryHeads with heads to be instantiated. Defaults to None.
neck (nn.Module | None): Module applied between backbone and decoder.
Expand Down Expand Up @@ -77,9 +75,6 @@ def __init__(
self.rescale = rescale
self.patch_size = patch_size
self.padding = padding
self.output_size = output_size
self.reference_top = 0
self.reference_left = 0

def freeze_encoder(self):
freeze_module(self.encoder)
Expand Down Expand Up @@ -140,9 +135,6 @@ def _get_size(x):
aux_output = aux_output[..., :image_size[0], :image_size[1]]
aux_outputs[name] = aux_output

# Cropping when necessary
if self.output_size:
mask = transforms.functional.crop(mask, self.reference_left, self.reference_left, *image_size)

return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ data:
- 2
- 1
- 0
train_data_root: tests/resources/inputs_extra
train_label_data_root: tests/resources/inputs_extra
val_data_root: tests/resources/inputs_extra
val_label_data_root: tests/resources/inputs_extra
test_data_root: tests/resources/inputs_extra
test_label_data_root: tests/resources/inputs_extra
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
val_label_data_root: tests/resources/inputs
test_data_root: tests/resources/inputs
test_label_data_root: tests/resources/inputs
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ model:
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
backbone_patch_size: 13
backbone_output_size:
- 224
- 224
decoder_channels: 64
num_frames: 1
in_channels: 6
Expand Down

0 comments on commit b0a4780

Please sign in to comment.