Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding padding at the input when necessary #342

Merged
merged 46 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
842941b
Adding padding at the input when necessary
Joao-L-S-Almeida Dec 23, 2024
6528861
patch_size as a explicit argument for PixelWiseModel
Joao-L-S-Almeida Jan 2, 2025
8376a5e
logging
Joao-L-S-Almeida Jan 2, 2025
8fa3bba
Cropping image
Joao-L-S-Almeida Jan 2, 2025
6fb8c95
cropping image for scaler model
Joao-L-S-Almeida Jan 2, 2025
5f37ba7
patch_size could be None
Joao-L-S-Almeida Jan 3, 2025
5cb27dc
Adapting the Clay factory to support patch_size and minor adjusts
Joao-L-S-Almeida Jan 3, 2025
ba43134
Trying to reduce the cost of these tests
Joao-L-S-Almeida Jan 3, 2025
9c26eab
pad_images must be in utils.py
Joao-L-S-Almeida Jan 3, 2025
c4fd736
Cropping images could be a necessary operation
Joao-L-S-Almeida Jan 6, 2025
b70a368
The cropping must be placed before the head in case of scalar models
Joao-L-S-Almeida Jan 6, 2025
ecca3aa
Creating extra images for tests
Joao-L-S-Almeida Jan 6, 2025
e09cd79
Minor changes
Joao-L-S-Almeida Jan 6, 2025
6fdf1b7
img_size also could be necessary
Joao-L-S-Almeida Jan 6, 2025
6178b47
conditional cropping
Joao-L-S-Almeida Jan 6, 2025
8cb6d26
config for testing nondivisible images
Joao-L-S-Almeida Jan 6, 2025
cb62e56
minor adjusts
Joao-L-S-Almeida Jan 6, 2025
a0cac1c
minor adjusts
Joao-L-S-Almeida Jan 6, 2025
ca881b4
Input files to be used for testing the padding for non-divisible images
Joao-L-S-Almeida Jan 6, 2025
62fa305
minor changes
Joao-L-S-Almeida Jan 6, 2025
1c409e8
more tests
Joao-L-S-Almeida Jan 6, 2025
0d79f8e
merging
Joao-L-S-Almeida Jan 6, 2025
42c3d98
merging
Joao-L-S-Almeida Jan 6, 2025
fd1599f
merging
Joao-L-S-Almeida Jan 6, 2025
41af8f7
merging
Joao-L-S-Almeida Jan 6, 2025
dde31bb
merging
Joao-L-S-Almeida Jan 6, 2025
e04e53e
argument not used
Joao-L-S-Almeida Jan 6, 2025
8cb2ff9
merging
Joao-L-S-Almeida Jan 8, 2025
6b9bb5a
Merge branch 'main' into add/pad
Joao-L-S-Almeida Jan 9, 2025
ed74fb5
This opration should not be here
Joao-L-S-Almeida Jan 9, 2025
4f54e12
merging with main
Joao-L-S-Almeida Jan 17, 2025
fce754d
wrong identation
Joao-L-S-Almeida Jan 17, 2025
f3dc433
Simplified padding code
blumenstiel Jan 20, 2025
4e0fcf2
Fix clay padding
blumenstiel Jan 20, 2025
2bb57b2
Remove padding from prithvi
blumenstiel Jan 20, 2025
fac50f0
Moving this search
Joao-L-S-Almeida Jan 20, 2025
0d64030
Limiting version for jsonargparse
Joao-L-S-Almeida Jan 20, 2025
23b5924
4.35.0
Joao-L-S-Almeida Jan 20, 2025
6786612
Cropping the image when necessary
Joao-L-S-Almeida Jan 21, 2025
f4fa6a0
Merge branch 'add/pad' of github.com:IBM/terratorch into add/pad
Joao-L-S-Almeida Jan 21, 2025
bc68b87
tests no more required
Joao-L-S-Almeida Jan 21, 2025
5265a30
merging with main
Joao-L-S-Almeida Jan 24, 2025
17a608e
Updating model name
Joao-L-S-Almeida Jan 24, 2025
1ff8938
Fixing indent
Joao-L-S-Almeida Jan 24, 2025
b0a4780
Removing output_size
Joao-L-S-Almeida Jan 24, 2025
0dc1e95
indent
Joao-L-S-Almeida Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/scripts/create_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from PIL import Image
import os
import random
import numpy as np
import tifffile as tiff
from argparse import ArgumentParser
from osgeo import gdal
from osgeo import osr

parser = ArgumentParser()
parser.add_argument("--input_file")
parser.add_argument("--output_dir")
parser.add_argument("--n_copies", type=int, default=2)

args = parser.parse_args()
input_file = args.input_file
output_dir = args.output_dir
n_copies = args.n_copies

pad_limit = 4

# config
GDAL_DATA_TYPE = gdal.GDT_Int32
GEOTIFF_DRIVER_NAME = r'GTiff'
NO_DATA = 15
SPATIAL_REFERENCE_SYSTEM_WKID = 4326

for c in range(n_copies):

pad = 3#random.randint(1, pad_limit)
filename = os.path.split(input_file)[-1]
output_file = os.path.join(output_dir, filename.replace(".tif", f"_{c}.tif"))
print(pad)
imarray = tiff.imread(input_file)
im_shape = imarray.shape
im_shape_ext = tuple([i+2*pad for i in list(im_shape[:-1])]) + (im_shape[-1],)
#print(im_shape_ext)
output = np.zeros(im_shape_ext)
#print(output.shape)
output[pad:-pad, pad:-pad, :] = imarray
#print(output.shape)
#tiff.imwrite(output_file, output)

# create driver
driver = gdal.GetDriverByName(GEOTIFF_DRIVER_NAME)

output_raster = driver.Create(output_file,
output.shape[1],
output.shape[0],
output.shape[-1],
eType = GDAL_DATA_TYPE)

14 changes: 0 additions & 14 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,6 @@ def checkpoint_filter_fn_mae(

return state_dict


def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor:
p = patch_size
# h, w = imgs.shape[3], imgs.shape[4]
t, h, w = imgs.shape[-3:]
h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds
if h_pad > 0 or w_pad > 0:
imgs = torch.stack([
nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding)
for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad
])
return imgs


def _create_prithvi(
variant: str,
pretrained: bool = False, # noqa: FBT001, FBT002
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright contributors to the Terratorch project


import logging
import warnings

Expand All @@ -14,6 +13,7 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi
# check all dimensions are the same except for channel dimension
if len(model_patch_embed.shape) != len(checkpoint_patch_embed.shape):
return False

model_shape = [model_patch_embed.shape[i] for i in range(len(model_patch_embed.shape)) if i != 1]
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape
Expand Down Expand Up @@ -68,4 +68,4 @@ def select_patch_embed_weights(

state_dict[proj_key] = temp_weight

return state_dict
return state_dict
romeokienzler marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 21 additions & 3 deletions terratorch/models/clay_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import sys
from collections.abc import Callable
import logging

import timm
import torch
Expand Down Expand Up @@ -108,6 +109,7 @@ def build_model(

# Path for accessing the model source code.
self.syspath_kwarg = "model_sys_path"
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
romeokienzler marked this conversation as resolved.
Show resolved Hide resolved

# TODO: support auxiliary heads
if not isinstance(backbone, nn.Module):
Expand All @@ -120,8 +122,6 @@ def build_model(
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
raise NotImplementedError(msg)

backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

# Trying to find the model on HuggingFace.
try:
backbone: nn.Module = timm.create_model(
Expand All @@ -143,6 +143,16 @@ def build_model(
backbone: nn.Module = Embedder(ckpt_path=checkpoint_path, **backbone_kwargs)
print("Model Clay was successfully restored.")

# If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
patch_size = backbone_kwargs.get("patch_size", None)
if patch_size is None:
# Infer patch size from model by checking all backbone modules
for module in backbone.modules():
if hasattr(module, "patch_size"):
patch_size = module.patch_size
break
padding = backbone_kwargs.get("padding", "reflect")

# allow decoder to be a module passed directly
decoder_cls = _get_decoder(decoder)
decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")
Expand All @@ -157,7 +167,7 @@ def build_model(
head_kwargs["num_classes"] = num_classes
if aux_decoders is None:
return _build_appropriate_model(
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, patch_size=patch_size, padding=padding, rescale=rescale
)

to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
Expand Down Expand Up @@ -186,6 +196,8 @@ def build_model(
decoder,
head_kwargs,
prepare_features_for_image_model,
patch_size=patch_size,
padding=padding,
rescale=rescale,
auxiliary_heads=to_be_aux_decoders,
)
Expand All @@ -197,6 +209,8 @@ def _build_appropriate_model(
decoder: nn.Module,
head_kwargs: dict,
prepare_features_for_image_model: Callable,
patch_size: int | list | None,
padding: str,
rescale: bool = True, # noqa: FBT001, FBT002
auxiliary_heads: dict | None = None,
):
Expand All @@ -206,6 +220,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
rescale=rescale,
auxiliary_heads=auxiliary_heads,
)
Expand All @@ -215,6 +231,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
auxiliary_heads=auxiliary_heads,
)

Expand Down
27 changes: 25 additions & 2 deletions terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright contributors to the Terratorch project


from typing import List
romeokienzler marked this conversation as resolved.
Show resolved Hide resolved
import warnings

import logging
from torch import nn

from terratorch.models.model import (
Expand Down Expand Up @@ -65,6 +65,8 @@ def _check_all_args_used(kwargs):
msg = f"arguments {kwargs} were passed but not used."
raise ValueError(msg)

def _get_argument_from_instance(model, name):
return getattr(model._timm_module.patch_embed, name)[-1]

@MODEL_FACTORY_REGISTRY.register
class EncoderDecoderFactory(ModelFactory):
Expand Down Expand Up @@ -128,6 +130,17 @@ def build_model(
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
backbone = _get_backbone(backbone, **backbone_kwargs)

# If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
romeokienzler marked this conversation as resolved.
Show resolved Hide resolved
patch_size = backbone_kwargs.get("patch_size", None)

if patch_size is None:
# Infer patch size from model by checking all backbone modules
for module in backbone.modules():
if hasattr(module, "patch_size"):
patch_size = module.patch_size
break
padding = backbone_kwargs.get("padding", "reflect")

if peft_config is not None:
if not backbone_kwargs.get("pretrained", False):
msg = (
Expand Down Expand Up @@ -166,6 +179,8 @@ def build_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -191,6 +206,8 @@ def build_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -203,6 +220,8 @@ def _build_appropriate_model(
backbone: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
patch_size: int | list | None,
padding: str,
decoder_includes_head: bool = False,
necks: list[Neck] | None = None,
rescale: bool = True, # noqa: FBT001, FBT002
Expand All @@ -218,6 +237,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
rescale=rescale,
Expand All @@ -229,6 +250,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
auxiliary_heads=auxiliary_heads,
Expand Down
49 changes: 32 additions & 17 deletions terratorch/models/pixel_wise_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright contributors to the Terratorch project

from typing import List
import logging
import torch
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from segmentation_models_pytorch.base import SegmentationModel
from torch import nn

from terratorch.models.heads import RegressionHead, SegmentationHead
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput

from terratorch.models.utils import pad_images

def freeze_module(module: nn.Module):
for param in module.parameters():
Expand All @@ -26,6 +28,8 @@ def __init__(
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
patch_size: int = None,
padding: str = None,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
Expand Down Expand Up @@ -69,6 +73,8 @@ def __init__(

self.neck = neck
self.rescale = rescale
self.patch_size = patch_size
self.padding = padding

def freeze_encoder(self):
freeze_module(self.encoder)
Expand All @@ -77,10 +83,6 @@ def freeze_decoder(self):
freeze_module(self.decoder)
freeze_module(self.head)

# TODO: do this properly
def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002
return True

@staticmethod
def _check_for_single_channel_and_squeeze(x):
if x.shape[1] == 1:
Expand All @@ -89,19 +91,27 @@ def _check_for_single_channel_and_squeeze(x):

def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)
if isinstance(x, torch.Tensor):
input_size = x.shape[-2:]
elif hasattr(kwargs, 'image_size'):
input_size = kwargs['image_size']
elif isinstance(x, dict):
# Multimodal input in passed as dict
input_size = list(x.values())[0].shape[-2:]
else:
ValueError('Could not infer input shape.')

def _get_size(x):
if isinstance(x, torch.Tensor):
return x.shape[-2:]
elif isinstance(x, dict):
# Multimodal input in passed as dict (Assuming first modality to be an image)
return list(x.values())[0].shape[-2:]
elif hasattr(kwargs, 'image_size'):
return kwargs['image_size']
else:
ValueError('Could not infer image shape.')

image_size = _get_size(x)
if isinstance(x, torch.Tensor) and self.patch_size:
# Only works for single image modalities
x = pad_images(x, self.patch_size, self.padding)
input_size = _get_size(x)

features = self.encoder(x, **kwargs)

## only for backwards compatibility with pre-neck times.
# only for backwards compatibility with pre-neck times.
if self.neck:
prepare = self.neck
else:
Expand All @@ -114,13 +124,18 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
if self.rescale and mask.shape[-2:] != input_size:
mask = F.interpolate(mask, size=input_size, mode="bilinear")
mask = self._check_for_single_channel_and_squeeze(mask)
mask = mask[..., :image_size[0], :image_size[1]]

aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
if self.rescale and aux_output.shape[-2:] != input_size:
aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
aux_output = self._check_for_single_channel_and_squeeze(aux_output)
aux_output = aux_output[..., :image_size[0], :image_size[1]]
aux_outputs[name] = aux_output


return ModelOutput(output=mask, auxiliary_heads=aux_outputs)
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved

def _get_head(self, task: str, input_embed_dim: int, head_kwargs):
Expand Down
Loading
Loading