Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ jobs:
# Run temporarily on a sub directory before the main restyling.
run: ./scripts/actions.sh lint-check

- name: TOML checks
run: ./scripts/actions.sh toml-check

- name: Type checker (pyrefly, experimental)
# Do not attempt to install the default dependencies, this is much faster.
# Run temporarily on a sub directory before the main restyling.
Expand Down
27 changes: 17 additions & 10 deletions packages/common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "weathergen-common"
version = "0.1.0"
description = "The WeatherGenerator Machine Learning Earth System Model"
readme = "../../README.md"
requires-python = ">=3.11,<3.13"
requires-python = ">=3.12,<3.13"
dependencies = [
"xarray>=2025.6.1",
"dask>=2024.9.1",
Expand Down Expand Up @@ -42,12 +42,6 @@ not-callable = false



[tool.black]

# Wide rows
line-length = 100


# The linting configuration
[tool.ruff]

Expand All @@ -70,7 +64,11 @@ select = [
# isort
"I",
# Banned imports
"TID"
"TID",
# Naming conventions
"N",
# print
"T201"
]

# These rules are sensible and should be enabled at a later stage.
Expand All @@ -82,11 +80,20 @@ ignore = [
"SIM118",
"SIM102",
"SIM401",
"UP040", # TODO: enable later
# To ignore, not relevant for us
"SIM108" # in case additional norm layer supports are added in future
"SIM108", # in case additional norm layer supports are added in future
"N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
"E731", # overly restrictive and less readable code
"N812", # prevents us following the convention for importing torch.nn.functional as F
]

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example"

[tool.ruff.format]
# Use Unix `\n` line endings for all files
line-ending = "lf"



[build-system]
Expand Down
5 changes: 1 addition & 4 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@

# experimental value, should be inferred more intelligently
CHUNK_N_SAMPLES = 16392
DType: typing.TypeAlias = np.float32
type DType = np.float32
type NPDT64 = datetime64


_logger = logging.getLogger(__name__)


np.ndarray(3)


@dataclasses.dataclass
class IOReaderData:
"""
Expand Down
30 changes: 25 additions & 5 deletions packages/evaluate/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "weathergen-evaluate"
version = "0.1.0"
description = "The WeatherGenerator Machine Learning Earth System Model"
readme = "../../README.md"
requires-python = ">=3.11,<3.13"
requires-python = ">=3.12,<3.13"
dependencies = [
"cartopy>=0.24.1",
"xskillscore",
Expand All @@ -25,6 +25,12 @@ dev = [
evaluation = "weathergen.evaluate.run_evaluation:evaluate"


# The linting configuration
[tool.ruff]

# Wide rows
line-length = 100

[tool.ruff.lint]
# All disabled until the code is formatted.
select = [
Expand All @@ -40,13 +46,16 @@ select = [
"SIM",
# isort
"I",
# Banned imports
"TID",
# Naming conventions
"N",
# print
"T201"
]

# These rules are sensible and should be enabled at a later stage.
ignore = [
"E501",
"E721",
"E722",
# "B006",
"B011",
"UP008",
Expand All @@ -55,9 +64,20 @@ ignore = [
"SIM102",
"SIM401",
# To ignore, not relevant for us
"E741",
"SIM108", # in case additional norm layer supports are added in future
"N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
"E731", # overly restrictive and less readable code
"N812", # prevents us following the convention for importing torch.nn.functional as F
]

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example"

[tool.ruff.format]
# Use Unix `\n` line endings for all files
line-ending = "lf"


[tool.pyrefly]
project-includes = ["src/"]
project-excludes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def get_channel(self, data_tars, data_preds, tag, level, calc_func) -> None:

data_updated.append(conc)

self.channels = self.channels + (
[tag] if tag not in self.channels else []
)
self.channels = self.channels + ([tag] if tag not in self.channels else [])

else:
data_updated.append(data)
Expand Down
50 changes: 18 additions & 32 deletions packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,19 +252,15 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili
)

stream_cfg = self.get_stream(stream)
assert stream_cfg.get(mode, False), (
"Mode does not exist in stream config. Please add it."
)
assert stream_cfg.get(mode, False), "Mode does not exist in stream config. Please add it."

samples = stream_cfg[mode].get("sample", None)
fsteps = stream_cfg[mode].get("forecast_step", None)
channels = stream_cfg.get("channels", None)

return DataAvailability(
score_availability=True,
channels=None
if (channels == "all" or channels is None)
else list(channels),
channels=None if (channels == "all" or channels is None) else list(channels),
fsteps=None if (fsteps == "all" or fsteps is None) else list(fsteps),
samples=None if (samples == "all" or samples is None) else list(samples),
)
Expand All @@ -284,9 +280,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non

if not self.results_base_dir:
self.results_base_dir = Path(self.inference_cfg["run_path"])
_logger.info(
f"Results directory obtained from model config: {self.results_base_dir}"
)
_logger.info(f"Results directory obtained from model config: {self.results_base_dir}")
else:
_logger.info(f"Results directory parsed: {self.results_base_dir}")

Expand All @@ -304,19 +298,15 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
)
# for backward compatibility allow metric_dir to be specified in the run config
self.metrics_dir = Path(
self.eval_cfg.get(
"metrics_dir", self.metrics_base_dir / self.run_id / "evaluation"
)
self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation")
)

self.fname_zarr = self.results_dir.joinpath(
f"validation_epoch{self.epoch:05d}_rank{self.rank:04d}.zarr"
)

if not self.fname_zarr.exists() or not self.fname_zarr.is_dir():
_logger.error(
f"Zarr file {self.fname_zarr} does not exist or is not a directory."
)
_logger.error(f"Zarr file {self.fname_zarr} does not exist or is not a directory.")
raise FileNotFoundError(
f"Zarr file {self.fname_zarr} does not exist or is not a directory."
)
Expand Down Expand Up @@ -399,9 +389,7 @@ def get_data(
# TODO: Avoid conversion of fsteps and sample to integers (as obtained from the ZarrIO)
fsteps = sorted([int(fstep) for fstep in fsteps])
samples = sorted(
[int(sample) for sample in self.get_samples()]
if samples is None
else samples
[int(sample) for sample in self.get_samples()] if samples is None else samples
)
channels = channels or stream_cfg.get("channels", all_channels)
channels = to_list(channels)
Expand All @@ -427,15 +415,11 @@ def get_data(
fsteps_final = []

for fstep in fsteps:
_logger.info(
f"RUN {self.run_id} - {stream}: Processing fstep {fstep}..."
)
_logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...")
da_tars_fs, da_preds_fs = [], []
pps = []

for sample in tqdm(
samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"
):
for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"):
out = zio.get_data(sample, stream, fstep)
target, pred = out.target.as_xarray(), out.prediction.as_xarray()

Expand Down Expand Up @@ -470,10 +454,16 @@ def get_data(
if len(samples) == 1:
# Ensure sample coordinate is repeated along ipoint even if only one sample
da_tars_fs = da_tars_fs.assign_coords(
sample=("ipoint", np.repeat(da_tars_fs.sample.values, len(da_tars_fs.ipoint)))
sample=(
"ipoint",
np.repeat(da_tars_fs.sample.values, len(da_tars_fs.ipoint)),
)
)
da_preds_fs = da_preds_fs.assign_coords(
sample=("ipoint", np.repeat(da_preds_fs.sample.values, len(da_preds_fs.ipoint)))
sample=(
"ipoint",
np.repeat(da_preds_fs.sample.values, len(da_preds_fs.ipoint)),
)
)

if set(channels) != set(all_channels):
Expand All @@ -494,12 +484,8 @@ def get_data(
points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps)

# Safer than a list
da_tars = {
fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True)
}
da_preds = {
fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=True)
}
da_tars = {fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True)}
da_preds = {fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=True)}

return ReaderOutput(
target=da_tars, prediction=da_preds, points_per_sample=points_per_sample
Expand Down
8 changes: 2 additions & 6 deletions packages/evaluate/src/weathergen/evaluate/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def plot_metric_region(
run_ids.append(run_id)

if selected_data:
_logger.info(
f"Creating plot for {metric} - {region} - {stream} - {ch}."
)
_logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.")
name = "_".join([metric, region] + sorted(set(run_ids)) + [stream, ch])
plotter.plot(
selected_data,
Expand Down Expand Up @@ -146,9 +144,7 @@ def get_marker_size(cls, stream_name: str) -> float:
float
The default marker size for the stream.
"""
return cls._marker_size_stream.get(
stream_name.lower(), cls._default_marker_size
)
return cls._marker_size_stream.get(stream_name.lower(), cls._default_marker_size)

@classmethod
def list_streams(cls):
Expand Down
40 changes: 10 additions & 30 deletions packages/evaluate/src/weathergen/evaluate/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,17 @@ def update_data_selection(self, select: dict):
self.select = select

if "sample" not in select:
_logger.warning(
"No sample in the selection. Might lead to unexpected results."
)
_logger.warning("No sample in the selection. Might lead to unexpected results.")
else:
self.sample = select["sample"]

if "stream" not in select:
_logger.warning(
"No stream in the selection. Might lead to unexpected results."
)
_logger.warning("No stream in the selection. Might lead to unexpected results.")
else:
self.stream = select["stream"]

if "forecast_step" not in select:
_logger.warning(
"No forecast_step in the selection. Might lead to unexpected results."
)
_logger.warning("No forecast_step in the selection. Might lead to unexpected results.")
else:
self.fstep = select["forecast_step"]

Expand Down Expand Up @@ -205,21 +199,15 @@ def create_histograms_per_sample(
f"Creating histograms for {ntimes_unique} valid times of variable {var}."
)

groups = zip(
targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False
)
groups = zip(targ.groupby("valid_time"), prd.groupby("valid_time"), strict=False)
else:
_logger.info(f"Plotting histogram for all valid times of {var}")

groups = [
((None, targ), (None, prd))
] # wrap once with dummy valid_time
groups = [((None, targ), (None, prd))] # wrap once with dummy valid_time

for (valid_time, targ_t), (_, prd_t) in groups:
if valid_time is not None:
_logger.debug(
f"Plotting histogram for {var} at valid_time {valid_time}"
)
_logger.debug(f"Plotting histogram for {var} at valid_time {valid_time}")
name = self.plot_histogram(targ_t, prd_t, hist_output_dir, var, tag=tag)
plot_names.append(name)

Expand Down Expand Up @@ -460,9 +448,7 @@ def scatter_plot(
**map_kwargs_save,
)

plt.colorbar(
scatter_plt, ax=ax, orientation="horizontal", label=f"Variable: {varname}"
)
plt.colorbar(scatter_plt, ax=ax, orientation="horizontal", label=f"Variable: {varname}")
plt.title(f"{self.stream}, {varname} : fstep = {self.fstep:03} ({valid_time})")
ax.set_global()
ax.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1)
Expand Down Expand Up @@ -583,9 +569,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path):

_logger.info(f"Saving summary plots to: {self.out_plot_dir}")

def _check_lengths(
self, data: xr.DataArray | list, labels: str | list
) -> tuple[list, list]:
def _check_lengths(self, data: xr.DataArray | list, labels: str | list) -> tuple[list, list]:
"""
Check if the lengths of data and labels match.

Expand All @@ -612,9 +596,7 @@ def _check_lengths(
data_list = [data] if type(data) == xr.DataArray else data
label_list = [labels] if type(labels) == str else labels

assert len(data_list) == len(label_list), (
"Compare::plot - Data and Labels do not match"
)
assert len(data_list) == len(label_list), "Compare::plot - Data and Labels do not match"

return data_list, label_list

Expand Down Expand Up @@ -667,9 +649,7 @@ def plot(
fig = plt.figure(figsize=(12, 6), dpi=self.dpi_val)

for i, data in enumerate(data_list):
non_zero_dims = [
dim for dim in data.dims if dim != x_dim and data[dim].shape[0] > 1
]
non_zero_dims = [dim for dim in data.dims if dim != x_dim and data[dim].shape[0] > 1]
if non_zero_dims:
_logger.info(
f"LinePlot:: Found multiple entries for dimensions: {non_zero_dims}. Averaging..."
Expand Down
Loading