From ad462841f633680351185e5a0a5db03572948a48 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 15 Oct 2025 08:55:34 +0200 Subject: [PATCH 1/3] store lead time in OutptDataset and make it available to evaluation --- packages/common/src/weathergen/common/io.py | 38 ++++++++++++++++++--- src/weathergen/utils/validation_io.py | 1 + 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 3e7594d1c..9c0ce8e13 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -148,6 +148,7 @@ class OutputDataset: channels: list[str] geoinfo_channels: list[str] + lead_time: int @functools.cached_property def arrays(self) -> dict[str, zarr.Array | NDArray]: @@ -186,6 +187,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "sample": [self.item_key.sample], "stream": [self.item_key.stream], "forecast_step": [self.item_key.forecast_step], + "lead_time": ("forecast_step", [self.lead_time]), "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data "valid_time": ("ipoint", times.astype("datetime64[ns]")), @@ -285,6 +287,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels + dataset_group.attrs["lead_time"] = dataset.lead_time def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords @@ -302,6 +305,14 @@ def _create_dataset(self, group: zarr.Group, name: str, array: NDArray): ) group.create_dataset(name, data=array, chunks=chunks) + @functools.cached_property + def example_key(self) -> ItemKey: + sample, example_sample = next(self.data_root.groups()) + stream, example_stream = next(example_sample.groups()) + fstep, example_item = next(example_stream.groups()) + + return ItemKey(sample, fstep, stream) + @functools.cached_property def samples(self) -> list[int]: """Query available samples in this zarr store.""" @@ -320,8 +331,17 @@ def forecast_steps(self) -> list[int]: # assume stream/samples/forecast_steps are orthogonal _, example_sample = next(self.data_root.groups()) _, example_stream = next(example_sample.groups()) + return list(example_stream.group_keys()) + @functools.cached_property + def lead_times(self) -> list[int]: + """Calculate available lead times from available forecast steps and len_hrs.""" + example_prediction = self.load_zarr(self.example_key).prediction + len_hrs = example_prediction.lead_time // self.example_key.forecast_step + + return [step * len_hrs for step in self.forecast_steps] + @dataclasses.dataclass class DataCoordinates: @@ -365,6 +385,7 @@ class OutputBatchData: sample_start: int forecast_offset: int + len_hrs: int @functools.cached_property def samples(self): @@ -415,8 +436,10 @@ def extract(self, key: ItemKey) -> OutputItem: "Number of channel names does not align with prediction data." ) + lead_time = self.len_hrs * key + if key.with_source: - source_dataset = self._extract_sources(offset_key.sample, stream_idx, key) + source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time) else: source_dataset = None @@ -425,9 +448,15 @@ def extract(self, key: ItemKey) -> OutputItem: return OutputItem( key=key, source=source_dataset, - target=OutputDataset("target", key, target_data, **dataclasses.asdict(data_coords)), + target=OutputDataset( + "target", key, target_data, lead_time=lead_time, **dataclasses.asdict(data_coords) + ), prediction=OutputDataset( - "prediction", key, preds_data, **dataclasses.asdict(data_coords) + "prediction", + key, + preds_data, + lead_time=lead_time, + **dataclasses.asdict(data_coords), ), ) @@ -487,7 +516,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels) - def _extract_sources(self, sample, stream_idx, key): + def _extract_sources(self, sample, stream_idx, key, lead_time): channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] @@ -506,6 +535,7 @@ def _extract_sources(self, sample, stream_idx, key): np.asarray(source.geoinfos), channels, geoinfo_channels, + lead_time ) _logger.debug(f"source shape: {source_dataset.data.shape}") diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e28563132..c0939cf8c 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -61,6 +61,7 @@ def write_output( geoinfo_channels, sample_start, cf.forecast_offset, + cf.len_hrs ) with io.ZarrIO(config.get_path_output(cf, epoch)) as writer: From 1c56cb00cfc6de48a2a3f88d9df7875a9c89a40c Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 15 Oct 2025 11:22:35 +0200 Subject: [PATCH 2/3] ruffed --- packages/common/src/weathergen/common/io.py | 2 +- src/weathergen/utils/validation_io.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 9c0ce8e13..d4676590b 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -535,7 +535,7 @@ def _extract_sources(self, sample, stream_idx, key, lead_time): np.asarray(source.geoinfos), channels, geoinfo_channels, - lead_time + lead_time, ) _logger.debug(f"source shape: {source_dataset.data.shape}") diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index c0939cf8c..b7108de5d 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -61,7 +61,7 @@ def write_output( geoinfo_channels, sample_start, cf.forecast_offset, - cf.len_hrs + cf.len_hrs, ) with io.ZarrIO(config.get_path_output(cf, epoch)) as writer: From 50fc8ab343f26d1caac888ac206ee55e457f5746 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 15 Oct 2025 12:17:53 +0200 Subject: [PATCH 3/3] addressed comments --- packages/common/src/weathergen/common/io.py | 23 +++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index d4676590b..32a63ecb1 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -148,7 +148,8 @@ class OutputDataset: channels: list[str] geoinfo_channels: list[str] - lead_time: int + # lead time in hours defined as forecast step * length of forecast step (len_hours) + lead_time_hrs: int @functools.cached_property def arrays(self) -> dict[str, zarr.Array | NDArray]: @@ -187,7 +188,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "sample": [self.item_key.sample], "stream": [self.item_key.stream], "forecast_step": [self.item_key.forecast_step], - "lead_time": ("forecast_step", [self.lead_time]), + "lead_time_hrs": ("forecast_step", [self.lead_time_hrs]), "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data "valid_time": ("ipoint", times.astype("datetime64[ns]")), @@ -287,7 +288,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels - dataset_group.attrs["lead_time"] = dataset.lead_time + dataset_group.attrs["lead_time_hrs"] = dataset.lead_time_hrs def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords @@ -338,7 +339,7 @@ def forecast_steps(self) -> list[int]: def lead_times(self) -> list[int]: """Calculate available lead times from available forecast steps and len_hrs.""" example_prediction = self.load_zarr(self.example_key).prediction - len_hrs = example_prediction.lead_time // self.example_key.forecast_step + len_hrs = example_prediction.lead_time_hrs // self.example_key.forecast_step return [step * len_hrs for step in self.forecast_steps] @@ -385,7 +386,7 @@ class OutputBatchData: sample_start: int forecast_offset: int - len_hrs: int + t_window_len_hours: int @functools.cached_property def samples(self): @@ -436,7 +437,7 @@ def extract(self, key: ItemKey) -> OutputItem: "Number of channel names does not align with prediction data." ) - lead_time = self.len_hrs * key + lead_time = self.t_window_len_hours * key.forecast_step if key.with_source: source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time) @@ -449,13 +450,17 @@ def extract(self, key: ItemKey) -> OutputItem: key=key, source=source_dataset, target=OutputDataset( - "target", key, target_data, lead_time=lead_time, **dataclasses.asdict(data_coords) + "target", + key, + target_data, + lead_time_hrs=lead_time, + **dataclasses.asdict(data_coords), ), prediction=OutputDataset( "prediction", key, preds_data, - lead_time=lead_time, + lead_time_hrs=lead_time, **dataclasses.asdict(data_coords), ), ) @@ -516,7 +521,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels) - def _extract_sources(self, sample, stream_idx, key, lead_time): + def _extract_sources(self, sample, stream_idx, key, lead_time: int): channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx]