Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DropoutMixin(Base):
"negative or zero.",
)

dropout_fraction: float|list[float] = Field(
dropout_fraction: float | list[float] = Field(
default=0,
description="Either a float(Chance of dropout being applied to each sample) or a list of "
"floats (probability that dropout of the corresponding timedelta is applied)",
Expand All @@ -104,11 +104,11 @@ def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
raise ValueError("Dropout timedeltas must be negative")
return v


@field_validator("dropout_fraction")
def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
def dropout_fractions(cls, dropout_frac: float | list[float]) -> float | list[float]:
"""Validate 'dropout_frac'."""
from math import isclose

if isinstance(dropout_frac, float):
if not (dropout_frac <= 1):
raise ValueError("Input should be less than or equal to 1")
Expand All @@ -128,21 +128,40 @@ def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]
if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
raise ValueError("Sum of all floats in the list must be 1.0")


else:
raise TypeError("Must be either a float or a list of floats")
return dropout_frac


@model_validator(mode="after")
def dropout_instructions_consistent(self) -> "DropoutMixin":
"""Validator for dropout instructions."""
if self.dropout_fraction == 0:
if self.dropout_timedeltas_minutes != []:
raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
else:
if self.dropout_timedeltas_minutes == []:
raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
# For float dropout_fraction > 0, dropout_timedeltas_minutes must be empty
if isinstance(self.dropout_fraction, float) and self.dropout_timedeltas_minutes != []:
raise ValueError(
"If dropout_fraction is a float and > 0, "
"dropout_timedeltas_minutes must be empty",
)
# For list dropout_fraction, must have matching dropout_timedeltas_minutes
if isinstance(self.dropout_fraction, list) and self.dropout_timedeltas_minutes == []:
raise ValueError(
"If dropout_fraction is a list and > 0, "
"dropout_timedeltas_minutes must be non-empty",
)
return self

@model_validator(mode="after")
def validate_dropout(self) -> "DropoutMixin":
"""Validator for length match: dropout_fraction list and dropout_timedeltas_minutes."""
if isinstance(self.dropout_fraction, list) and len(self.dropout_fraction) != len(
self.dropout_timedeltas_minutes,
):
raise ValueError(
"If dropout_fraction is a list, its length must match dropout_timedeltas_minutes.",
)
return self


Expand All @@ -164,12 +183,14 @@ class SpatialWindowMixin(Base):

class NormalisationValues(Base):
"""Normalisation mean and standard deviation."""

mean: float = Field(..., description="Mean value for normalization")
std: float = Field(..., gt=0, description="Standard deviation (must be positive)")


class NormalisationConstantsMixin(Base):
"""Normalisation constants for multiple channels."""

normalisation_constants: dict[str, NormalisationValues]

@property
Expand All @@ -180,7 +201,6 @@ def channel_means(self) -> dict[str, float]:
for channel, norm_values in self.normalisation_constants.items()
}


@property
def channel_stds(self) -> dict[str, float]:
"""Return the channel standard deviations."""
Expand Down Expand Up @@ -209,9 +229,9 @@ def check_all_channel_have_normalisation_constants(self) -> "Satellite":
"""Check that all the channels have normalisation constants."""
normalisation_channels = set(self.normalisation_constants.keys())
missing_norm_values = set(self.channels) - set(normalisation_channels)
if len(missing_norm_values)>0:
if len(missing_norm_values) > 0:
raise ValueError(
"Normalsation constants must be provided for all channels. Missing values for "
"Normalisation constants must be provided for all channels. Missing values for "
f"channels: {missing_norm_values}",
)
return self
Expand Down Expand Up @@ -261,7 +281,6 @@ def validate_provider(cls, v: str) -> str:
raise OSError(f"NWP provider {v} is not in {NWP_PROVIDERS}")
return v


@model_validator(mode="after")
def check_all_channel_have_normalisation_constants(self) -> "NWP":
"""Check that all the channels have normalisation constants."""
Expand All @@ -270,16 +289,16 @@ def check_all_channel_have_normalisation_constants(self) -> "NWP":
accum_channel_names = [f"diff_{c}" for c in self.accum_channels]

missing_norm_values = set(non_accum_channels) - set(normalisation_channels)
if len(missing_norm_values)>0:
if len(missing_norm_values) > 0:
raise ValueError(
"Normalsation constants must be provided for all channels. Missing values for "
"Normalisation constants must be provided for all channels. Missing values for "
f"channels: {missing_norm_values}",
)

missing_norm_values = set(accum_channel_names) - set(normalisation_channels)
if len(missing_norm_values)>0:
if len(missing_norm_values) > 0:
raise ValueError(
"Normalsation constants must be provided for all channels. Accumulated "
"Normalisation constants must be provided for all channels. Accumulated "
"channels which will be diffed require normalisation constant names which "
"start with the prefix 'diff_'. The following channels were missing: "
f"{missing_norm_values}.",
Expand Down
Loading