Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved metadata #240

Closed
wants to merge 40 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
21338d3
Template for Clay V1 model (#221)
srmsoumya Apr 18, 2024
74c65ad
Remove Pixelify from Decoder
Apr 19, 2024
71645e9
Merge branch 'main' into dev
Apr 19, 2024
d5ed36f
Merge branch 'dev' into multinode-training
Apr 19, 2024
b6f31ab
Modify lightning module for ClayMAE
Apr 19, 2024
e6a276e
Use lightning config.yaml
Apr 20, 2024
7b31a09
Modify dataset & datamodule to read metadata.yaml & npz files instead…
Apr 20, 2024
261cd14
Modify the model to handle EODataset
Apr 21, 2024
e0e31a6
Pass metadata path to both ClayDataModule & ClayMAEModule
Apr 21, 2024
3bcf7e8
Add model variants tiny, small, medium, large
Apr 21, 2024
801ddea
Lr 1e-3 to 1e-5
Apr 21, 2024
fa9edcf
Add teacher encoder
Apr 21, 2024
d0d9544
Pass band order as argument to get rgb for teacher model. Replace tra…
Apr 21, 2024
9587232
Add 0.75 weight to reconstruction loss
Apr 21, 2024
880cbf7
Pass rgb indices in the metadata.yaml file
Apr 22, 2024
afe1460
Don't add bias for decoder side of Dynamic EMbedding
Apr 23, 2024
40af48a
Freeze the teacher model on start of every train epoch
Apr 23, 2024
b68c5f9
pre-commit fix lint errors
weiji14 Apr 23, 2024
a317860
:heavy_plus_sign: Add timm
weiji14 Apr 24, 2024
4aa9fc0
:mute: Silence set_float32_matmul_precision tip and a print statement
weiji14 Apr 24, 2024
1e4ea81
Modify config & datamodule to run on multi-gpu mode
Apr 24, 2024
80a0778
add rec_loss & rep_loss to the logger
Apr 25, 2024
c4c8e8d
Add a temporary env file for v1 runs
Apr 25, 2024
701f412
Fix lr to 1e-5
Apr 25, 2024
7f8e633
Add multinode sbatch script
Apr 25, 2024
6df5aca
Sampler to load data as batches of sensor (#233)
srmsoumya Apr 26, 2024
cc605c6
Add a sampler to load multi sensor data
Apr 26, 2024
11fe5a8
Write collate, Sampler returns one element, use BatchSampler
Apr 26, 2024
819c25e
Simplify names of landsat
Apr 26, 2024
757132f
Sampler returns batches of input, use_distributed_sampler: False in l…
Apr 27, 2024
9628f86
Add support to provide images of different size as input to the model
Apr 27, 2024
f0efcb4
Add jsonargparse[signatures]>=4.27.7 to env yaml, required by lightni…
Apr 30, 2024
8f69719
Improved metadata
yellowcap Apr 30, 2024
0b15e83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2024
8f6386e
Update configs/metadata.yaml
yellowcap Apr 30, 2024
698b692
Update configs/metadata.yaml
yellowcap Apr 30, 2024
e04338b
Update configs/metadata.yaml
yellowcap Apr 30, 2024
4c0c846
Update statistics on each band for each platform
yellowcap Apr 30, 2024
6107496
Fixed metadata by excluding nodata
yellowcap Apr 30, 2024
5e9ccfe
Replace legacy `np.random.shuffle` call with `np.random.Generator`
weiji14 Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Modify dataset & datamodule to read metadata.yaml & npz files instead…
… of tifs
SRM committed Apr 20, 2024
commit 7b31a099d6aee03a6f57a2833069d20361dbb5b1
69 changes: 69 additions & 0 deletions configs/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
s2a_l2a:
gsd: 20
bands:
mean:
blue: 1200
green: 1400
red: 1500
rededge1: 1800
rededge2: 1900
rededge3: 2000
nir: 2200
nir08: 2300
swir16: 2400
swir22: 2500
std:
blue: 200
green: 250
red: 300
rededge1: 330
rededge2: 340
rededge3: 360
nir: 400
nir08: 420
swir16: 450
swir22: 480
wavelength:
blue: 0.4905
green: 0.5605
red: 0.665
rededge1: 0.7055
rededge2: 0.7405
rededge3: 0.783
nir: 0.8425
nir08: 0.865
swir16: 1.61
swir22: 2.19
naip:
gsd: 1.0
bands:
mean:
blue: 100
green: 100
red: 100
nir: 180
std:
blue: 30
green: 30
red: 30
nir: 45
wavelength:
blue: 0.48
green: 0.53
red: 0.66
nir: 0.86
linz:
gsd: 0.1
bands:
mean:
blue: 100
green: 100
red: 100
std:
blue: 25
green: 25
red: 25
wavelength:
blue: 0.48
green: 0.53
red: 0.66
102 changes: 60 additions & 42 deletions src/datamodule.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import math
import os
import random
import yaml
from pathlib import Path
from typing import List, Literal

@@ -14,11 +15,9 @@
import rasterio
import torch
import torchdata
from box import Box
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2

os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
os.environ["GDAL_HTTP_MERGE_CONSECUTIVE_RANGES"] = "YES"
from torchvision import transforms


# %%
@@ -106,70 +105,89 @@ def __len__(self):
return len(self.chips_path)


class ClayDataModule(L.LightningDataModule):
MEAN = [
1369.03,
1597.68,
1741.10,
2053.58,
2569.82,
2763.01,
2858.43,
2893.86,
2303.00,
1807.79,
0.026,
0.118,
499.46,
]

STD = [
2026.96,
2011.88,
2146.35,
2138.96,
2003.27,
1962.45,
2016.38,
1917.12,
1679.88,
1568.06,
0.118,
0.873,
880.35,
]
class EODataset(Dataset):
"""Reads different Earth Observation data sources from a directory."""

def __init__(
self, chips_path: List[Path], platform: str, metadata_path: str
) -> None:
super().__init__()
self.chips_path = chips_path
self.platform = platform
self.metadata = Box(yaml.safe_load(open(metadata_path, "r"))[platform])
self.tfm = transforms.Compose(
[
transforms.Normalize(
mean=list(self.metadata.bands.mean.values()),
std=list(self.metadata.bands.std.values()),
),
]
)

def __len__(self):
return len(self.chips_path)

def __getitem__(self, idx):
chip_path = self.chips_path[idx]
chip = np.load(chip_path, allow_pickle=True)
return {
"pixels": self.tfm(torch.from_numpy(chip["pixels"].astype(np.float32))),
"platform": str(chip["platform"]),
"date": str(chip["date"]),
"hour": torch.as_tensor(chip["hour_norm"], dtype=torch.float16),
"week": torch.as_tensor(chip["week_norm"], dtype=torch.float16),
"lat": torch.as_tensor(chip["lat_norm"], dtype=torch.float16),
"lon": torch.as_tensor(chip["lon_norm"], dtype=torch.float16),
}


class ClayDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: str = "data",
platform: str = "naip",
metadata_path: str = "configs/metadata.yaml",
batch_size: int = 10,
num_workers: int = 8,
):
super().__init__()
self.data_dir = data_dir
self.platform = platform
self.metadata_path = metadata_path
self.batch_size = batch_size
self.num_workers = num_workers
self.split_ratio = 0.8
self.tfm = v2.Compose([v2.Normalize(mean=self.MEAN, std=self.STD)])

def setup(self, stage: Literal["fit", "predict"] | None = None) -> None:
# Get list of GeoTIFF filepaths from s3 bucket or data/ folder
if self.data_dir.startswith("s3://"):
dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir])
chips_path = list(dp.list_files_by_s3(masks="*.tif"))
chips_path = list(dp.list_files_by_s3(masks="*.npz"))
else: # if self.data_dir is a local data path
chips_path = sorted(list(Path(self.data_dir).glob("**/*.tif")))
chips_path = sorted(list(Path(self.data_dir).glob("**/*.npz")))
print(f"Total number of chips: {len(chips_path)}")

if stage == "fit":
random.shuffle(chips_path)
split = int(len(chips_path) * self.split_ratio)

self.trn_ds = ClayDataset(chips_path=chips_path[:split], transform=self.tfm)
self.val_ds = ClayDataset(chips_path=chips_path[split:], transform=self.tfm)
self.trn_ds = EODataset(
chips_path=chips_path[:split],
platform=self.platform,
metadata_path=self.metadata_path,
)
self.val_ds = EODataset(
chips_path=chips_path[split:],
platform=self.platform,
metadata_path=self.metadata_path,
)

elif stage == "predict":
self.prd_ds = ClayDataset(chips_path=chips_path, transform=self.tfm)
self.prd_ds = EODataset(
chips_path=chips_path,
platform=self.platform,
metadata_path=self.metadata_path,
)

def train_dataloader(self):
return DataLoader(
46 changes: 0 additions & 46 deletions src/metadata.yaml

This file was deleted.