From 0adf769781f4c1c60c1bf26be681f258e95d5d11 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 2 Jul 2025 17:30:27 +0200 Subject: [PATCH 01/23] initial commit for ICON data reader - new version --- config/streams/streams_icon/icon.yml | 4 +- src/weathergen/datasets/data_reader_icon.py | 305 +++++++++++ src/weathergen/datasets/icon_dataset.py | 484 ------------------ .../datasets/multi_stream_data_sampler.py | 4 +- 4 files changed, 309 insertions(+), 488 deletions(-) create mode 100644 src/weathergen/datasets/data_reader_icon.py delete mode 100644 src/weathergen/datasets/icon_dataset.py diff --git a/config/streams/streams_icon/icon.yml b/config/streams/streams_icon/icon.yml index a38bbdc97..47275aed6 100644 --- a/config/streams/streams_icon/icon.yml +++ b/config/streams/streams_icon/icon.yml @@ -10,8 +10,8 @@ ICON : type : icon filenames : ['icon-art-NWP_OH_CHEMISTRY-chem_DOM01_ML_daily_repeat_reduced_levels.zarr'] - source : ['u_00', 'v_00', 'w_80', 'temp_00'] - target : ['u_00', 'v_00', 'w_80', 'temp_00'] + source_channels : ['u_00', 'v_00', 'w_80', 'temp_00'] + target_channels : ['u_00', 'v_00', 'w_80', 'temp_00'] loss_weight : 1. diagnostic : False masking_rate : 0.6 diff --git a/src/weathergen/datasets/data_reader_icon.py b/src/weathergen/datasets/data_reader_icon.py new file mode 100644 index 000000000..bdc72a810 --- /dev/null +++ b/src/weathergen/datasets/data_reader_icon.py @@ -0,0 +1,305 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import numpy as np +import json +import zarr + + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + +class DataReaderIcon(DataReaderTimestep): + "Wrapper for ICON data variables" + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + + """ + Construct data reader for ICON data variables + + Parameters + ---------- + filename : Path + filename (and path) of json kerchunk generated file + stream_info : Omega object + information about stream + + Attributes + ---------- + self.filename + self.ds + self.mesh_size + self.colnames + self.cols_idx + self.stats + self.time + self.start_idx + self.end_idx + self.len + self.lat + self.lon + self.step_hrs + self.properties + self.mean + self.stdev + self.source_channels + self.source_idx + self.target_channels + self.target_idx + self.geoinfo_channels + self.geoinfo_idx + + Returns + ------- + None + """ + + # loading datafile + self.filename = filename + self.ds = zarr.open(filename, mode="r") + self.mesh_size = self.ds.attrs["ncells"] + + # variables + self.colnames = list(self.ds) + self.cols_idx = np.array(list(np.arange(len(self.colnames)))) + + stats_filename = Path(filename).with_name(Path(filename).stem + ".json") + with open(stats_filename) as stats_file: + self.stats = json.load(stats_file) + + stats_vars = self.stats["metadata"]["variables"] + assert stats_vars == self.colnames, ( + f"Variables in normalization file {stats_vars} do not match dataset columns {self.colnames}" + ) + + # time + self.time = np.array(self.ds["time"], dtype="timedelta64[D]") + np.datetime64( + self.ds["time"].attrs["units"].split("since ")[-1] + ) + + start_ds = self.time[0] + end_ds = self.time[-1] + + if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: + name = stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + self.start_idx = (tw_handler.t_start - start_ds).astype("timedelta64[D]").astype( + int + ) * self.mesh_size + self.end_idx = ( + (tw_handler.t_end - start_ds).astype("timedelta64[D]").astype(int) + 1 + ) * self.mesh_size - 1 + + self.len = (self.end_idx - self.start_idx) // self.mesh_size + + assert self.end_idx > self.start_idx, ( + f"Abort: Final index of {self.end_idx} is the same of larger than start index {self.start_idx}" + ) + + # TODO @Asma - use something more generalizable + period = self.time[1] - self.time[0] + + super().__init__( + tw_handler, + stream_info, + start_ds, + end_ds, + period, + ) + + len_data_entries = len(self.time) * self.mesh_size + len_hrs = tw_handler.t_window_len + assert self.end_idx + len_hrs <= len_data_entries, ( + f"Abort: end_date must be set at least {len_hrs} before the last date in the dataset" + ) + + # coordinates + coords_units = self.ds["clat"].attrs['units'] + + if coords_units == "radian": + self.lat = np.rad2deg(self.ds["clat"][:].astype("f")) + self.lon = np.rad2deg(self.ds["clon"][:].astype("f")) + + else: + self.lat = self.ds["clat"][:].astype("f") + self.lon = self.ds["clon"][:].astype("f") + + # Ignore step_hrs, idk how it supposed to work + # TODO, TODO, TODO: + self.step_hrs = 1 + + self.properties = { + "stream_id": 0, + } + + # stats + stats_vars = self.stats["metadata"]["variables"] + assert stats_vars == self.colnames, ( + f"Variables in normalization file {stats_vars} do not match dataset columns {self.colnames}" + ) + self.mean = np.array(self.stats["statistics"]["mean"], dtype="d") + self.stdev = np.array(self.stats["statistics"]["std"], dtype="d") + + + source_channels = stream_info.get("source_channels") + if source_channels: + self.source_channels, self.source_idx = self.select(source_channels) + else: + self.source_channels = self.colnames + self.source_idx = self.cols_idx + + target_channels = stream_info.get("target_channels") + if target_channels: + self.target_channels, self.target_idx = self.select(target_channels) + else: + self.target_channels = self.colnames + self.target_idx = self.cols_idx + + # Check if standard deviations are strictly positive for selected channels + selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) + non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] + assert len(non_positive_stds) == 0, ( + f"Abort: Encountered non-positive standard deviations for selected columns {[self.colnames[selected_channel_indices][i] for i in non_positive_stds]}." + ) + + self.geoinfo_channels = [] + self.geoinfo_idx = [] + + def select(self, ch_filters: list[str]) -> (np.array, list[str]): + """ + Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + + Parameters + ---------- + ch_filters: list[str] + list of patterns to access + + Returns + ------- + selected_colnames: np.array, + Selected columns according to the patterns specified in ch_filters + selected_cols_idx + respective index of these patterns in the data array + """ + mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] + + selected_cols_idx = self.cols_idx[np.where(mask)[0]] + selected_colnames = [self.colnames[int(i)] for i in np.where(mask)[0]] + + return selected_colnames, selected_cols_idx + + @override + def init_empty(self) -> None: + super().init_empty() + self.len = 0 + + @override + def length(self) -> int: + """ + Length of dataset + + Parameters + ---------- + None + + Returns + ------- + length of dataset + """ + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for temporal window + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + + Returns + ------- + data (coords, geoinfos, data, datetimes) + """ + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # TODO: handle sub-sampling + + t_idxs_start = t_idxs[0] + t_idxs_end = t_idxs[-1] + 1 + + # datetime + datetimes = self.time[t_idxs_start:t_idxs_end] + + # lat/lon coordinates + tiling to match time steps + lat = self.lat[:, np.newaxis] + lon = self.lon[:, np.newaxis] + + lat = np.tile(lat, len(datetimes)) + lon = np.tile(lon, len(datetimes)) + + coords = np.concatenate([lat, lon], axis=1) + + # time coordinate repeated to match grid points + datetimes = np.repeat(datetimes, self.mesh_size).reshape(-1, 1) + datetimes = np.squeeze(datetimes) + # print(f"datetimes.shape = {datetimes.shape}", flush = True) + + # expanding indexes for data + start_row = t_idxs_start * self.mesh_size + end_row = t_idxs_end * self.mesh_size + + # data + channels = np.array(self.colnames)[channels_idx] + data_reshaped = [ + np.asarray(self.ds[ch_]).reshape(-1, 1)[start_row:end_row] for ch_ in channels + ] + data = np.concatenate(data_reshaped, axis=1) + + # empty geoinfos + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + + return rd \ No newline at end of file diff --git a/src/weathergen/datasets/icon_dataset.py b/src/weathergen/datasets/icon_dataset.py deleted file mode 100644 index abc17e32a..000000000 --- a/src/weathergen/datasets/icon_dataset.py +++ /dev/null @@ -1,484 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -from datetime import datetime -from pathlib import Path - -import numpy as np -import torch -import zarr - - -class IconDataset: - """ - A data reader for ICON model output stored in zarr. - - Parameters - ---------- - start : datetime | int - Start time of the data period as datetime object or integer in "%Y%m%d%H%M" format - end : datetime | int - End time of the data period (inclusive) with same format as start - len_hrs : int - Length of temporal windows in days - step_hrs : int - (Currently unused) Intended step size between windows in hours - filename : Path - Path to Zarr dataset containing ICON output - stream_info : dict[str, list[str]] - Dictionary with "source" and "target" keys specifying channel subsets to use - (e.g., {"source": ["temp_00"], "target": ["TRCH4_chemtr_00"]}) - - Attributes - ---------- - len_hrs : int - Temporal window length in days - mesh_size : int - Number of nodes in the ICON mesh - source_channels : list[str] - Patterns of selected source channels - target_channels : list[str] - Patterns of selected target channels - mean : np.ndarray - Per-channel means for normalization (includes coordinates) - stdev : np.ndarray - Per-channel standard deviations for normalization (includes coordinates) - properties : dict[str, list[str]] - Dataset metadata including 'stream_id' from Zarr attributes - - """ - - def __init__( - self, - start: datetime | int, - end: datetime | int, - len_hrs: int, - step_hrs: int, - filename: Path, - stream_info: dict, - ): - self.len_hrs = len_hrs - - format_str = "%Y%m%d%H%M" - if type(start) is not datetime: - start = datetime.strptime(str(start), format_str) - start = np.datetime64(start).astype("datetime64[D]") - - if type(end) is not datetime: - end = datetime.strptime(str(end), format_str) - end = np.datetime64(end).astype("datetime64[D]") - - # loading datafile - self.filename = filename - self.ds = zarr.open(filename, mode="r") - self.mesh_size = self.ds.attrs["ncells"] - - # Loading stat file - stats_filename = Path(filename).with_suffix(".json") - with open(stats_filename) as stats_file: - self.stats = json.load(stats_file) - - time_as_in_data_file = np.array(self.ds["time"], dtype="timedelta64[D]") + np.datetime64( - self.ds["time"].attrs["units"].split("since ")[-1] - ) - - start_ds = time_as_in_data_file[0] - end_ds = time_as_in_data_file[-1] - - # asserting start and end times - if start_ds > end or end_ds < start: - # TODO: this should be set in the base class - self.source_channels = [] - self.target_channels = [] - self.source_idx = np.array([]) - self.target_idx = np.array([]) - self.geoinfo_idx = [] - self.len = 0 - self.ds = None - return - - self.start_idx = (start - start_ds).astype("timedelta64[D]").astype(int) * self.mesh_size - self.end_idx = ( - (end - start_ds).astype("timedelta64[D]").astype(int) + 1 - ) * self.mesh_size - 1 - - self.len = (self.end_idx - self.start_idx) // self.mesh_size - - assert self.end_idx > self.start_idx, ( - f"Abort: Final index of {self.end_idx} is the same of larger than", - f" start index {self.start_idx}", - ) - - len_data_entries = len(self.ds["time"]) * self.mesh_size - - assert self.end_idx + len_hrs <= len_data_entries, ( - f"Abort: end_date must be set at least {len_hrs} before the last date in the dataset" - ) - - # variables - self.colnames = list(self.ds) - self.cols_idx = np.array(list(np.arange(len(self.colnames)))) - - # Ignore step_hrs, idk how it supposed to work - # TODO, TODO, TODO: - self.step_hrs = 1 - - # time - repeated_times = np.repeat(time_as_in_data_file, self.mesh_size).reshape(-1, 1) - self.time = repeated_times - - # coordinates - coords_units = self.ds["clat"].attrs["units"] - - if coords_units == "radian": - lat_as_in_data_file = np.rad2deg(self.ds["clat"][:].astype("f")) - lon_as_in_data_file = np.rad2deg(self.ds["clon"][:].astype("f")) - - else: - lat_as_in_data_file = self.ds["clat"][:].astype("f") - lon_as_in_data_file = self.ds["clon"][:].astype("f") - - self.lat = np.tile(lat_as_in_data_file, len(time_as_in_data_file)) - self.lon = np.tile(lon_as_in_data_file, len(time_as_in_data_file)) - - self.properties = {"stream_id": 0} - - # stats - stats_vars = self.stats["metadata"]["variables"] - assert stats_vars == self.colnames, ( - f"Variables in normalization file {stats_vars}" - f"do not match dataset columns {self.colnames}" - ) - - self.mean = np.array(self.stats["statistics"]["mean"], dtype="d") - self.stdev = np.array(self.stats["statistics"]["std"], dtype="d") - - # Channel selection and indexing - source_channels = stream_info["source"] if "source" in stream_info else None - if source_channels: - self.source_channels, self.source_idx = self.select(source_channels) - else: - self.source_channels = self.colnames - self.source_idx = self.cols_idx - - target_channels = stream_info["target"] if "target" in stream_info else None - if target_channels: - self.target_channels, self.target_idx = self.select(target_channels) - else: - self.target_channels = self.colnames - self.target_idx = self.cols_idx - - # Check if standard deviations are strictly positive for selected channels - selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) - non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] - assert len(non_positive_stds) == 0, ( - f"Abort: Encountered non-positive standard deviations " - f"for selected columns { - [self.colnames[selected_channel_indices][i] for i in non_positive_stds] - }." - ) - # TODO: define in base class - self.geoinfo_idx = [] - - def select(self, ch_filters: list[str]) -> tuple[list[str], np.array]: - """ - Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ - - mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] - - selected_cols_idx = np.where(mask)[0] - selected_colnames = [self.colnames[i] for i in selected_cols_idx] - - return selected_colnames, selected_cols_idx - - def __len__(self) -> int: - """ - Length of dataset - - Parameters - ---------- - None - - Returns - ------- - length of dataset - """ - return self.len - - def _get(self, idx: int, channels: np.array) -> tuple: - """ - Get data for window - - Parameters - ---------- - idx : int - Index of temporal window - channels_idx : np.array - Selection of channels - - Returns - ------- - data (coords, geoinfos, data, datetimes) - """ - if self.ds is None: - fp32 = np.float32 - return ( - np.array([], dtype=fp32), - np.array([], dtype=fp32), - np.array([], dtype=fp32), - np.array([], dtype=fp32), - ) - - # indexing - start_row = self.start_idx + idx * self.mesh_size - end_row = start_row + self.len_hrs * self.mesh_size - - # data - data_reshaped = [ - np.asarray(self.ds[ch_]).reshape(-1, 1)[start_row:end_row] for ch_ in channels - ] - data = np.concatenate(data_reshaped, axis=1) - - lat = np.expand_dims(self.lat[start_row:end_row], 1) - lon = np.expand_dims(self.lon[start_row:end_row], 1) - - latlon = np.concatenate([lat, lon], 1) - - # empty geoinfos - geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) - datetimes = np.squeeze(self.time[start_row:end_row]) - - return (latlon, geoinfos, data, datetimes) - - def get_source(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]: - """ - Get source data for idx - - Parameters - ---------- - idx : int - Index of temporal window - - Returns - ------- - source data (coords, geoinfos, data, datetimes) - """ - return self._get(idx, self.source_channels) - - def get_target(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]: - """ - Get target data for idx - - Parameters - ---------- - idx : int - Index of temporal window - - Returns - ------- - target data (coords, geoinfos, data, datetimes) - """ - return self._get(idx, self.target_channels) - - def get_source_size(self) -> int: - """ - Get size of all columns, including coordinates and geoinfo, with source - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 + len(self.geoinfo_idx) + len(self.source_idx) if self.ds else 0 - - def get_target_size(self) -> int: - """ - Get size of all columns, including coordinates and geoinfo, with source - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 + len(self.geoinfo_idx) + len(self.target_idx) if self.ds else 0 - - def get_coords_size(self) -> int: - """ - Get size of coords - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 - - def normalize_coords(self, coords: torch.tensor) -> torch.tensor: - """ - Normalize coordinates - - Parameters - ---------- - coords : - coordinates to be normalized - - Returns - ------- - Normalized coordinates - """ - coords[..., 0] = np.sin(np.deg2rad(coords[..., 0])) - coords[..., 1] = np.sin(0.5 * np.deg2rad(coords[..., 1])) - - return coords - - def normalize_source_channels(self, source: torch.tensor) -> torch.tensor: - """ - Normalize source channels - - Parameters - ---------- - source : - data to be normalized - - Returns - ------- - Normalized data - """ - assert source.shape[1] == len(self.source_idx) - for i, ch in enumerate(self.source_idx): - source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] - - return source - - def normalize_target_channels(self, target: torch.tensor) -> torch.tensor: - """ - Normalize target channels - - Parameters - ---------- - target : - data to be normalized - - Returns - ------- - Normalized data - """ - assert target.shape[1] == len(self.target_idx) - for i, ch in enumerate(self.target_idx): - target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch] - - return target - - def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: - """ - Temporal window corresponding to index - - Parameters - ---------- - idx : - index of temporal window - - Returns - ------- - start and end of temporal window - """ - start_row = self.start_idx + idx * self.mesh_size - end_row = start_row + self.len_hrs * self.mesh_size - - return (self.time[start_row, 0], self.time[end_row, 0]) - - def denormalize_target_channels(self, data: torch.tensor) -> torch.tensor: - """ - Denormalize target channels - - Parameters - ---------- - data : - data to be denormalized (target or pred) - - Returns - ------- - Denormalized data - """ - assert data.shape[-1] == len(self.target_idx), "incorrect number of channels" - for i, ch in enumerate(self.target_idx): - data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch] - - return data - - def get_source_num_channels(self) -> int: - """ - Get number of source channels - - Parameters - ---------- - None - - Returns - ------- - number of source channels - """ - return len(self.source_idx) - - def get_target_num_channels(self) -> int: - """ - Get number of target channels - - Parameters - ---------- - None - - Returns - ------- - number of target channels - """ - return len(self.target_idx) - - def get_geoinfo_size(self) -> int: - """ - Get size of geoinfos - - Parameters - ---------- - None - - Returns - ------- - size of geoinfos - """ - return len(self.geoinfo_idx) - - def normalize_geoinfos(self, geoinfos: torch.tensor) -> torch.tensor: - """ - Normalize geoinfos - - Parameters - ---------- - geoinfos : - geoinfos to be normalized - - Returns - ------- - Normalized geoinfo - """ - - assert geoinfos.shape[-1] == 0, "incorrect number of geoinfo channels" - return geoinfos diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 3a6b819ed..df3a253e1 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -23,7 +23,7 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs -from weathergen.datasets.icon_dataset import IconDataset +from weathergen.datasets.data_reader_icon import DataReaderIcon from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData from weathergen.datasets.tokenizer_forecast import TokenizerForecast @@ -113,7 +113,7 @@ def __init__( dataset = DataReaderFesom datapath = cf.data_path_fesom case "icon": - dataset = IconDataset + dataset = DataReaderIcon datapath = cf.data_path_icon case _: msg = f"Unsupported stream type {stream_info['type']}" From 97b507577f6b7667faa2eb48c2d10a67e3622901 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Fri, 8 Aug 2025 07:01:32 +0200 Subject: [PATCH 02/23] modifications to make d4fcwk5b pretrained model compatible with new core updates --- config/default_config.yml | 3 +-- config/streams/streams_anemoi/era5.yml | 2 +- src/weathergen/train/trainer.py | 5 +++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 5049c87cd..7005fa3fc 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -129,5 +129,4 @@ run_id: ??? # Parameters for logging/printing in the training loop train_log: # The period to log metrics (in number of batch steps) - log_interval: 20 - + log_interval: 20 \ No newline at end of file diff --git a/config/streams/streams_anemoi/era5.yml b/config/streams/streams_anemoi/era5.yml index ffac94a3f..68da2e480 100644 --- a/config/streams/streams_anemoi/era5.yml +++ b/config/streams/streams_anemoi/era5.yml @@ -35,4 +35,4 @@ ERA5 : # sampling_rate : 0.2 pred_head : ens_size : 1 - num_layers : 1 + num_layers : 1 \ No newline at end of file diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 67a8505f4..b2c36da96 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -81,6 +81,11 @@ def init( ########################################### def inference(self, cf, run_id_trained, epoch): # general initalization + + # Asma: This is a quick fix, won't be useful in the future + cf.batch_size_per_gpu = 1 + cf.batch_size_validation_per_gpu = 1 + ######## End of stupid code self.init(cf) self.dataset_val = MultiStreamDataSampler( From b0c08239bd16a710123418ca1d0d5d5f9b019ad4 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 18 Aug 2025 05:52:43 +0200 Subject: [PATCH 03/23] CAMS data reader python script --- src/weathergen/datasets/data_reader_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index a84e3a441..0c2bc0a58 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -319,7 +319,6 @@ def __len__(self) -> int: ------- length of dataset """ - return self.length() def get_source(self, idx: TIndex) -> ReaderData: From c083eeea400a573943c6d19bad6c94fa2cf78efb Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 18 Aug 2025 05:53:27 +0200 Subject: [PATCH 04/23] addition CAMS reader to the stream reading script --- src/weathergen/datasets/multi_stream_data_sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index df3a253e1..4b7d4bebd 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -24,6 +24,7 @@ from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs from weathergen.datasets.data_reader_icon import DataReaderIcon +from weathergen.datasets.data_reader_cams import DataReaderCams from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData from weathergen.datasets.tokenizer_forecast import TokenizerForecast @@ -115,6 +116,9 @@ def __init__( case "icon": dataset = DataReaderIcon datapath = cf.data_path_icon + case "camseac4": + dataset = DataReaderCams + datapath = cf.data_path_cams case _: msg = f"Unsupported stream type {stream_info['type']}" f"for stream name '{stream_info['name']}'." @@ -298,6 +302,7 @@ def __iter__(self): # idx_raw is used to index into the dataset; the decoupling is needed # since there are empty batches idx_raw = iter_start + for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)): # forecast_dt needs to be constant per batch (amortized through data parallel training) forecast_dt = self.perms_forecast_dt[i] From 12fdc2c0fdc93b9f8e6adce88815481ecf81727a Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 18 Aug 2025 05:54:59 +0200 Subject: [PATCH 05/23] CAMS EAC4 stream config --- .../streams/streams_cams_eac4/cams_eac4.yml | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 config/streams/streams_cams_eac4/cams_eac4.yml diff --git a/config/streams/streams_cams_eac4/cams_eac4.yml b/config/streams/streams_cams_eac4/cams_eac4.yml new file mode 100644 index 000000000..e6585970f --- /dev/null +++ b/config/streams/streams_cams_eac4/cams_eac4.yml @@ -0,0 +1,136 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +CAMSEAC4 : + type : camseac4 + filenames : ['cams_eac4_2017_2022.zarr'] + source : [ + # Surface variables + '2t', 'u10', 'v10', 'msl', 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Specific humidity (q) + 'q_1000', 'q_700', 'q_400', + 'q_250', 'q_150', 'q_50', + + # Temperature (t) + 't_1000', 't_700', 't_400', + 't_250', 't_150', 't_50', + + # U wind (u) + 'u_1000', 'u_700', 'u_400', + 'u_250', 'u_150', 'u_50', + + # V wind (v) + 'v_1000', 'v_700', 'v_400', + 'v_250', 'v_150', 'v_50', + + # Ozone (o3) + 'o3_1000', 'o3_700', 'o3_400', + 'o3_250', 'o3_150', 'o3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_700', 'so2_400', + 'so2_250', 'so2_150', 'so2_50', + + # Geopotential height (z) + 'z_1000', 'z_700', 'z_400', + 'z_250', 'z_150', 'z_50', + ] + target : [ + # Surface variables + '2t', 'u10', 'v10', 'msl', 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Specific humidity (q) + 'q_1000', 'q_700', 'q_400', + 'q_250', 'q_150', 'q_50', + + # Temperature (t) + 't_1000', 't_700', 't_400', + 't_250', 't_150', 't_50', + + # U wind (u) + 'u_1000', 'u_700', 'u_400', + 'u_250', 'u_150', 'u_50', + + # V wind (v) + 'v_1000', 'v_700', 'v_400', + 'v_250', 'v_150', 'v_50', + + # Ozone (o3) + 'o3_1000', 'o3_700', 'o3_400', + 'o3_250', 'o3_150', 'o3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_700', 'so2_400', + 'so2_250', 'so2_150', 'so2_50', + + # Geopotential height (z) + 'z_1000', 'z_700', 'z_400', + 'z_250', 'z_150', 'z_50', + + ] + # source_exclude : [] + # target_exclude : [] + variables: [ + # Surface variables + '2t', 'u10', 'v10', 'msl', 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Specific humidity (q) + 'q_1000', 'q_925', 'q_850', 'q_700', 'q_600', 'q_500', + 'q_400', 'q_300', 'q_250', 'q_200', 'q_150', 'q_100', 'q_50', + + # Temperature (t) + 't_1000', 't_925', 't_850', 't_700', 't_600', 't_500', + 't_400', 't_300', 't_250', 't_200', 't_150', 't_100', 't_50', + + # U wind (u) + 'u_1000', 'u_925', 'u_850', 'u_700', 'u_600', 'u_500', + 'u_400', 'u_300', 'u_250', 'u_200', 'u_150', 'u_100', 'u_50', + + # V wind (v) + 'v_1000', 'v_925', 'v_850', 'v_700', 'v_600', 'v_500', + 'v_400', 'v_300', 'v_250', 'v_200', 'v_150', 'v_100', 'v_50', + + # Ozone (o3) + 'o3_1000', 'o3_925', 'o3_850', 'o3_700', 'o3_600', 'o3_500', + 'o3_400', 'o3_300', 'o3_250', 'o3_200', 'o3_150', 'o3_100', 'o3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Geopotential height (z) + 'z_1000', 'z_925', 'z_850', 'z_700', 'z_600', 'z_500', + 'z_400', 'z_250', 'z_200', 'z_150', 'z_100', 'z_50', + ] + loss_weight : 1. + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file From dfcb0b1dfc15fbbbc35a62f4e668bafb9fd8a2fd Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 18 Aug 2025 05:55:42 +0200 Subject: [PATCH 06/23] CAMS EAC4 model config --- config/cams_eac4_config.yml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 config/cams_eac4_config.yml diff --git a/config/cams_eac4_config.yml b/config/cams_eac4_config.yml new file mode 100644 index 000000000..a89ee10f3 --- /dev/null +++ b/config/cams_eac4_config.yml @@ -0,0 +1,21 @@ +streams_directory: "./config/streams/streams_cams_eac4/" + +start_date: 201710010000 +end_date: 202112312100 +start_date_val: 202201010000 +end_date_val: 202205312100 + +num_epochs: 100 # 10 +token_size: 64 + +step_hrs: 1 + +samples_per_epoch: 90 # works # 100 works +samples_per_validation: 17 +shuffle: True + +loader_num_workers: 16 + +masking_rate: 0.8 + +with_mixed_precision: True \ No newline at end of file From 7c1a7c8039f6bbb4c6963b4b57c76a76b76356c9 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 18 Aug 2025 05:57:37 +0200 Subject: [PATCH 07/23] some temporary config files --- config/cams_eac4_config_test.yml | 21 ++++++++++++++ config/era5_d4fcwk5b_config.yml | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 config/cams_eac4_config_test.yml create mode 100644 config/era5_d4fcwk5b_config.yml diff --git a/config/cams_eac4_config_test.yml b/config/cams_eac4_config_test.yml new file mode 100644 index 000000000..af557d0ee --- /dev/null +++ b/config/cams_eac4_config_test.yml @@ -0,0 +1,21 @@ +streams_directory: "./config/streams/streams_cams_eac4/" + +start_date: 201710010000 +end_date: 202112312100 +start_date_val: 202201010000 +end_date_val: 202205312100 + +num_epochs: 1 # 10 +token_size: 64 + +step_hrs: 1 + +samples_per_epoch: 17 # works # 100 works +samples_per_validation: 17 +shuffle: True + +loader_num_workers: 16 + +masking_rate: 0.8 + +with_mixed_precision: True \ No newline at end of file diff --git a/config/era5_d4fcwk5b_config.yml b/config/era5_d4fcwk5b_config.yml new file mode 100644 index 000000000..aead5cd89 --- /dev/null +++ b/config/era5_d4fcwk5b_config.yml @@ -0,0 +1,48 @@ +ae_local_dim_embed: 1024 +ae_local_num_blocks: 0 +ae_local_num_queries: 1 +ae_global_num_blocks: 4 + +forecast_offset: 1 +forecast_steps: 2 +forecast_policy: "fixed" +forecast_att_dense_rate: 1.0 + +fe_num_blocks: 8 + +len_hrs: 6 +step_hrs: 6 + +grad_clip: 1.0 +lr_max: 0.00005 + +istep: 12288 +data_loader_rng_seed: 1752001227 +run_id: "d4fcwk5b" + +### The following was in path: WeatherGenerator-private/hpc/levante/config/era5_d4fcwk5b_config.yml +### Which I deleted (for cleaning) +# ae_local_dim_embed: 1024 +# ae_local_num_blocks: 0 +# ae_local_num_queries: 1 +# # ae_global_num_blocks: 4 + +# ae_global_num_blocks: 12 + + +# forecast_offset: 1 +# forecast_steps: 2 +# forecast_policy: "fixed" +# forecast_att_dense_rate: 1.0 + +# fe_num_blocks: 8 + +# len_hrs: 6 +# step_hrs: 6 + +# grad_clip: 1.0 +# lr_max: 0.00005 + +# istep: 12288 +# data_loader_rng_seed: 1752001227 +# run_id: "d4fcwk5b" From c293ffd5eb18be1aa11f8ae46547783f9cba8a1e Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 18 Aug 2025 05:57:59 +0200 Subject: [PATCH 08/23] CAMS data reader python script --- src/weathergen/datasets/data_reader_cams.py | 281 ++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 src/weathergen/datasets/data_reader_cams.py diff --git a/src/weathergen/datasets/data_reader_cams.py b/src/weathergen/datasets/data_reader_cams.py new file mode 100644 index 000000000..7dca62bec --- /dev/null +++ b/src/weathergen/datasets/data_reader_cams.py @@ -0,0 +1,281 @@ +import json +import logging +from pathlib import Path +from typing import override + +import numpy as np +import xarray as xr + +from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + + +class DataReaderCams(DataReaderTimestep): + "Wrapper for CAMs data variables" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + """ + Parameters + ---------- + tw_handler : TimeWindowHandler + Handles temporal slicing and mapping from time indices to datetime + filename : + filename (and path) of dataset + stream_info : dict + Stream metadata + """ + + # ======= Reading the Dataset ================ + + # Open the dataset using Xarray with Zarr engine + self.ds = xr.open_dataset(filename, engine="zarr") + + # Column (variable) names and indices + self.colnames = stream_info["variables"] # list(self.ds) + self.cols_idx = np.array(list(np.arange(len(self.colnames)))) + + # Load associated statistics file for normalization + stats_filename = Path(filename).with_name(Path(filename).stem + "_stats.json") + with open(stats_filename) as stats_file: + self.stats = json.load(stats_file) + + # Variables included in the stats + self.stats_vars = list(self.stats) + + # Load mean and standard deviation per variable + self.mean = np.array([self.stats[var]["mean"] for var in self.stats_vars], dtype=np.float64) + self.stdev = np.array([self.stats[var]["std"] for var in self.stats_vars], dtype=np.float64) + + # Extract coordinates and pressure level + self.lat = _clip_lat(self.ds["latitude"].values) + self.lon = _clip_lon(self.ds["longitude"].values) + + # Time range in the dataset + self.time = self.ds["time"].values + start_ds = np.datetime64(self.time[0]) + end_ds = np.datetime64(self.time[-1]) + self.temporal_frequency = self.time[1] - self.time[0] + + # # Skip stream if it doesn't intersect with time window + # print(f"start_ds = {start_ds}") + # print(f"tw_handler.t_end = {tw_handler.t_end}") + # print(f"end_ds = {end_ds}") + # print(f"tw_handler.t_start = {tw_handler.t_start}") + # """ + # 0: start_ds = 2017-10-01T00:00:00.000000000 + # 0: tw_handler.t_end = 2017-01-03T00:00:00.000000 + # 0: end_ds = 2022-05-31T21:00:00.000000000 + # 0: tw_handler.t_start = 2017-01-01T00:00:00.000000 + + # """ + + if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: + # print("inside skipping stream") + name = "plop" # stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + # Initialize parent class with resolved time window + super().__init__( + tw_handler, + stream_info, + start_ds, + end_ds, + self.temporal_frequency, + ) + + # Compute absolute start/end indices in the dataset based on time window + self.start_idx = (tw_handler.t_start - start_ds).astype("timedelta64[ns]").astype(int) + self.end_idx = (tw_handler.t_end - start_ds).astype("timedelta64[ns]").astype(int) + 1 + + # Asma TODO check if self.len checks out + # Number of time steps in selected range + self.len = self.end_idx - self.start_idx + 1 + + # Placeholder; currently unused + self.step_hrs = 1 + + # Stream metadata + self.properties = { + "stream_id": 0, + } + + # === Normalization statistics === + + # Ensure stats match dataset columns + assert self.stats_vars == self.colnames, ( + f"Variables in normalization file {self.stats_vars} do not match " + f"dataset columns {self.colnames}" + ) + + # === Channel selection === + + # Source channels and levels + source_channels = stream_info.get("source") + if source_channels: + self.source_channels, self.source_idx = self.select(source_channels) + else: + self.source_channels = self.colnames + self.source_idx = self.cols_idx + # self.source_levels = self.get_levels(self.source_channels) + + # Target channels and levels + target_channels = stream_info.get("target") + if target_channels: + self.target_channels, self.target_idx = self.select(target_channels) + else: + self.target_channels = self.colnames + self.target_idx = self.cols_idx + # self.target_levels = self.get_levels(self.target_channels) + + + + # Ensure all selected channels have valid standard deviations + selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) + non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] + assert len(non_positive_stds) == 0, ( + f"Abort: Encountered non-positive standard deviations for selected columns " + f"{[self.colnames[selected_channel_indices][i] for i in non_positive_stds]}." + ) + + # === Geo-info channels (currently unused) === + self.geoinfo_channels = [] + self.geoinfo_idx = [] + + + def select(self, ch_filters: list[str]) -> (np.array, list[str]): + """ + Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + Parameters + ---------- + ch_filters: list[str] + list of patterns to access + Returns + ------- + selected_colnames: np.array, + Selected columns according to the patterns specified in ch_filters + selected_cols_idx + respective index of these patterns in the data array + """ + mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] + + selected_cols_idx = self.cols_idx[np.where(mask)[0]] + selected_colnames = [self.colnames[int(i)] for i in np.where(mask)[0]] + + return selected_colnames, selected_cols_idx + + # # Asma TODO test it once kerchunk is ready + # def get_levels(self, channels: list[str]) -> list: + + @override + def init_empty(self) -> None: + super().init_empty() + self.len = 0 + + @override + def length(self) -> int: + """ + Length of dataset + Parameters + ---------- + None + Returns + ------- + length of dataset + """ + return self.len + + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for temporal window + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + Returns + ------- + data (coords, geoinfos, data, datetimes) + """ + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # TODO: handle sub-sampling + # print(f"#2.5.1 inside DataReaderCams._get()", flush=True) + t_idxs_start = t_idxs[0] + t_idxs_end = t_idxs[-1] + 1 + + # datetime + datetimes = self.time[t_idxs_start:t_idxs_end] + + # =========== lat/lon coordinates + tiling to match time steps ========== + + # making lon, lat like a mesh for easier flattening later + lon2d, lat2d = np.meshgrid(self.lon, self.lat) + + # Flatten to match (lat, lon) storage order in your array + lat = lat2d.flatten()[:, np.newaxis] # shape (241*480,) + lon = lon2d.flatten()[:, np.newaxis] + + lat = np.tile(lat, len(datetimes)) + lon = np.tile(lon, len(datetimes)) + + coords = np.concatenate([lat, lon], axis=0) + + # data + channels = np.array(self.colnames)[channels_idx] + # print(f"#2.5.2 inside DataReaderCams._get() before data", flush=True) + # for ch_ in channels: + # print(f"self.ds[{ch_}] = {self.ds[ch_].shape}", flush=True) + data_reshaped = [ + np.asarray(self.ds[ch_][t_idxs_start:t_idxs_end, :, :]).reshape(-1, 1) for ch_ in channels + ] + # print(f"#2.5.3 inside DataReaderCams._get() after data", flush=True) + data = np.concatenate(data_reshaped, axis=1) + + # time coordinate repeated to match grid points + datetimes = np.repeat(datetimes, len(data) // len(t_idxs)) + + # empty geoinfos + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + # print(f"self.lon.shape = {self.lon.shape} self.lat.shape = {self.lat.shape}",flush=True) + # print(f"lon.shape = {lon.shape} lat.shape = {lat.shape}",flush=True) + # print(f"datetimes.shape = {datetimes.shape}",flush=True) + # print(f"data.shape = {data.shape}",flush=True) + # print(f"coords.shape = {coords.shape}",flush=True) + + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + + return rd \ No newline at end of file From 7632dad4b77da26edceade7f8679f1b713233edb Mon Sep 17 00:00:00 2001 From: sbAsma Date: Thu, 28 Aug 2025 15:46:35 +0200 Subject: [PATCH 09/23] changes for flexible finetuning datasets adding --- config/cams_eac4_config.yml | 17 ++- .../streams/streams_cams_eac4/cams_eac4.yml | 95 +++------------- src/weathergen/model/engines.py | 1 + src/weathergen/model/model.py | 104 +++++++++++++++--- src/weathergen/train/trainer.py | 12 +- src/weathergen/utils/config.py | 2 + 6 files changed, 128 insertions(+), 103 deletions(-) diff --git a/config/cams_eac4_config.yml b/config/cams_eac4_config.yml index a89ee10f3..2a0346497 100644 --- a/config/cams_eac4_config.yml +++ b/config/cams_eac4_config.yml @@ -1,17 +1,16 @@ -streams_directory: "./config/streams/streams_cams_eac4/" +streams_directory: "./config/streams/streams_cams_eac4/" -start_date: 201710010000 -end_date: 202112312100 -start_date_val: 202201010000 -end_date_val: 202205312100 +start_date: 200301010000 +# end_date: 202112310000 +# start_date_val: 202201010000 +# end_date_val: 202205300000 -num_epochs: 100 # 10 -token_size: 64 +num_epochs: 63 # 10 step_hrs: 1 -samples_per_epoch: 90 # works # 100 works -samples_per_validation: 17 +samples_per_epoch: 1500 +samples_per_validation: 300 shuffle: True loader_num_workers: 16 diff --git a/config/streams/streams_cams_eac4/cams_eac4.yml b/config/streams/streams_cams_eac4/cams_eac4.yml index e6585970f..09601ba06 100644 --- a/config/streams/streams_cams_eac4/cams_eac4.yml +++ b/config/streams/streams_cams_eac4/cams_eac4.yml @@ -9,108 +9,47 @@ CAMSEAC4 : type : camseac4 - filenames : ['cams_eac4_2017_2022.zarr'] + filenames : ['cams_eac4_2003_2024.zarr'] source : [ # Surface variables - '2t', 'u10', 'v10', 'msl', 'pm1', 'pm2p5', 'pm10', + 'pm1', 'pm2p5', 'pm10', 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', - # Specific humidity (q) - 'q_1000', 'q_700', 'q_400', - 'q_250', 'q_150', 'q_50', - - # Temperature (t) - 't_1000', 't_700', 't_400', - 't_250', 't_150', 't_50', - - # U wind (u) - 'u_1000', 'u_700', 'u_400', - 'u_250', 'u_150', 'u_50', - - # V wind (v) - 'v_1000', 'v_700', 'v_400', - 'v_250', 'v_150', 'v_50', - # Ozone (o3) - 'o3_1000', 'o3_700', 'o3_400', - 'o3_250', 'o3_150', 'o3_50', + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', # Sulfur dioxide (so2) - 'so2_1000', 'so2_700', 'so2_400', - 'so2_250', 'so2_150', 'so2_50', - - # Geopotential height (z) - 'z_1000', 'z_700', 'z_400', - 'z_250', 'z_150', 'z_50', + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', ] target : [ # Surface variables - '2t', 'u10', 'v10', 'msl', 'pm1', 'pm2p5', 'pm10', + 'pm1', 'pm2p5', 'pm10', 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', - # Specific humidity (q) - 'q_1000', 'q_700', 'q_400', - 'q_250', 'q_150', 'q_50', - - # Temperature (t) - 't_1000', 't_700', 't_400', - 't_250', 't_150', 't_50', - - # U wind (u) - 'u_1000', 'u_700', 'u_400', - 'u_250', 'u_150', 'u_50', - - # V wind (v) - 'v_1000', 'v_700', 'v_400', - 'v_250', 'v_150', 'v_50', - # Ozone (o3) - 'o3_1000', 'o3_700', 'o3_400', - 'o3_250', 'o3_150', 'o3_50', + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', # Sulfur dioxide (so2) - 'so2_1000', 'so2_700', 'so2_400', - 'so2_250', 'so2_150', 'so2_50', - - # Geopotential height (z) - 'z_1000', 'z_700', 'z_400', - 'z_250', 'z_150', 'z_50', - + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', ] # source_exclude : [] # target_exclude : [] variables: [ # Surface variables - '2t', 'u10', 'v10', 'msl', 'pm1', 'pm2p5', 'pm10', + 'pm1', 'pm2p5', 'pm10', 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', - # Specific humidity (q) - 'q_1000', 'q_925', 'q_850', 'q_700', 'q_600', 'q_500', - 'q_400', 'q_300', 'q_250', 'q_200', 'q_150', 'q_100', 'q_50', - - # Temperature (t) - 't_1000', 't_925', 't_850', 't_700', 't_600', 't_500', - 't_400', 't_300', 't_250', 't_200', 't_150', 't_100', 't_50', - - # U wind (u) - 'u_1000', 'u_925', 'u_850', 'u_700', 'u_600', 'u_500', - 'u_400', 'u_300', 'u_250', 'u_200', 'u_150', 'u_100', 'u_50', - - # V wind (v) - 'v_1000', 'v_925', 'v_850', 'v_700', 'v_600', 'v_500', - 'v_400', 'v_300', 'v_250', 'v_200', 'v_150', 'v_100', 'v_50', - # Ozone (o3) - 'o3_1000', 'o3_925', 'o3_850', 'o3_700', 'o3_600', 'o3_500', - 'o3_400', 'o3_300', 'o3_250', 'o3_200', 'o3_150', 'o3_100', 'o3_50', + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', # Sulfur dioxide (so2) 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', - - # Geopotential height (z) - 'z_1000', 'z_925', 'z_850', 'z_700', 'z_600', 'z_500', - 'z_400', 'z_250', 'z_200', 'z_150', 'z_100', 'z_50', ] loss_weight : 1. masking_rate : 0.6 @@ -121,11 +60,11 @@ CAMSEAC4 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 - num_blocks : 2 + dim_embed : 512 + num_blocks : 4 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 631d9fec9..184478d21 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -345,6 +345,7 @@ def __init__( dim_internal = dim_embed * hidden_factor # norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm enl = ens_num_layers + # dim_out = 72 self.pred_heads = torch.nn.ModuleList() for i in range(ens_size): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 585df149b..f88150ded 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -431,31 +431,95 @@ def print_num_parameters(self) -> None: print("-----------------") ######################################### + ### original version + # def load(self, run_id: str, epoch: str = -1) -> None: + # """Loads model state from checkpoint and checks for missing and unused keys. + # Args: + # run_id : model_id of the trained model + # epoch : The epoch to load. Default (-1) is the latest epoch + # """ + + # path_run = Path(self.cf.model_path) / run_id + # epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" + # filename = f"{run_id}_{epoch_id}.chkpt" + + # params = torch.load( + # path_run / filename, map_location=torch.device("cpu"), weights_only=True + # ) + # params_renamed = {} + # for k in params.keys(): + # params_renamed[k.replace("module.", "")] = params[k] + # mkeys, ukeys = self.load_state_dict(params_renamed, strict=False) + # # mkeys, ukeys = self.load_state_dict( params, strict=False) + + # if len(mkeys) > 0: + # logger.warning(f"Missing keys when loading model: {mkeys}") + + # if len(ukeys) > 0: + # logger.warning(f"Unused keys when loading model: {mkeys}") + + ### Version 2 def load(self, run_id: str, epoch: str = -1) -> None: - """Loads model state from checkpoint and checks for missing and unused keys. - Args: - run_id : model_id of the trained model - epoch : The epoch to load. Default (-1) is the latest epoch + """Loads model state from checkpoint, matching weights by size + and partially copying overlapping dimensions if needed. """ - path_run = Path(self.cf.model_path) / run_id epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" filename = f"{run_id}_{epoch_id}.chkpt" - params = torch.load( + checkpoint = torch.load( path_run / filename, map_location=torch.device("cpu"), weights_only=True ) - params_renamed = {} - for k in params.keys(): - params_renamed[k.replace("module.", "")] = params[k] - mkeys, ukeys = self.load_state_dict(params_renamed, strict=False) - # mkeys, ukeys = self.load_state_dict( params, strict=False) + + # remove "module." prefix + params_renamed = {k.replace("module.", ""): v for k, v in checkpoint.items()} + + model_dict = self.state_dict() + updated_dict = {} + + skipped_params = [] + partial_params = [] + + for k, v in params_renamed.items(): + if k not in model_dict: + continue + + target = model_dict[k] + + if target.shape == v.shape: + # exact match → copy fully + updated_dict[k] = v + + else: + # try partial overlap + min_shape = tuple(min(s1, s2) for s1, s2 in zip(target.shape, v.shape)) + + if all(m > 0 for m in min_shape): + # copy overlapping region only + new_tensor = target.clone() + slices = tuple(slice(0, m) for m in min_shape) + new_tensor[slices] = v[slices] + updated_dict[k] = new_tensor + partial_params.append((k, v.shape, target.shape)) + else: + skipped_params.append((k, v.shape, target.shape)) + + # actually load + mkeys, ukeys = self.load_state_dict(updated_dict, strict=False) if len(mkeys) > 0: - logger.warning(f"Missing keys when loading model: {mkeys}") + logger.warning(f"Missing keys (random init kept): {mkeys}") if len(ukeys) > 0: - logger.warning(f"Unused keys when loading model: {mkeys}") + logger.warning(f"Unused keys (not in model): {ukeys}") + + if len(skipped_params) > 0: + for k, old_shape, new_shape in skipped_params: + logger.warning(f"Skipped {k}: checkpoint {old_shape} vs model {new_shape}") + + if len(partial_params) > 0: + for k, old_shape, new_shape in partial_params: + logger.warning(f"Partially loaded {k}: {old_shape} → {new_shape}") ######################################### def forward_jac(self, *args): @@ -575,10 +639,22 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: # ) # scatter write to reorder from per stream to per cell ordering + ################################################################ + # scatter write to reorder from per stream to per cell ordering + # but first ensure pe_embed is big enough + if idxs_pe.max() >= model_params.pe_embed.shape[0]: + old_size, emb_dim = model_params.pe_embed.shape + new_size = int(idxs_pe.max().item()) + 1 + new_pe = model_params.pe_embed.new_empty((new_size, emb_dim)) + new_pe[:old_size] = model_params.pe_embed + torch.nn.init.normal_(new_pe[old_size:], mean=0.0, std=0.02) + model_params.pe_embed = torch.nn.Parameter(new_pe, requires_grad=True) + print(f"[INFO] Resized pe_embed from {old_size} → {new_size}", flush=True) + + ################################################################ tokens_all.scatter_(0, idxs, x_embed + model_params.pe_embed[idxs_pe]) return tokens_all - ######################################### def assimilate_local( self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 09ad36906..caadecfce 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -79,8 +79,8 @@ def inference(self, cf, run_id_trained, epoch): # general initalization # Asma: This is a quick fix, won't be useful in the future - cf.batch_size_per_gpu = 1 - cf.batch_size_validation_per_gpu = 1 + # cf.batch_size_per_gpu = 1 + # cf.batch_size_validation_per_gpu = 1 ######## End of stupid code self.init(cf) @@ -298,8 +298,11 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): self.loss_calculator_val = LossCalculator(cf=cf, stage=VAL, device=self.devices[0]) # recover epoch when continuing run + print(f"self.num_ranks_original = {self.num_ranks_original}") if self.num_ranks_original is None: epoch_base = int(self.cf.istep / len(self.data_loader)) + elif epoch_contd is not None: + epoch_base = epoch_contd + 1 else: len_per_rank = ( len(self.dataset) // (self.num_ranks_original * cf.batch_size_per_gpu) @@ -308,6 +311,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): self.cf.istep / (min(len_per_rank, cf.samples_per_epoch) * self.num_ranks_original) ) + # torch.autograd.set_detect_anomaly(True) if cf.forecast_policy is not None: torch._dynamo.config.optimize_ddp = False @@ -322,6 +326,10 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): if cf.val_initial: self.validate(-1) + print(f"before training") + print(f"epoch_base = {epoch_base} cf.num_epochs = {cf.num_epochs}") + print(f"epoch_contd = {epoch_contd}") + for epoch in range(epoch_base, cf.num_epochs): logger.info(f"Epoch {epoch} of {cf.num_epochs}: train.") self.train(epoch) diff --git a/src/weathergen/utils/config.py b/src/weathergen/utils/config.py index 3781fdf6a..cebabf270 100644 --- a/src/weathergen/utils/config.py +++ b/src/weathergen/utils/config.py @@ -63,6 +63,7 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> Load a configuration file from a given run_id and epoch. If run_id is a full path, loads it from the full path. """ + print(f"model_path = {model_path}") if Path(run_id).exists(): # load from the full path if a full path is provided fname = Path(run_id) _logger.info(f"Loading config from provided full run_id path: {fname}") @@ -75,6 +76,7 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> ) model_path = Path(model_path) fname = model_path / run_id / _get_model_config_file_name(run_id, epoch) + print(f"fname = {fname}") assert fname.exists(), ( "The fallback path to the model does not exist. Please provide a `model_path`." ) From 8a25482a4a4e1afaff0073e8802c9f077f4932a7 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 1 Sep 2025 14:44:47 +0200 Subject: [PATCH 10/23] fixed bugs in _get() method --- src/weathergen/datasets/data_reader_cams.py | 170 ++++++++++++++------ 1 file changed, 120 insertions(+), 50 deletions(-) diff --git a/src/weathergen/datasets/data_reader_cams.py b/src/weathergen/datasets/data_reader_cams.py index 7dca62bec..34b2001bd 100644 --- a/src/weathergen/datasets/data_reader_cams.py +++ b/src/weathergen/datasets/data_reader_cams.py @@ -85,7 +85,7 @@ def __init__( if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: # print("inside skipping stream") - name = "plop" # stream_info["name"] + name = stream_info["name"] _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") super().__init__(tw_handler, stream_info) self.init_empty() @@ -207,15 +207,10 @@ def length(self) -> int: def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: """ Get data for temporal window - Parameters - ---------- - idx : int - Index of temporal window - channels_idx : np.array - Selection of channels + Returns ------- - data (coords, geoinfos, data, datetimes) + ReaderData providing coords, geoinfos, data, datetimes """ (t_idxs, dtr) = self._get_dataset_idxs(idx) @@ -225,50 +220,54 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) ) - # TODO: handle sub-sampling - # print(f"#2.5.1 inside DataReaderCams._get()", flush=True) - t_idxs_start = t_idxs[0] - t_idxs_end = t_idxs[-1] + 1 - - # datetime - datetimes = self.time[t_idxs_start:t_idxs_end] - - # =========== lat/lon coordinates + tiling to match time steps ========== - - # making lon, lat like a mesh for easier flattening later - lon2d, lat2d = np.meshgrid(self.lon, self.lat) - - # Flatten to match (lat, lon) storage order in your array - lat = lat2d.flatten()[:, np.newaxis] # shape (241*480,) - lon = lon2d.flatten()[:, np.newaxis] - - lat = np.tile(lat, len(datetimes)) - lon = np.tile(lon, len(datetimes)) - - coords = np.concatenate([lat, lon], axis=0) + assert t_idxs[0] >= 0, "index must be non-negative" + t0 = t_idxs[0] + t1 = t_idxs[-1] + 1 # end is exclusive + T = t1 - t0 + + # channels to read + channels = np.array(self.colnames)[channels_idx].tolist() + + # --- read & shape data to match anemoi path: (T, C, G) -> (T, G, C) -> (T*G, C) + try: + data_per_channel = [] + for ch in channels: + # expect ds[ch] with shape (time, lat, lon) for this dataset + arr = np.asarray(self.ds[ch][t0:t1, :, :], dtype=np.float32) # (T, nlat, nlon) + if arr.ndim != 3: + raise ValueError(f"Expected 3D array (time, lat, lon) for '{ch}', got shape {arr.shape}") + _, nlat, nlon = arr.shape + data_per_channel.append(arr.reshape(T, nlat * nlon)) # (T, G) + + # stack channels to (T, C, G) + data_TCG = np.stack(data_per_channel, axis=1) # (T, C, G) + # move channels to last and flatten time: (T, G, C) -> (T*G, C) + data = np.transpose(data_TCG, (0, 2, 1)).reshape(T * (nlat * nlon), len(channels)).astype(np.float32) + + except MissingDateError as e: + _logger.debug(f"Date not present in CAMS dataset: {str(e)}. Skipping.") + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) - # data - channels = np.array(self.colnames)[channels_idx] - # print(f"#2.5.2 inside DataReaderCams._get() before data", flush=True) - # for ch_ in channels: - # print(f"self.ds[{ch_}] = {self.ds[ch_].shape}", flush=True) - data_reshaped = [ - np.asarray(self.ds[ch_][t_idxs_start:t_idxs_end, :, :]).reshape(-1, 1) for ch_ in channels - ] - # print(f"#2.5.3 inside DataReaderCams._get() after data", flush=True) - data = np.concatenate(data_reshaped, axis=1) + # --- coords: build flattened [lat, lon] once, then repeat for each time + lon2d, lat2d = np.meshgrid(np.asarray(self.lon), np.asarray(self.lat)) # shapes (nlat, nlon) + G = lon2d.size + latlon_flat = np.column_stack([lat2d.ravel(order="C"), lon2d.ravel(order="C")]) # (G, 2); LAT first, LON second + coords = np.vstack([latlon_flat] * T) # (T*G, 2) - # time coordinate repeated to match grid points - datetimes = np.repeat(datetimes, len(data) // len(t_idxs)) + # --- datetimes: repeat each timestamp for all grid points + datetimes = np.repeat(self.time[t0:t1], G) - # empty geoinfos - geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) - # print(f"self.lon.shape = {self.lon.shape} self.lat.shape = {self.lat.shape}",flush=True) - # print(f"lon.shape = {lon.shape} lat.shape = {lat.shape}",flush=True) - # print(f"datetimes.shape = {datetimes.shape}",flush=True) - # print(f"data.shape = {data.shape}",flush=True) - # print(f"coords.shape = {coords.shape}",flush=True) + # --- empty geoinfos (match anemoi) + geoinfos = np.zeros((data.shape[0], 0), dtype=np.float32) + # debug (optional) + print(f"from CAMS, T={T}, nlat={nlat}, nlon={nlon}, G={G}") + print(f"data_TCG.shape = {(T, len(channels), G)}") + print(f"final data.shape = {data.shape}") + print(f"coords.shape = {coords.shape}") + print(f"len(datetimes) = {len(datetimes)}") rd = ReaderData( coords=coords, @@ -277,5 +276,76 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: datetimes=datetimes, ) check_reader_data(rd, dtr) - - return rd \ No newline at end of file + return rd + + # def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + # """ + # Get data for temporal window + # Parameters + # ---------- + # idx : int + # Index of temporal window + # channels_idx : np.array + # Selection of channels + # Returns + # ------- + # data (coords, geoinfos, data, datetimes) + # """ + + # (t_idxs, dtr) = self._get_dataset_idxs(idx) + + # if self.ds is None or self.len == 0 or len(t_idxs) == 0: + # return ReaderData.empty( + # num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + # ) + + # # TODO: handle sub-sampling + # # print(f"#2.5.1 inside DataReaderCams._get()", flush=True) + # t_idxs_start = t_idxs[0] + # t_idxs_end = t_idxs[-1] + 1 + + # # datetime + # datetimes = self.time[t_idxs_start:t_idxs_end] + + # # =========== lat/lon coordinates + tiling to match time steps ========== + + # # making lon, lat like a mesh for easier flattening later + # lon2d, lat2d = np.meshgrid(self.lon, self.lat) + + # print(f"len(self.lon) = {len(self.lon)}") + # print(f"len(self.lat) = {len(self.lat)}") + # print(f"len(datetimes) = {len(datetimes)}") + + # # Flatten to match (lat, lon) storage order in your array + # lat = lat2d.flatten()[:, np.newaxis] # shape (241*480,) + # lon = lon2d.flatten()[:, np.newaxis] + + # lat = np.tile(lat, len(datetimes)) + # lon = np.tile(lon, len(datetimes)) + + # coords = np.concatenate([lat, lon], axis=0) + + # # data + # channels = np.array(self.colnames)[channels_idx] + + # data_reshaped = [ + # np.asarray(self.ds[ch_][t_idxs_start:t_idxs_end, :, :]).reshape(-1, 1) for ch_ in channels + # ] + # data = np.concatenate(data_reshaped, axis=1) + + # # time coordinate repeated to match grid points + # datetimes = np.repeat(datetimes, len(data) // len(t_idxs)) + # print(f"len(datetimes) = {len(datetimes)}") + + # # empty geoinfos + # geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + + # rd = ReaderData( + # coords=coords, + # geoinfos=geoinfos, + # data=data, + # datetimes=datetimes, + # ) + # check_reader_data(rd, dtr) + + # return rd \ No newline at end of file From 96616f2d378c4d1af0ce00ed95f7324a6716be0b Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 1 Sep 2025 14:46:04 +0200 Subject: [PATCH 11/23] levante cartopy assets --- packages/evaluate/src/weathergen/evaluate/plotter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index 85e2ba0a4..617095664 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -12,7 +12,7 @@ from weathergen.utils.config import _load_private_conf -work_dir = Path(_load_private_conf(None)["path_shared_working_dir"]) / "assets/cartopy" +work_dir = Path("./assets/cartopy" ) # Path(_load_private_conf(None)["path_shared_working_dir"]) / "assets/cartopy" cartopy.config["data_dir"] = str(work_dir) cartopy.config["pre_existing_data_dir"] = str(work_dir) From f194f77ba8bd4b0ec38c6f29cf1f0aa3bd554468 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 1 Sep 2025 14:47:05 +0200 Subject: [PATCH 12/23] removed custom load method --- src/weathergen/model/model.py | 134 +++++++++++++++++----------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 574e035dc..a735adbe2 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -438,94 +438,94 @@ def print_num_parameters(self) -> None: ######################################### ### original version + def load(self, run_id: str, epoch: str = -1) -> None: + """Loads model state from checkpoint and checks for missing and unused keys. + Args: + run_id : model_id of the trained model + epoch : The epoch to load. Default (-1) is the latest epoch + """ + + path_run = Path(self.cf.model_path) / run_id + epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" + filename = f"{run_id}_{epoch_id}.chkpt" + + params = torch.load( + path_run / filename, map_location=torch.device("cpu"), weights_only=True + ) + params_renamed = {} + for k in params.keys(): + params_renamed[k.replace("module.", "")] = params[k] + mkeys, ukeys = self.load_state_dict(params_renamed, strict=False) + # mkeys, ukeys = self.load_state_dict( params, strict=False) + + if len(mkeys) > 0: + logger.warning(f"Missing keys when loading model: {mkeys}") + + if len(ukeys) > 0: + logger.warning(f"Unused keys when loading model: {mkeys}") + + # ### Version 2 # def load(self, run_id: str, epoch: str = -1) -> None: - # """Loads model state from checkpoint and checks for missing and unused keys. - # Args: - # run_id : model_id of the trained model - # epoch : The epoch to load. Default (-1) is the latest epoch + # """Loads model state from checkpoint, matching weights by size + # and partially copying overlapping dimensions if needed. # """ - # path_run = Path(self.cf.model_path) / run_id # epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" # filename = f"{run_id}_{epoch_id}.chkpt" - # params = torch.load( + # checkpoint = torch.load( # path_run / filename, map_location=torch.device("cpu"), weights_only=True # ) - # params_renamed = {} - # for k in params.keys(): - # params_renamed[k.replace("module.", "")] = params[k] - # mkeys, ukeys = self.load_state_dict(params_renamed, strict=False) - # # mkeys, ukeys = self.load_state_dict( params, strict=False) - # if len(mkeys) > 0: - # logger.warning(f"Missing keys when loading model: {mkeys}") + # # remove "module." prefix + # params_renamed = {k.replace("module.", ""): v for k, v in checkpoint.items()} - # if len(ukeys) > 0: - # logger.warning(f"Unused keys when loading model: {mkeys}") + # model_dict = self.state_dict() + # updated_dict = {} - ### Version 2 - def load(self, run_id: str, epoch: str = -1) -> None: - """Loads model state from checkpoint, matching weights by size - and partially copying overlapping dimensions if needed. - """ - path_run = Path(self.cf.model_path) / run_id - epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" - filename = f"{run_id}_{epoch_id}.chkpt" + # skipped_params = [] + # partial_params = [] - checkpoint = torch.load( - path_run / filename, map_location=torch.device("cpu"), weights_only=True - ) + # for k, v in params_renamed.items(): + # if k not in model_dict: + # continue - # remove "module." prefix - params_renamed = {k.replace("module.", ""): v for k, v in checkpoint.items()} + # target = model_dict[k] - model_dict = self.state_dict() - updated_dict = {} + # if target.shape == v.shape: + # # exact match → copy fully + # updated_dict[k] = v - skipped_params = [] - partial_params = [] - - for k, v in params_renamed.items(): - if k not in model_dict: - continue + # else: + # # try partial overlap + # min_shape = tuple(min(s1, s2) for s1, s2 in zip(target.shape, v.shape)) - target = model_dict[k] + # if all(m > 0 for m in min_shape): + # # copy overlapping region only + # new_tensor = target.clone() + # slices = tuple(slice(0, m) for m in min_shape) + # new_tensor[slices] = v[slices] + # updated_dict[k] = new_tensor + # partial_params.append((k, v.shape, target.shape)) + # else: + # skipped_params.append((k, v.shape, target.shape)) - if target.shape == v.shape: - # exact match → copy fully - updated_dict[k] = v + # # actually load + # mkeys, ukeys = self.load_state_dict(updated_dict, strict=False) - else: - # try partial overlap - min_shape = tuple(min(s1, s2) for s1, s2 in zip(target.shape, v.shape)) - - if all(m > 0 for m in min_shape): - # copy overlapping region only - new_tensor = target.clone() - slices = tuple(slice(0, m) for m in min_shape) - new_tensor[slices] = v[slices] - updated_dict[k] = new_tensor - partial_params.append((k, v.shape, target.shape)) - else: - skipped_params.append((k, v.shape, target.shape)) - - # actually load - mkeys, ukeys = self.load_state_dict(updated_dict, strict=False) - - if len(mkeys) > 0: - logger.warning(f"Missing keys (random init kept): {mkeys}") + # if len(mkeys) > 0: + # logger.warning(f"Missing keys (random init kept): {mkeys}") - if len(ukeys) > 0: - logger.warning(f"Unused keys (not in model): {ukeys}") + # if len(ukeys) > 0: + # logger.warning(f"Unused keys (not in model): {ukeys}") - if len(skipped_params) > 0: - for k, old_shape, new_shape in skipped_params: - logger.warning(f"Skipped {k}: checkpoint {old_shape} vs model {new_shape}") + # if len(skipped_params) > 0: + # for k, old_shape, new_shape in skipped_params: + # logger.warning(f"Skipped {k}: checkpoint {old_shape} vs model {new_shape}") - if len(partial_params) > 0: - for k, old_shape, new_shape in partial_params: - logger.warning(f"Partially loaded {k}: {old_shape} → {new_shape}") + # if len(partial_params) > 0: + # for k, old_shape, new_shape in partial_params: + # logger.warning(f"Partially loaded {k}: {old_shape} → {new_shape}") ######################################### def forward_jac(self, *args): From 489d4bb0f6c9b49d54a1dc6ee2eb870f96d6bab1 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:01:00 +0200 Subject: [PATCH 13/23] changed some training params --- config/cams_eac4_config.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/config/cams_eac4_config.yml b/config/cams_eac4_config.yml index 2a0346497..9fa94d96e 100644 --- a/config/cams_eac4_config.yml +++ b/config/cams_eac4_config.yml @@ -5,16 +5,16 @@ start_date: 200301010000 # start_date_val: 202201010000 # end_date_val: 202205300000 -num_epochs: 63 # 10 +num_epochs: 45 # 10 -step_hrs: 1 +# samples_per_epoch: 700 # HERE +# samples_per_validation: 200 # HERE +# shuffle: True -samples_per_epoch: 1500 -samples_per_validation: 300 -shuffle: True - -loader_num_workers: 16 +loader_num_workers: 4 masking_rate: 0.8 +# forecast_offset : 1 # HERE + with_mixed_precision: True \ No newline at end of file From ea2eea66c99876a687290ce4016eba6207455bd1 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:01:38 +0200 Subject: [PATCH 14/23] deleted unecessary files --- config/cams_eac4_config_test.yml | 21 -------------- config/era5_d4fcwk5b_config.yml | 48 -------------------------------- 2 files changed, 69 deletions(-) delete mode 100644 config/cams_eac4_config_test.yml delete mode 100644 config/era5_d4fcwk5b_config.yml diff --git a/config/cams_eac4_config_test.yml b/config/cams_eac4_config_test.yml deleted file mode 100644 index af557d0ee..000000000 --- a/config/cams_eac4_config_test.yml +++ /dev/null @@ -1,21 +0,0 @@ -streams_directory: "./config/streams/streams_cams_eac4/" - -start_date: 201710010000 -end_date: 202112312100 -start_date_val: 202201010000 -end_date_val: 202205312100 - -num_epochs: 1 # 10 -token_size: 64 - -step_hrs: 1 - -samples_per_epoch: 17 # works # 100 works -samples_per_validation: 17 -shuffle: True - -loader_num_workers: 16 - -masking_rate: 0.8 - -with_mixed_precision: True \ No newline at end of file diff --git a/config/era5_d4fcwk5b_config.yml b/config/era5_d4fcwk5b_config.yml deleted file mode 100644 index aead5cd89..000000000 --- a/config/era5_d4fcwk5b_config.yml +++ /dev/null @@ -1,48 +0,0 @@ -ae_local_dim_embed: 1024 -ae_local_num_blocks: 0 -ae_local_num_queries: 1 -ae_global_num_blocks: 4 - -forecast_offset: 1 -forecast_steps: 2 -forecast_policy: "fixed" -forecast_att_dense_rate: 1.0 - -fe_num_blocks: 8 - -len_hrs: 6 -step_hrs: 6 - -grad_clip: 1.0 -lr_max: 0.00005 - -istep: 12288 -data_loader_rng_seed: 1752001227 -run_id: "d4fcwk5b" - -### The following was in path: WeatherGenerator-private/hpc/levante/config/era5_d4fcwk5b_config.yml -### Which I deleted (for cleaning) -# ae_local_dim_embed: 1024 -# ae_local_num_blocks: 0 -# ae_local_num_queries: 1 -# # ae_global_num_blocks: 4 - -# ae_global_num_blocks: 12 - - -# forecast_offset: 1 -# forecast_steps: 2 -# forecast_policy: "fixed" -# forecast_att_dense_rate: 1.0 - -# fe_num_blocks: 8 - -# len_hrs: 6 -# step_hrs: 6 - -# grad_clip: 1.0 -# lr_max: 0.00005 - -# istep: 12288 -# data_loader_rng_seed: 1752001227 -# run_id: "d4fcwk5b" From a2357178cf84fc4c6ffc7def133c72cbc73b11e7 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:03:05 +0200 Subject: [PATCH 15/23] title: changes to data and local embeddings details: - added more variables - changed params to optimal ones --- .../streams/streams_cams_eac4/cams_eac4.yml | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/config/streams/streams_cams_eac4/cams_eac4.yml b/config/streams/streams_cams_eac4/cams_eac4.yml index 09601ba06..b3ca70f7d 100644 --- a/config/streams/streams_cams_eac4/cams_eac4.yml +++ b/config/streams/streams_cams_eac4/cams_eac4.yml @@ -22,6 +22,18 @@ CAMSEAC4 : # Sulfur dioxide (so2) 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', ] target : [ # Surface variables @@ -35,6 +47,18 @@ CAMSEAC4 : # Sulfur dioxide (so2) 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', ] # source_exclude : [] # target_exclude : [] @@ -50,21 +74,33 @@ CAMSEAC4 : # Sulfur dioxide (so2) 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', ] loss_weight : 1. masking_rate : 0.6 masking_rate_none : 0.05 - token_size : 32 + token_size : 16 tokenize_spacetime : True embed : net : transformer num_tokens : 1 num_heads : 8 dim_embed : 512 - num_blocks : 4 + num_blocks : 2 embed_target_coords : net : linear - dim_embed : 512 + dim_embed : 512 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 From e2ea48d0db07ec24917d3bd6ff3c492f66a456ea Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:22:45 +0200 Subject: [PATCH 16/23] title: handling _get() hanging and removing unused code description: - added prints for when the _get() methods hangs during variables access - removed unused code - added chuncks for zarr data reading - added scripts that breaks the data reading when it hangs for more than 30 seconds --- src/weathergen/datasets/data_reader_cams.py | 136 ++++++-------------- 1 file changed, 38 insertions(+), 98 deletions(-) diff --git a/src/weathergen/datasets/data_reader_cams.py b/src/weathergen/datasets/data_reader_cams.py index 34b2001bd..8b32f3393 100644 --- a/src/weathergen/datasets/data_reader_cams.py +++ b/src/weathergen/datasets/data_reader_cams.py @@ -16,6 +16,26 @@ check_reader_data, ) +############################################################################ +import os, time +from typing import Sequence + +def _now_ms() -> int: + return int(time.time() * 1000) + +def _pfx() -> str: + # Helpful when multiple workers/ranks print at once + return f"[DATAREADER DEBUG:{os.environ.get('RANK', '?')}/{os.getpid()}]" + +import signal + +class _Timeout(Exception): pass +def _alarm_handler(signum, frame): raise _Timeout() + +signal.signal(signal.SIGALRM, _alarm_handler) + +############################################################################ + _logger = logging.getLogger(__name__) @@ -42,7 +62,7 @@ def __init__( # ======= Reading the Dataset ================ # Open the dataset using Xarray with Zarr engine - self.ds = xr.open_dataset(filename, engine="zarr") + self.ds = xr.open_dataset(filename, engine="zarr", chunks={"time": 24}) # Column (variable) names and indices self.colnames = stream_info["variables"] # list(self.ds) @@ -70,21 +90,7 @@ def __init__( end_ds = np.datetime64(self.time[-1]) self.temporal_frequency = self.time[1] - self.time[0] - # # Skip stream if it doesn't intersect with time window - # print(f"start_ds = {start_ds}") - # print(f"tw_handler.t_end = {tw_handler.t_end}") - # print(f"end_ds = {end_ds}") - # print(f"tw_handler.t_start = {tw_handler.t_start}") - # """ - # 0: start_ds = 2017-10-01T00:00:00.000000000 - # 0: tw_handler.t_end = 2017-01-03T00:00:00.000000 - # 0: end_ds = 2022-05-31T21:00:00.000000000 - # 0: tw_handler.t_start = 2017-01-01T00:00:00.000000 - - # """ - if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: - # print("inside skipping stream") name = stream_info["name"] _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") super().__init__(tw_handler, stream_info) @@ -202,7 +208,6 @@ def length(self) -> int: """ return self.len - @override def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: """ @@ -233,7 +238,22 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: data_per_channel = [] for ch in channels: # expect ds[ch] with shape (time, lat, lon) for this dataset - arr = np.asarray(self.ds[ch][t0:t1, :, :], dtype=np.float32) # (T, nlat, nlon) + + ################################################################################### + + signal.alarm(30) # seconds + arr = 0 + try: + arr = np.asarray(self.ds[ch][t0:t1, :, :], dtype=np.float32) # (T, nlat, nlon) + except _Timeout: + print(f"{_pfx()} idx={idx} TIMEOUT while reading channel '{ch}' [{t0}:{t1}] after 30s", flush=True) + print(f"{_pfx()} idx={idx} TIMEOUT time steps: {self.time[t0:t1]}", flush=True) + print(f"{_pfx()} idx={idx} TIMEOUT data: {arr}", flush=True) + + finally: + signal.alarm(0) # always cancel alarm + + ################################################################################### if arr.ndim != 3: raise ValueError(f"Expected 3D array (time, lat, lon) for '{ch}', got shape {arr.shape}") _, nlat, nlon = arr.shape @@ -244,12 +264,11 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: # move channels to last and flatten time: (T, G, C) -> (T*G, C) data = np.transpose(data_TCG, (0, 2, 1)).reshape(T * (nlat * nlon), len(channels)).astype(np.float32) - except MissingDateError as e: + except Exception as e: _logger.debug(f"Date not present in CAMS dataset: {str(e)}. Skipping.") return ReaderData.empty( num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) ) - # --- coords: build flattened [lat, lon] once, then repeat for each time lon2d, lat2d = np.meshgrid(np.asarray(self.lon), np.asarray(self.lat)) # shapes (nlat, nlon) G = lon2d.size @@ -262,13 +281,6 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: # --- empty geoinfos (match anemoi) geoinfos = np.zeros((data.shape[0], 0), dtype=np.float32) - # debug (optional) - print(f"from CAMS, T={T}, nlat={nlat}, nlon={nlon}, G={G}") - print(f"data_TCG.shape = {(T, len(channels), G)}") - print(f"final data.shape = {data.shape}") - print(f"coords.shape = {coords.shape}") - print(f"len(datetimes) = {len(datetimes)}") - rd = ReaderData( coords=coords, geoinfos=geoinfos, @@ -277,75 +289,3 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: ) check_reader_data(rd, dtr) return rd - - # def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: - # """ - # Get data for temporal window - # Parameters - # ---------- - # idx : int - # Index of temporal window - # channels_idx : np.array - # Selection of channels - # Returns - # ------- - # data (coords, geoinfos, data, datetimes) - # """ - - # (t_idxs, dtr) = self._get_dataset_idxs(idx) - - # if self.ds is None or self.len == 0 or len(t_idxs) == 0: - # return ReaderData.empty( - # num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) - # ) - - # # TODO: handle sub-sampling - # # print(f"#2.5.1 inside DataReaderCams._get()", flush=True) - # t_idxs_start = t_idxs[0] - # t_idxs_end = t_idxs[-1] + 1 - - # # datetime - # datetimes = self.time[t_idxs_start:t_idxs_end] - - # # =========== lat/lon coordinates + tiling to match time steps ========== - - # # making lon, lat like a mesh for easier flattening later - # lon2d, lat2d = np.meshgrid(self.lon, self.lat) - - # print(f"len(self.lon) = {len(self.lon)}") - # print(f"len(self.lat) = {len(self.lat)}") - # print(f"len(datetimes) = {len(datetimes)}") - - # # Flatten to match (lat, lon) storage order in your array - # lat = lat2d.flatten()[:, np.newaxis] # shape (241*480,) - # lon = lon2d.flatten()[:, np.newaxis] - - # lat = np.tile(lat, len(datetimes)) - # lon = np.tile(lon, len(datetimes)) - - # coords = np.concatenate([lat, lon], axis=0) - - # # data - # channels = np.array(self.colnames)[channels_idx] - - # data_reshaped = [ - # np.asarray(self.ds[ch_][t_idxs_start:t_idxs_end, :, :]).reshape(-1, 1) for ch_ in channels - # ] - # data = np.concatenate(data_reshaped, axis=1) - - # # time coordinate repeated to match grid points - # datetimes = np.repeat(datetimes, len(data) // len(t_idxs)) - # print(f"len(datetimes) = {len(datetimes)}") - - # # empty geoinfos - # geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) - - # rd = ReaderData( - # coords=coords, - # geoinfos=geoinfos, - # data=data, - # datetimes=datetimes, - # ) - # check_reader_data(rd, dtr) - - # return rd \ No newline at end of file From a504f6e18f75542f24e5f325026981c34851f6b5 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:24:49 +0200 Subject: [PATCH 17/23] added ERA5 in the stream --- config/streams/streams_cams_eac4/era5.yml | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 config/streams/streams_cams_eac4/era5.yml diff --git a/config/streams/streams_cams_eac4/era5.yml b/config/streams/streams_cams_eac4/era5.yml new file mode 100644 index 000000000..5561ef0c6 --- /dev/null +++ b/config/streams/streams_cams_eac4/era5.yml @@ -0,0 +1,37 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr'] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file From 6489ecb1c80db72d09733e9434b406eb78aff4d4 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:26:28 +0200 Subject: [PATCH 18/23] removed previously added code that expands the architecture --- src/weathergen/model/model.py | 78 +---------------------------------- 1 file changed, 1 insertion(+), 77 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a735adbe2..60c0d445c 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -437,7 +437,6 @@ def print_num_parameters(self) -> None: print("-----------------") ######################################### - ### original version def load(self, run_id: str, epoch: str = -1) -> None: """Loads model state from checkpoint and checks for missing and unused keys. Args: @@ -464,69 +463,6 @@ def load(self, run_id: str, epoch: str = -1) -> None: if len(ukeys) > 0: logger.warning(f"Unused keys when loading model: {mkeys}") - # ### Version 2 - # def load(self, run_id: str, epoch: str = -1) -> None: - # """Loads model state from checkpoint, matching weights by size - # and partially copying overlapping dimensions if needed. - # """ - # path_run = Path(self.cf.model_path) / run_id - # epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" - # filename = f"{run_id}_{epoch_id}.chkpt" - - # checkpoint = torch.load( - # path_run / filename, map_location=torch.device("cpu"), weights_only=True - # ) - - # # remove "module." prefix - # params_renamed = {k.replace("module.", ""): v for k, v in checkpoint.items()} - - # model_dict = self.state_dict() - # updated_dict = {} - - # skipped_params = [] - # partial_params = [] - - # for k, v in params_renamed.items(): - # if k not in model_dict: - # continue - - # target = model_dict[k] - - # if target.shape == v.shape: - # # exact match → copy fully - # updated_dict[k] = v - - # else: - # # try partial overlap - # min_shape = tuple(min(s1, s2) for s1, s2 in zip(target.shape, v.shape)) - - # if all(m > 0 for m in min_shape): - # # copy overlapping region only - # new_tensor = target.clone() - # slices = tuple(slice(0, m) for m in min_shape) - # new_tensor[slices] = v[slices] - # updated_dict[k] = new_tensor - # partial_params.append((k, v.shape, target.shape)) - # else: - # skipped_params.append((k, v.shape, target.shape)) - - # # actually load - # mkeys, ukeys = self.load_state_dict(updated_dict, strict=False) - - # if len(mkeys) > 0: - # logger.warning(f"Missing keys (random init kept): {mkeys}") - - # if len(ukeys) > 0: - # logger.warning(f"Unused keys (not in model): {ukeys}") - - # if len(skipped_params) > 0: - # for k, old_shape, new_shape in skipped_params: - # logger.warning(f"Skipped {k}: checkpoint {old_shape} vs model {new_shape}") - - # if len(partial_params) > 0: - # for k, old_shape, new_shape in partial_params: - # logger.warning(f"Partially loaded {k}: {old_shape} → {new_shape}") - ######################################### def forward_jac(self, *args): sources = args[:-1] @@ -645,22 +581,10 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: # ) # scatter write to reorder from per stream to per cell ordering - ################################################################ - # scatter write to reorder from per stream to per cell ordering - # but first ensure pe_embed is big enough - if idxs_pe.max() >= model_params.pe_embed.shape[0]: - old_size, emb_dim = model_params.pe_embed.shape - new_size = int(idxs_pe.max().item()) + 1 - new_pe = model_params.pe_embed.new_empty((new_size, emb_dim)) - new_pe[:old_size] = model_params.pe_embed - torch.nn.init.normal_(new_pe[old_size:], mean=0.0, std=0.02) - model_params.pe_embed = torch.nn.Parameter(new_pe, requires_grad=True) - print(f"[INFO] Resized pe_embed from {old_size} → {new_size}", flush=True) - - ################################################################ tokens_all.scatter_(0, idxs, x_embed + model_params.pe_embed[idxs_pe]) return tokens_all + ######################################### def assimilate_local( self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor From 040137ba83927e9bfc7bd5584225f3696f5a58b1 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 17 Sep 2025 16:27:27 +0200 Subject: [PATCH 19/23] params for finetuning on 8 forecast steps --- src/weathergen/run_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 6080a9039..7213f3216 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -92,8 +92,8 @@ def train_continue_from_args(argl: list[str]): if args.finetune_forecast: finetune_overwrite = dict( training_mode="forecast", - forecast_delta_hrs=0, # 12 - forecast_steps=1, # [j for j in range(1,9) for i in range(4)] + forecast_delta_hrs=6, # 12 + forecast_steps= 7, # [j for j in range(1,9) for i in range(4)] forecast_policy="fixed", # 'sequential_random' # 'fixed' #'sequential' #_random' forecast_freeze_model=True, forecast_att_dense_rate=1.0, # 0.25 From 74c98e819b58e26d0ccbb933f43f50b89d96bd88 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Fri, 19 Sep 2025 13:45:11 +0200 Subject: [PATCH 20/23] reading zarr data with merged levels --- config/cams_eac4_config.yml | 2 +- src/weathergen/datasets/data_reader_cams.py | 60 +++++++++++---------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/config/cams_eac4_config.yml b/config/cams_eac4_config.yml index 9fa94d96e..fa8eebcba 100644 --- a/config/cams_eac4_config.yml +++ b/config/cams_eac4_config.yml @@ -14,7 +14,7 @@ num_epochs: 45 # 10 loader_num_workers: 4 masking_rate: 0.8 -# forecast_offset : 1 # HERE +forecast_offset : 1 # HERE with_mixed_precision: True \ No newline at end of file diff --git a/src/weathergen/datasets/data_reader_cams.py b/src/weathergen/datasets/data_reader_cams.py index 8b32f3393..4b5b08e4a 100644 --- a/src/weathergen/datasets/data_reader_cams.py +++ b/src/weathergen/datasets/data_reader_cams.py @@ -16,7 +16,6 @@ check_reader_data, ) -############################################################################ import os, time from typing import Sequence @@ -60,10 +59,13 @@ def __init__( """ # ======= Reading the Dataset ================ + # open groups + ds_surface = xr.open_zarr(filename, group="surface", chunks={"time": 24}) + ds_profiles = xr.open_zarr(filename, group="profiles", chunks={"time": 24}) - # Open the dataset using Xarray with Zarr engine - self.ds = xr.open_dataset(filename, engine="zarr", chunks={"time": 24}) - + # merge along variables + self.ds = xr.merge([ds_surface, ds_profiles]) + # Column (variable) names and indices self.colnames = stream_info["variables"] # list(self.ds) self.cols_idx = np.array(list(np.arange(len(self.colnames)))) @@ -83,6 +85,7 @@ def __init__( # Extract coordinates and pressure level self.lat = _clip_lat(self.ds["latitude"].values) self.lon = _clip_lon(self.ds["longitude"].values) + self.levels = stream_info["pressure_levels"] # Time range in the dataset self.time = self.ds["time"].values @@ -91,6 +94,7 @@ def __init__( self.temporal_frequency = self.time[1] - self.time[0] if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: + # print("inside skipping stream") name = stream_info["name"] _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") super().__init__(tw_handler, stream_info) @@ -110,7 +114,6 @@ def __init__( self.start_idx = (tw_handler.t_start - start_ds).astype("timedelta64[ns]").astype(int) self.end_idx = (tw_handler.t_end - start_ds).astype("timedelta64[ns]").astype(int) + 1 - # Asma TODO check if self.len checks out # Number of time steps in selected range self.len = self.end_idx - self.start_idx + 1 @@ -187,9 +190,6 @@ def select(self, ch_filters: list[str]) -> (np.array, list[str]): return selected_colnames, selected_cols_idx - # # Asma TODO test it once kerchunk is ready - # def get_levels(self, channels: list[str]) -> list: - @override def init_empty(self) -> None: super().init_empty() @@ -217,7 +217,6 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: ------- ReaderData providing coords, geoinfos, data, datetimes """ - (t_idxs, dtr) = self._get_dataset_idxs(idx) if self.ds is None or self.len == 0 or len(t_idxs) == 0: @@ -230,45 +229,48 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: t1 = t_idxs[-1] + 1 # end is exclusive T = t1 - t0 + nlat = len(self.lat) + nlon = len(self.lon) # channels to read channels = np.array(self.colnames)[channels_idx].tolist() # --- read & shape data to match anemoi path: (T, C, G) -> (T, G, C) -> (T*G, C) + data_per_channel = [] try: - data_per_channel = [] for ch in channels: - # expect ds[ch] with shape (time, lat, lon) for this dataset - - ################################################################################### + ch_parts = ch.split("_") + # retrieving profile channels + if len(ch_parts) == 2 and ch_parts[1] in self.levels : + ch_ = ch_parts[0] + level=int(ch_parts[1]) + data_lazy = self.ds[ch_].sel(isobaricInhPa=level)[t0:t1, :, :].astype("float32") + # retrieving surface channels + else: + data_lazy = self.ds[ch][t0:t1, :, :].astype("float32") signal.alarm(30) # seconds - arr = 0 try: - arr = np.asarray(self.ds[ch][t0:t1, :, :], dtype=np.float32) # (T, nlat, nlon) + data = data_lazy.compute(scheduler='synchronous').values + data_per_channel.append(data.reshape(T, nlat * nlon)) # (T, G) except _Timeout: print(f"{_pfx()} idx={idx} TIMEOUT while reading channel '{ch}' [{t0}:{t1}] after 30s", flush=True) print(f"{_pfx()} idx={idx} TIMEOUT time steps: {self.time[t0:t1]}", flush=True) - print(f"{_pfx()} idx={idx} TIMEOUT data: {arr}", flush=True) + print(f"{_pfx()} idx={idx} TIMEOUT data: {data}", flush=True) finally: signal.alarm(0) # always cancel alarm - - ################################################################################### - if arr.ndim != 3: - raise ValueError(f"Expected 3D array (time, lat, lon) for '{ch}', got shape {arr.shape}") - _, nlat, nlon = arr.shape - data_per_channel.append(arr.reshape(T, nlat * nlon)) # (T, G) - - # stack channels to (T, C, G) - data_TCG = np.stack(data_per_channel, axis=1) # (T, C, G) - # move channels to last and flatten time: (T, G, C) -> (T*G, C) - data = np.transpose(data_TCG, (0, 2, 1)).reshape(T * (nlat * nlon), len(channels)).astype(np.float32) - + except Exception as e: _logger.debug(f"Date not present in CAMS dataset: {str(e)}. Skipping.") return ReaderData.empty( num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) ) + + # stack channels to (T, C, G) + data_TCG = np.stack(data_per_channel, axis=1) # (T, C, G) + # move channels to last and flatten time: (T, G, C) -> (T*G, C) + data = np.transpose(data_TCG, (0, 2, 1)).reshape(T * (nlat * nlon), len(channels)).astype(np.float32) + # --- coords: build flattened [lat, lon] once, then repeat for each time lon2d, lat2d = np.meshgrid(np.asarray(self.lon), np.asarray(self.lat)) # shapes (nlat, nlon) G = lon2d.size @@ -288,4 +290,4 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: datetimes=datetimes, ) check_reader_data(rd, dtr) - return rd + return rd \ No newline at end of file From 70cfe75fe076bca6826930451c90166fb51abf2f Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sun, 21 Sep 2025 16:32:38 +0200 Subject: [PATCH 21/23] fixed forgotten argument --- config/streams/streams_cams_eac4/cams_eac4.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/streams/streams_cams_eac4/cams_eac4.yml b/config/streams/streams_cams_eac4/cams_eac4.yml index b3ca70f7d..079396005 100644 --- a/config/streams/streams_cams_eac4/cams_eac4.yml +++ b/config/streams/streams_cams_eac4/cams_eac4.yml @@ -87,6 +87,7 @@ CAMSEAC4 : 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', ] + pressure_levels : ["50", "100", "150", "200", "250", "300", "400", "500", "600", "700", "850", "925", "1000"] loss_weight : 1. masking_rate : 0.6 masking_rate_none : 0.05 From 25c7cce513b28f10a3d758d325b4ed8282134bcd Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sun, 21 Sep 2025 16:33:08 +0200 Subject: [PATCH 22/23] longer time tolerance for dask compute --- src/weathergen/datasets/data_reader_cams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/data_reader_cams.py b/src/weathergen/datasets/data_reader_cams.py index 4b5b08e4a..6ba6621e8 100644 --- a/src/weathergen/datasets/data_reader_cams.py +++ b/src/weathergen/datasets/data_reader_cams.py @@ -248,12 +248,12 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: else: data_lazy = self.ds[ch][t0:t1, :, :].astype("float32") - signal.alarm(30) # seconds + signal.alarm(600) # seconds try: data = data_lazy.compute(scheduler='synchronous').values data_per_channel.append(data.reshape(T, nlat * nlon)) # (T, G) except _Timeout: - print(f"{_pfx()} idx={idx} TIMEOUT while reading channel '{ch}' [{t0}:{t1}] after 30s", flush=True) + print(f"{_pfx()} idx={idx} TIMEOUT while reading channel '{ch}' [{t0}:{t1}] after 600s", flush=True) print(f"{_pfx()} idx={idx} TIMEOUT time steps: {self.time[t0:t1]}", flush=True) print(f"{_pfx()} idx={idx} TIMEOUT data: {data}", flush=True) From c862eaa94f74d9aa049a70a5c906c7cafabfa9be Mon Sep 17 00:00:00 2001 From: sbAsma Date: Sun, 21 Sep 2025 16:34:15 +0200 Subject: [PATCH 23/23] evaluation parameters --- config/config_eval.yml | 161 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 config/config_eval.yml diff --git a/config/config_eval.yml b/config/config_eval.yml new file mode 100644 index 000000000..df105c63a --- /dev/null +++ b/config/config_eval.yml @@ -0,0 +1,161 @@ +verbose: true +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + +summary_plots : true +print_summary: false + +evaluation: + metrics : ["rmse"] + regions: ["global", "nhem"] +run_ids : + wtqfk9i5: + label: "CAMS EAC4 forecast epoch=25 dim_embed = 512 token_size = 16 num_blocks = 2" + epoch: 0 + rank: 0 + streams: + # ERA5: + # channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # plot_maps: true + # plot_histograms: true + # plot_animations: false + CAMSEAC4: + channels: [ + # Surface variables + 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Ozone (o3) + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', + ] + evaluation: + forecast_step: "all" + sample: "all" + plotting: + sample: [0] + forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + plot_maps: true + plot_histograms: true + plot_animations: true + # e8fzh2t1: + # label: "CAMS forecast finetune dim_embed = 256 token_size = 16 num_blocks = 2" + # epoch: 0 + # rank: 0 + # streams: + # # ERA5: + # # channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + # # evaluation: + # # forecast_step: "all" + # # sample: "all" + # # plotting: + # # sample: [0] + # # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # # plot_maps: true + # # plot_histograms: true + # # plot_animations: false + # CAMSEAC4: + # channels: [ + # # Surface variables + # 'pm1', 'pm2p5', 'pm10', + # 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # # Ozone (o3) + # 'go3_1000', 'go3_500', 'go3_250','go3_50', + + # # Sulfur dioxide (so2) + # 'so2_1000', 'so2_500', 'so2_250','so2_50', + + + # # Nitrogen monoxide (no) + # 'no_1000', 'no_500', 'no_250','no_50', + + # # Nitrogen dioxide (no2) + # 'no2_1000', 'no2_500', 'no2_250','no2_50', + + + # # Carbon monoxide (co) + # 'co_1000', 'co_500', 'co_250','co_50', + + # ] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # plot_maps: true + # plot_histograms: true + # plot_animations: true + # z6aup4r1: + # label: "CAMS forecast finetune dim_embed = 512 token_size = 32 num_blocks = 2" + # epoch: 0 + # rank: 0 + # streams: + # # ERA5: + # # channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + # # evaluation: + # # forecast_step: "all" + # # sample: "all" + # # plotting: + # # sample: [0] + # # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # # plot_maps: true + # # plot_histograms: true + # # plot_animations: false + # CAMSEAC4: + # channels: [ + # # Surface variables + # 'pm1', 'pm2p5', 'pm10', + # 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # # Ozone (o3) + # 'go3_1000', 'go3_500', 'go3_250','go3_50', + + # # Sulfur dioxide (so2) + # 'so2_1000', 'so2_500', 'so2_250','so2_50', + + + # # Nitrogen monoxide (no) + # 'no_1000', 'no_500', 'no_250','no_50', + + # # Nitrogen dioxide (no2) + # 'no2_1000', 'no2_500', 'no2_250','no2_50', + + + # # Carbon monoxide (co) + # 'co_1000', 'co_500', 'co_250','co_50', + + # ] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # plot_maps: true + # plot_histograms: true + # plot_animations: true \ No newline at end of file