Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

201 downscaling #472

Merged
merged 10 commits into from
Mar 10, 2025
Prev Previous commit
Next Next commit
test_wxc_downscaling_pincer_instantiate
romeokienzler committed Feb 14, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 7c1e9c057708b52c0777c4aa855a2e00deccfc5d
70 changes: 70 additions & 0 deletions integrationtests/test_prithvi_wxc_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright contributors to the Terratorch project

import os
import sys

import pytest
import torch
@@ -187,6 +188,75 @@ def test_wxc_unet_pincer_inference():

dist.destroy_process_group()

def test_wxc_downscaling_pincer_instantiate():
kwargs = {
"in_channels": 1280,
"input_size_time": 1,
"n_lats_px": 64,
"n_lons_px": 128,
"in_channels_static": 3,
"input_scalers_mu": torch.tensor([0] * 1280),
"input_scalers_sigma": torch.tensor([1] * 1280),
"input_scalers_epsilon": 0,
"static_input_scalers_mu": torch.tensor([0] * 3),
"static_input_scalers_sigma": torch.tensor([1] * 3),
"static_input_scalers_epsilon": 0,
"output_scalers": torch.tensor([0] * 1280),
"patch_size_px": [2, 2],
"mask_unit_size_px": [8, 16],
"mask_ratio_inputs": 0.5,
"embed_dim": 2560,
"n_blocks_encoder": 12,
"n_blocks_decoder": 2,
"mlp_multiplier": 4,
"n_heads": 16,
"dropout": 0.0,
"drop_path": 0.05,
"parameter_dropout": 0.0,
"residual": "none",
"masking_mode": "both",
"positional_encoding": "absolute",
"config_path": "./integrationtests/test_prithvi_wxc_model_factory_config.yaml",
}

WxCModelFactory().build_model(backbone="prithviwxc", aux_decoders="downscaler", **kwargs)

def test_wxc_downscaling_pincer_task():
model_args = {
"in_channels": 1280,
"input_size_time": 1,
"n_lats_px": 64,
"n_lons_px": 128,
"in_channels_static": 3,
"input_scalers_mu": torch.tensor([0] * 1280),
"input_scalers_sigma": torch.tensor([1] * 1280),
"input_scalers_epsilon": 0,
"static_input_scalers_mu": torch.tensor([0] * 3),
"static_input_scalers_sigma": torch.tensor([1] * 3),
"static_input_scalers_epsilon": 0,
"output_scalers": torch.tensor([0] * 1280),
"patch_size_px": [2, 2],
"mask_unit_size_px": [8, 16],
"mask_ratio_inputs": 0.5,
"embed_dim": 2560,
"n_blocks_encoder": 12,
"n_blocks_decoder": 2,
"mlp_multiplier": 4,
"n_heads": 16,
"dropout": 0.0,
"drop_path": 0.05,
"parameter_dropout": 0.0,
"residual": "none",
"masking_mode": "both",
"positional_encoding": "absolute",
"backbone": "prithviwxc",
"aux_decoders": "downscaler",
}


task = WxCTask(WxCModelFactory(), model_args=model_args, mode='train')



def test_wxc_unet_pincer_train():
os.environ['MASTER_ADDR'] = 'localhost'
118 changes: 118 additions & 0 deletions integrationtests/test_prithvi_wxc_model_factory_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
data:
type: merra2

# Input variables definition
input_surface_vars:
- EFLUX
- GWETROOT
- HFLUX
- LAI
- LWGAB # surface absorbed longwave radiation
- LWGEM # longwave flux emitted from surface
- LWTUP # upwelling longwave flux at toa
- PS # surface pressure
- QV2M # 2-meter specific humidity
- SLP # sea level pressure
- SWGNT # surface net downward shortwave flux
- SWTNT # toa net downward shortwave flux
- T2M # near surface temperature
- TQI # total precipitable ice water
- TQL # total precipitable liquid water
- TQV # total precipitable water vapor
- TS # surface skin temperature
- U10M # 10m eastward wind
- V10M # 10m northward wind
- Z0M # surface roughness
input_static_surface_vars: [FRACI, FRLAND, FROCEAN, PHIS]
input_vertical_vars:
- CLOUD # cloud feraction for radiation
- H # geopotential/ mid layer heights
- OMEGA # vertical pressure velocity
- PL # mid level pressure
- QI # mass fraction of clous ice water
- QL # mass fraction of cloud liquid water
- QV # specific humidity
- T # tempertaure
- U # eastward wind
- V # northward wind
# (model level/ml ~ pressure level/hPa)
# 52ml ~ 562.5hPa, 56ml ~ 700hPa, 63 ml ~ 850hPa
input_levels: [34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 53.0, 56.0, 63.0, 68.0, 72.0]
## remove: n_input_timestamps: 1
# Output variables definition
output_vars:
- T2M # near surface temperature

n_input_timestamps: 2

# Data transformations
# Initial crop before any other processing
crop_lat: [0, 1]
# crop_lon: [0, 0]
# coarsening of target -- applied after crop
input_size_lat: 60 # 6x coarsening
input_size_lon: 96 # 6x coarsening
apply_smoothen: True
data_path_surface: ~/Downloads/merra-2
data_path_vertical: ~/Downloads/merra-2
climatology_path_surface: ~/Downloads/climatology
climatology_path_vertical: ~/Downloads/climatology

model:

# Platform independent config
num_static_channels: 7
embed_dim: 2560
token_size:
- 1
- 1
n_blocks_encoder: 12
mlp_multiplier: 4
n_heads: 16
dropout_rate: 0.0
drop_path: 0.05

# Accepted values: temporal, climate, none
residual: climate

residual_connection: True
encoder_shift: False

downscaling_patch_size: [2, 2]
downscaling_embed_dim: 256
encoder_decoder_type: 'conv' # ['conv', 'transformer']
encoder_decoder_upsampling_mode: pixel_shuffle # ['nearest', 'bilinear', 'pixel_shuffle', 'conv_transpose']
encoder_decoder_kernel_size_per_stage: [[3], [3]] # Optional, default = 3 for conv_tanspose [[3], [2]]
encoder_decoder_scale_per_stage: [[2], [3]] # First list determines before/after backbone
encoder_decoder_conv_channels: 128
input_scalers_surface_path: ~/Downloads/climatology/musigma_surface.nc
input_scalers_vertical_path: ~/Downloads/climatology/musigma_vertical.nc
output_scalers_surface_path: ~/Downloads/climatology/anomaly_variance_surface.nc
output_scalers_vertical_path: ~/Downloads/climatology/anomaly_variance_vertical.nc



job_id: inference-test
batch_size: 1
num_epochs: 400
dl_num_workers: 2
dl_prefetch_size: 1
learning_rate: 0.0001
limit_steps_train: 250
limit_steps_valid: 25
min_lr: 0.00001
max_lr: 0.0002
warm_up_steps: 0
mask_unit_size:
- 15
- 16
mask_ratio_inputs: 0.0
mask_ratio_targets: 0.0
max_batch_size: 16

path_experiment: experiment

backbone_freeze: True
backbone_prefix: encoder.
finetune_w_static: True
strict_matching: true
8 changes: 4 additions & 4 deletions terratorch/models/pincers/wxc_downscaling_pincer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from granitewxc.utils.config import ExperimentConfig
from granitewxc.utils.downscaling_model import ClimateDownscaleFinetuneModel
from wxc_embedding_network import get_embedding_network
from wxc_upscaler import get_upscaler
from terratorch.models.pincers.wxc_embedding_network import get_embedding_network
from terratorch.models.pincers.wxc_upscaler import get_upscaler
from torch import nn
import numpy as np
from granitewxc.utils.downscaling_model import get_scalers

def get_downscaling_pincer(config: ExperimentConfig, backbone: nn.Model):
def get_downscaling_pincer(config: ExperimentConfig, backbone: nn.Module):

n_output_parameters = len(config.data.output_vars)
if config.model.__dict__.get('loss_type', 'patch_rmse_loss')=='cross_entropy':
@@ -16,7 +16,7 @@ def get_downscaling_pincer(config: ExperimentConfig, backbone: nn.Model):
n_output_parameters = len(np.load(config.model.cross_entropy_bin_boundaries_file)) + 1

embedding, embedding_static, upscale = get_embedding_network(config)
head = get_upscaler()
head = get_upscaler(config, n_output_parameters)
scalers = get_scalers(config)

model = ClimateDownscaleFinetuneModel(
28 changes: 9 additions & 19 deletions terratorch/models/pincers/wxc_upscaler.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from granitewxc.decoders.downscaling import ConvEncoderDecoder
from granitewxc.utils.config import ExperimentConfig

def get_upscaler(model_embed_dim: int,
model_encoder_decoder_conv_channels: int,
out_channels: int,
kernel_size: int,
scale,
upscale_mode,
):

def get_upscaler(config: ExperimentConfig, n_output_parameters: int) -> ConvEncoderDecoder:
head = ConvEncoderDecoder(
in_channels=model_embed_dim,
channels=model_encoder_decoder_conv_channels,
out_channels=out_channels,
kernel_size=kernel_size,
scale=scale,
upsampling_mode=upscale_mode,
#in_channels=config.model.embed_dim,
#channels=config.model.encoder_decoder_conv_channels,
#out_channels=n_output_parameters,
#kernel_size=config.model.encoder_decoder_kernel_size_per_stage[1],
#scale=config.model.encoder_decoder_scale_per_stage[1],
#upsampling_mode=config.model.encoder_decoder_upsampling_mode,
in_channels=config.model.embed_dim,
channels=config.model.encoder_decoder_conv_channels,
out_channels=n_output_parameters,
kernel_size=config.model.encoder_decoder_kernel_size_per_stage[1],
scale=config.model.encoder_decoder_scale_per_stage[1],
upsampling_mode=config.model.encoder_decoder_upsampling_mode,
)
return head
27 changes: 20 additions & 7 deletions terratorch/models/wxc_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Copyright contributors to the Terratorch project
import timm
import torch
from torch import nn

@@ -8,8 +6,6 @@
import logging
import importlib

import terratorch.models.decoders as decoder_registry
from terratorch.datasets import HLSBands
from terratorch.models.model import (
Model,
ModelFactory,
@@ -18,6 +14,8 @@
from terratorch.registry import MODEL_FACTORY_REGISTRY

from terratorch.models.pincers.unet_pincer import UNetPincer
from terratorch.models.pincers.wxc_downscaling_pincer import get_downscaling_pincer


logger = logging.getLogger(__name__)

@@ -61,8 +59,11 @@ def build_model(
raise

#remove parameters not meant for the backbone but for other parts of the model
logger.trace(kwargs)
skip_connection = kwargs.pop('skip_connection')
logger.debug(kwargs)
if 'skip_connection' in kwargs.keys():
skip_connection = kwargs.pop('skip_connection')
if 'config_path' in kwargs.keys():
config_path = kwargs.pop('config_path')

backbone = prithviwxc.PrithviWxC(**kwargs)

@@ -111,9 +112,21 @@ def build_model(

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone.to(device)
if aux_decoders is not None:
if aux_decoders == 'unetpincer':
model_to_return = UNetPincer(backbone, skip_connection=skip_connection).to(device)
return model_to_return
if aux_decoders == 'downscaler':
# from granitewxc.utils.config import get_config #TODO rkie fix: import flaky
from granitewxc.utils.config import ExperimentConfig
import yaml
def get_config(config_path: str) -> ExperimentConfig:
cfg = yaml.safe_load(open(config_path, 'r'))
return ExperimentConfig.from_dict(cfg)

# end TODO
config = get_config(config_path)
model_to_return = get_downscaling_pincer(config, backbone)
return model_to_return
return WxCModuleWrapper(backbone)