Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding padding at the input when necessary #342

Merged
merged 46 commits into from
Jan 29, 2025
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
842941b
Adding padding at the input when necessary
Joao-L-S-Almeida Dec 23, 2024
6528861
patch_size as a explicit argument for PixelWiseModel
Joao-L-S-Almeida Jan 2, 2025
8376a5e
logging
Joao-L-S-Almeida Jan 2, 2025
8fa3bba
Cropping image
Joao-L-S-Almeida Jan 2, 2025
6fb8c95
cropping image for scaler model
Joao-L-S-Almeida Jan 2, 2025
5f37ba7
patch_size could be None
Joao-L-S-Almeida Jan 3, 2025
5cb27dc
Adapting the Clay factory to support patch_size and minor adjusts
Joao-L-S-Almeida Jan 3, 2025
ba43134
Trying to reduce the cost of these tests
Joao-L-S-Almeida Jan 3, 2025
9c26eab
pad_images must be in utils.py
Joao-L-S-Almeida Jan 3, 2025
c4fd736
Cropping images could be a necessary operation
Joao-L-S-Almeida Jan 6, 2025
b70a368
The cropping must be placed before the head in case of scalar models
Joao-L-S-Almeida Jan 6, 2025
ecca3aa
Creating extra images for tests
Joao-L-S-Almeida Jan 6, 2025
e09cd79
Minor changes
Joao-L-S-Almeida Jan 6, 2025
6fdf1b7
img_size also could be necessary
Joao-L-S-Almeida Jan 6, 2025
6178b47
conditional cropping
Joao-L-S-Almeida Jan 6, 2025
8cb6d26
config for testing nondivisible images
Joao-L-S-Almeida Jan 6, 2025
cb62e56
minor adjusts
Joao-L-S-Almeida Jan 6, 2025
a0cac1c
minor adjusts
Joao-L-S-Almeida Jan 6, 2025
ca881b4
Input files to be used for testing the padding for non-divisible images
Joao-L-S-Almeida Jan 6, 2025
62fa305
minor changes
Joao-L-S-Almeida Jan 6, 2025
1c409e8
more tests
Joao-L-S-Almeida Jan 6, 2025
0d79f8e
merging
Joao-L-S-Almeida Jan 6, 2025
42c3d98
merging
Joao-L-S-Almeida Jan 6, 2025
fd1599f
merging
Joao-L-S-Almeida Jan 6, 2025
41af8f7
merging
Joao-L-S-Almeida Jan 6, 2025
dde31bb
merging
Joao-L-S-Almeida Jan 6, 2025
e04e53e
argument not used
Joao-L-S-Almeida Jan 6, 2025
8cb2ff9
merging
Joao-L-S-Almeida Jan 8, 2025
6b9bb5a
Merge branch 'main' into add/pad
Joao-L-S-Almeida Jan 9, 2025
ed74fb5
This opration should not be here
Joao-L-S-Almeida Jan 9, 2025
4f54e12
merging with main
Joao-L-S-Almeida Jan 17, 2025
fce754d
wrong identation
Joao-L-S-Almeida Jan 17, 2025
f3dc433
Simplified padding code
blumenstiel Jan 20, 2025
4e0fcf2
Fix clay padding
blumenstiel Jan 20, 2025
2bb57b2
Remove padding from prithvi
blumenstiel Jan 20, 2025
fac50f0
Moving this search
Joao-L-S-Almeida Jan 20, 2025
0d64030
Limiting version for jsonargparse
Joao-L-S-Almeida Jan 20, 2025
23b5924
4.35.0
Joao-L-S-Almeida Jan 20, 2025
6786612
Cropping the image when necessary
Joao-L-S-Almeida Jan 21, 2025
f4fa6a0
Merge branch 'add/pad' of github.com:IBM/terratorch into add/pad
Joao-L-S-Almeida Jan 21, 2025
bc68b87
tests no more required
Joao-L-S-Almeida Jan 21, 2025
5265a30
merging with main
Joao-L-S-Almeida Jan 24, 2025
17a608e
Updating model name
Joao-L-S-Almeida Jan 24, 2025
1ff8938
Fixing indent
Joao-L-S-Almeida Jan 24, 2025
b0a4780
Removing output_size
Joao-L-S-Almeida Jan 24, 2025
0dc1e95
indent
Joao-L-S-Almeida Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pad_images must be in utils.py
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Joao-L-S-Almeida committed Jan 3, 2025
commit 9c26eab4ea979c40f4bcf9b7da831d3e2f7efc8d
15 changes: 1 addition & 14 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
from terratorch.datasets.utils import generate_bands_intervals
from terratorch.models.backbones.prithvi_mae import PrithviViT, PrithviMAE
from terratorch.models.utils import pad_images

logger = logging.getLogger(__name__)

@@ -153,20 +154,6 @@ def checkpoint_filter_fn_mae(

return state_dict


def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor:
p = patch_size
# h, w = imgs.shape[3], imgs.shape[4]
t, h, w = imgs.shape[-3:]
h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds
if h_pad > 0 or w_pad > 0:
imgs = torch.stack([
nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding)
for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad
])
return imgs


def _create_prithvi(
variant: str,
pretrained: bool = False, # noqa: FBT001, FBT002
2 changes: 1 addition & 1 deletion terratorch/models/pixel_wise_model.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

from terratorch.models.heads import RegressionHead, SegmentationHead
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput
from terratorch.models.backbones.prithvi_vit import pad_images
from terratorch.models.utils import pad_images

def freeze_module(module: nn.Module):
for param in module.parameters():
2 changes: 1 addition & 1 deletion terratorch/models/scalar_output_model.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

from terratorch.models.heads import ClassificationHead
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput

from terratorch.models.utils import pad_images

def freeze_module(module: nn.Module):
for param in module.parameters():
16 changes: 16 additions & 0 deletions terratorch/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch
from torch import nn, Tensor

class DecoderNotFoundError(Exception):
pass

@@ -11,3 +14,16 @@ def extract_prefix_keys(d: dict, prefix: str) -> dict:
remaining_dict[k] = v

return extracted_dict, remaining_dict

def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor:
p = patch_size
# h, w = imgs.shape[3], imgs.shape[4]
t, h, w = imgs.shape[-3:]
h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds
if h_pad > 0 or w_pad > 0:
imgs = torch.stack([
nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding)
for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad
])
return imgs