Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _get_model_config_file_name(run_id: str, epoch: int | None):
epoch_str = f"_epoch{epoch:05d}"
return f"model_{run_id}{epoch_str}.json"


def get_model_results(run_id: str, epoch: int, rank: int) -> Path:
"""
Get the path to the model results zarr store from a given run_id and epoch.
Expand All @@ -110,6 +111,7 @@ def get_model_results(run_id: str, epoch: int, rank: int) -> Path:
raise FileNotFoundError(f"Zarr file {zarr_path} does not exist or is not a directory.")
return zarr_path


def _apply_fixes(config: Config) -> Config:
"""
Apply fixes to maintain a best effort backward combatibility.
Expand All @@ -135,6 +137,7 @@ def _check_logging(config: Config) -> Config:

return config


def load_config(
private_home: Path | None,
from_run_id: str | None,
Expand Down
43 changes: 39 additions & 4 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class OutputDataset:

channels: list[str]
geoinfo_channels: list[str]
# lead time in hours defined as forecast step * length of forecast step (len_hours)
lead_time_hrs: int

@functools.cached_property
def arrays(self) -> dict[str, zarr.Array]:
Expand Down Expand Up @@ -149,6 +151,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.Dataset:
"sample": [self.item_key.sample],
"stream": [self.item_key.stream],
"forecast_step": [self.item_key.forecast_step],
"lead_time_hrs": ("forecast_step", [self.lead_time_hrs]),
"ipoint": self.datapoints,
"channel": self.channels, # TODO: make sure channel names align with data
"valid_time": ("ipoint", times.astype("datetime64[ns]")),
Expand Down Expand Up @@ -247,6 +250,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset):
def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset):
dataset_group.attrs["channels"] = dataset.channels
dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels
dataset_group.attrs["lead_time_hrs"] = dataset.lead_time_hrs

def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset):
for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords
Expand All @@ -263,6 +267,14 @@ def _create_dataset(self, group: zarr.Group, name: str, array: NDArray):
)
group.create_dataset(name, data=array, chunks=chunks)

@functools.cached_property
def example_key(self) -> ItemKey:
sample, example_sample = next(self.data_root.groups())
stream, example_stream = next(example_sample.groups())
fstep, example_item = next(example_stream.groups())

return ItemKey(sample, fstep, stream)

@functools.cached_property
def samples(self) -> list[int]:
"""Query available samples in this zarr store."""
Expand All @@ -281,8 +293,17 @@ def forecast_steps(self) -> list[int]:
# assume stream/samples/forecast_steps are orthogonal
_, example_sample = next(self.data_root.groups())
_, example_stream = next(example_sample.groups())

return list(example_stream.group_keys())

@functools.cached_property
def lead_times(self) -> list[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using it? I don't see a place where you would need it.

Also, let's not used cached_property unless it is vital to do so. Most people will be confused.

"""Calculate available lead times from available forecast steps and len_hrs."""
example_prediction = self.load_zarr(self.example_key).prediction
len_hrs = example_prediction.lead_time_hrs // self.example_key.forecast_step

return [step * len_hrs for step in self.forecast_steps]


@dataclasses.dataclass
class DataCoordinates:
Expand Down Expand Up @@ -326,6 +347,7 @@ class OutputBatchData:

sample_start: int
forecast_offset: int
t_window_len_hours: int

@functools.cached_property
def samples(self):
Expand Down Expand Up @@ -386,17 +408,29 @@ def extract(self, key: ItemKey) -> OutputItem:
"Number of channel names does not align with prediction data."
)

lead_time = self.t_window_len_hours * key.forecast_step

if key.with_source:
source_dataset = self._extract_sources(offset_key.sample, stream_idx, key)
source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time)
else:
source_dataset = None

return OutputItem(
key=key,
source=source_dataset,
target=OutputDataset("target", key, target_data, **dataclasses.asdict(data_coords)),
target=OutputDataset(
"target",
key,
target_data,
lead_time_hrs=lead_time,
**dataclasses.asdict(data_coords),
),
prediction=OutputDataset(
"prediction", key, preds_data, **dataclasses.asdict(data_coords)
"prediction",
key,
preds_data,
lead_time_hrs=lead_time,
**dataclasses.asdict(data_coords),
),
)

Expand Down Expand Up @@ -456,7 +490,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi

return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels)

def _extract_sources(self, sample, stream_idx, key):
def _extract_sources(self, sample, stream_idx, key, lead_time: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add types for all args.

channels = self.source_channels[stream_idx]
geoinfo_channels = self.geoinfo_channels[stream_idx]

Expand All @@ -475,6 +509,7 @@ def _extract_sources(self, sample, stream_idx, key):
source.geoinfos,
channels,
geoinfo_channels,
lead_time,
)

_logger.debug(f"source shape: {source_dataset.data.shape}")
Expand Down
20 changes: 10 additions & 10 deletions packages/evaluate/src/weathergen/evaluate/export_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_data(
dtype: str,
fsteps: list,
channels: list,
fstep_hours: int,
fstep_hours: int,
n_processes: list,
epoch: int,
rank: int,
Expand Down Expand Up @@ -340,7 +340,7 @@ def get_data(
_logger.info(
f"Saving sample {sample_idx} data to {output_format} format in {output_dir}."
)

save_sample_to_netcdf(
str(dtype)[:4],
da_fs,
Expand Down Expand Up @@ -495,11 +495,11 @@ def parse_args(args: list) -> argparse.Namespace:
)

parser.add_argument(
"--fstep-hours",
type = int,
default= 6,
help= "Time difference between forecast steps in hours (e.g., 6)"
)
"--fstep-hours",
type=int,
default=6,
help="Time difference between forecast steps in hours (e.g., 6)",
)

parser.add_argument(
"--epoch",
Expand Down Expand Up @@ -535,7 +535,7 @@ def export_from_args(args: list) -> None:
Export data from Zarr store to NetCDF files based on command line arguments.
Parameters
----------
args : List of command line arguments.
args : List of command line arguments.
"""
args = parse_args(sys.argv[1:])
run_id = args.run_id
Expand All @@ -559,7 +559,7 @@ def export_from_args(args: list) -> None:
config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml")
config = OmegaConf.load(config_file)
# check config loaded correctly
assert len(config["variables"].keys()) > 0 , "Config file not loaded correctly"
assert len(config["variables"].keys()) > 0, "Config file not loaded correctly"

for dtype in data_type:
_logger.info(f"Starting processing {dtype} for run ID {run_id}.")
Expand All @@ -570,7 +570,7 @@ def export_from_args(args: list) -> None:
dtype,
fsteps,
channels,
fstep_hours,
fstep_hours,
n_processes,
epoch,
rank,
Expand Down
2 changes: 1 addition & 1 deletion packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def get_data(
def get_stream(self, stream: str):
"""
returns the dictionary associated to a particular stream.
Returns an empty dictionary if the stream does not exist in the Zarr file.
Returns an empty dictionary if the stream does not exist in the Zarr file.

Parameters
----------
Expand Down
4 changes: 3 additions & 1 deletion packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def evaluate_from_config(cfg):

stream_dict = reader.get_stream(stream)
if not stream_dict:
_logger.info(f"Stream {stream} does not exist in source data or config file is empty. Skipping.")
_logger.info(
f"Stream {stream} does not exist in source data or config file is empty. Skipping."
)
continue

if stream_dict.get("plotting"):
Expand Down
2 changes: 1 addition & 1 deletion packages/evaluate/src/weathergen/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def calc_scores_per_stream(
assert int(combined_metrics.forecast_step) == int(fstep), (
"Different steps in data and metrics. Please check."
)

metric_stream.loc[
{
"forecast_step": int(combined_metrics.forecast_step),
Expand Down
1 change: 1 addition & 0 deletions src/weathergen/utils/validation_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def write_output(
geoinfo_channels,
sample_start,
cf.forecast_offset,
cf.len_hrs,
)

with io.ZarrIO(config.get_path_output(cf, epoch)) as writer:
Expand Down
Loading