diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 281a8c1db..bdf0d3641 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -475,6 +475,82 @@ def normalize_coords(self, coords: NDArray[DType]) -> NDArray[DType]: return coords + def _normalize( + self, + data: NDArray[DType], + idx: list[int], + mean: dict[int, float], + stdev: dict[int, float], + name: str, + ) -> NDArray[DType]: + """ + Helper function to normalize data + + Parameters + ---------- + data : + data to be normalized + idx : + indices of channels to be normalized + mean : + mean values for channels + stdev : + standard deviation values for channels + name : + name of the data (for error messages) + + Returns + ------- + Normalized data + """ + # assert data.shape[-1] == len(idx), f"incorrect number of {name} channels" + if data.shape[-1] != len(idx): + raise ValueError( + f"incorrect number of {name} channels: expected {len(idx)}, got {data.shape[-1]}" + ) + for i, ch in enumerate(idx): + data[..., i] = (data[..., i] - mean[ch]) / stdev[ch] + + return data + + def _denormalize( + self, + data: NDArray[DType], + idx: list[int], + mean: dict[int, float], + stdev: dict[int, float], + name: str, + ) -> NDArray[DType]: + """ + Helper function to denormalize data + + Parameters + ---------- + data : + data to be denormalized + idx : + indices of channels to be denormalized + mean : + mean values for channels + stdev : + standard deviation values for channels + name : + name of the data (for error messages) + + Returns + ------- + Denormalized data + """ + # assert data.shape[-1] == len(idx), f"incorrect number of {name} channels" + if data.shape[-1] != len(idx): + raise ValueError( + f"incorrect number of {name} channels: expected {len(idx)}, got {data.shape[-1]}" + ) + for i, ch in enumerate(idx): + data[..., i] = (data[..., i] * stdev[ch]) + mean[ch] + + return data + def normalize_geoinfos(self, geoinfos: NDArray[DType]) -> NDArray[DType]: """ Normalize geoinfos @@ -501,18 +577,14 @@ def normalize_source_channels(self, source: NDArray[DType]) -> NDArray[DType]: Parameters ---------- - data : + source : data to be normalized Returns ------- - Normalized data + Normalized source data """ - assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" - for i, ch in enumerate(self.source_idx): - source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] - - return source + return self._normalize(source, self.source_idx, self.mean, self.stdev, "source") def normalize_target_channels(self, target: NDArray[DType]) -> NDArray[DType]: """ @@ -520,18 +592,14 @@ def normalize_target_channels(self, target: NDArray[DType]) -> NDArray[DType]: Parameters ---------- - data : + target : data to be normalized Returns ------- - Normalized data + Normalized target data """ - assert target.shape[-1] == len(self.target_idx), "incorrect number of target channels" - for i, ch in enumerate(self.target_idx): - target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch] - - return target + return self._normalize(target, self.target_idx, self.mean, self.stdev, "target") def denormalize_source_channels(self, source: NDArray[DType]) -> NDArray[DType]: """ @@ -539,37 +607,29 @@ def denormalize_source_channels(self, source: NDArray[DType]) -> NDArray[DType]: Parameters ---------- - data : + source : data to be denormalized Returns ------- - Denormalized data + Denormalized source data """ - assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" - for i, ch in enumerate(self.source_idx): - source[..., i] = (source[..., i] * self.stdev[ch]) + self.mean[ch] + return self._denormalize(source, self.source_idx, self.mean, self.stdev, "source") - return source - - def denormalize_target_channels(self, data: NDArray[DType]) -> NDArray[DType]: + def denormalize_target_channels(self, target: NDArray[DType]) -> NDArray[DType]: """ Denormalize target channels Parameters ---------- - data : + target : data to be denormalized (target or pred) Returns ------- - Denormalized data + Denormalized target data """ - assert data.shape[-1] == len(self.target_idx), "incorrect number of target channels" - for i, ch in enumerate(self.target_idx): - data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch] - - return data + return self._denormalize(target, self.target_idx, self.mean, self.stdev, "target") class DataReaderTimestep(DataReaderBase):