diff --git a/tests/test_spatial.py b/tests/test_spatial.py index fe0361cd..90c5e095 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -15,13 +15,13 @@ def setup(self): ) def test__init__(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) obj = SpatialAccessor(ds) assert obj._dataset.identical(ds) def test_decorator_call(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) obj = ds.spatial assert obj._dataset.identical(ds) @@ -50,7 +50,7 @@ def test_raises_error_if_axis_list_contains_unsupported_axis(self): self.ds.spatial.average("ts", axis=["Y", "incorrect_axis"]) def test_raises_error_if_lat_axis_coords_cant_be_found(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) # Update CF metadata to invalid values so cf_xarray can't interpret them. del ds.lat.attrs["axis"] @@ -64,7 +64,7 @@ def test_raises_error_if_lat_axis_coords_cant_be_found(self): ds.spatial.average("ts", axis=["X", "Y"]) def test_raises_error_if_lon_axis_coords_cant_be_found(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) # Update CF metadata to invalid values so cf_xarray can't interpret them. del ds.lon.attrs["axis"] @@ -141,13 +141,13 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) def test_spatial_average_for_lat_region_and_keep_weights(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) result = ds.spatial.average( "ts", axis=["Y"], lat_bounds=(-5.0, 5), keep_weights=True ) - expected = self.ds.copy() + expected = self.ds.copy(deep=True) expected["ts"] = xr.DataArray( data=np.array( [[2.25, 2.25, 2.25, 2.25], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] @@ -164,12 +164,12 @@ def test_spatial_average_for_lat_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) def test_spatial_average_for_lat_region(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) # Specifying axis as a str instead of list of str. result = ds.spatial.average("ts", axis=["Y"], lat_bounds=(-5.0, 5)) - expected = self.ds.copy() + expected = self.ds.copy(deep=True) expected["ts"] = xr.DataArray( data=np.array( [[2.25, 2.25, 2.25, 2.25], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] @@ -183,7 +183,7 @@ def test_spatial_average_for_lat_region(self): def test_spatial_average_for_domain_wrapping_p_meridian_non_cf_conventions( self, ): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) # get spatial average for original dataset ref = ds.spatial.average("ts").ts @@ -200,14 +200,14 @@ def test_spatial_average_for_domain_wrapping_p_meridian_non_cf_conventions( @requires_dask def test_spatial_average_for_lat_region_and_keep_weights_with_dask(self): - ds = self.ds.copy().chunk(2) + ds = self.ds.copy(deep=True).chunk(2) # Specifying axis as a str instead of list of str. result = ds.spatial.average( "ts", axis=["Y"], lat_bounds=(-5.0, 5), keep_weights=True ) - expected = self.ds.copy() + expected = self.ds.copy(deep=True) expected["ts"] = xr.DataArray( data=np.array( [[2.25, 2.25, 2.25, 2.25], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] @@ -224,7 +224,7 @@ def test_spatial_average_for_lat_region_and_keep_weights_with_dask(self): xr.testing.assert_allclose(result, expected) def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) result = ds.spatial.average( "ts", axis=["X", "Y"], @@ -233,7 +233,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): keep_weights=True, ) - expected = self.ds.copy() + expected = self.ds.copy(deep=True) expected["ts"] = xr.DataArray( data=np.array([2.25, 1.0, 1.0]), coords={"time": expected.time}, @@ -255,7 +255,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) weights = xr.DataArray( data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]), @@ -270,7 +270,7 @@ def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): data_var="ts", ) - expected = self.ds.copy() + expected = self.ds.copy(deep=True) expected["ts"] = xr.DataArray( data=np.array([2.25, 1.0, 1.0]), coords={"time": expected.time}, @@ -280,6 +280,179 @@ def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): assert result.identical(expected) +class TestAveragerMinWeight: + @pytest.fixture(autouse=True) + def setup(self): + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + + # Limit to just 3 data points to simplify testing. + self.ds = self.ds.isel(time=slice(None, 3)) + + # Change the value of the first element so that it is easier to identify + # changes in the output. + self.ds["ts"].data[0] = np.full((4, 4), 2.25) + + def test_raises_error_if_min_weight_is_negative(self): + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.1) + + def test_raises_error_if_min_weight_is_greater_than_one(self): + with pytest.raises(ValueError): + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.1) + + def test_spatial_average_with_min_weight_zero(self): + ds = self.ds.copy(deep=True) + + # Insert NaN values into the dataset (no minimum required to compute value). + ds["ts"][0, :, 2] = np.nan + + # min_weight=0.0 means no minimum weight threshold required to compute value + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + min_weight=0.0, + ) + + expected = self.ds.copy(deep=True) + expected["ts"] = xr.DataArray( + data=np.array([2.25, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + def test_spatial_average_with_min_weight_none_equivalent_to_zero(self): + ds = self.ds.copy(deep=True) + + # Insert NaN values into the dataset. + ds["ts"][0, :, 2] = np.nan + + # min_weight=None means no minimum weight threshold required to compute value + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + min_weight=None, + ) + + expected = self.ds.copy(deep=True) + expected["ts"] = xr.DataArray( + data=np.array([2.25, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + def test_spatial_average_with_min_weight_half(self): + ds = self.ds.copy(deep=True) + + # Insert NaN values into the dataset > 50% at second time point. + ds["ts"][1, :, :] = np.nan + + # At least 50% of the weights must be non-NaN to compute value. + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + min_weight=0.5, + ) + + # The second grouping window will by NaN because the minimum weight + # threshold is not met (>50%). + expected = self.ds.copy(deep=True) + expected["ts"] = xr.DataArray( + data=np.array([2.25, np.nan, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + def test_spatial_average_with_min_weight_one(self): + ds = self.ds.copy(deep=True) + + # Insert a single NaN value at the last time point. + ds["ts"][2, 0, 0] = np.nan + + # 100% of the weights must be non-NaN to compute value. + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + min_weight=1.0, + ) + + # The last grouping window will by NaN because the minimum weight + # threshold is not met (1 value is NaN out of 4). + expected = self.ds.copy(deep=True) + expected["ts"] = xr.DataArray( + data=np.array([2.25, 1.0, np.nan]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + def test_spatial_average_with_min_weight_edge_case_zero_weights(self): + ds = self.ds.copy(deep=True) + + # Set all weights to zero. + weights = xr.DataArray( + data=np.zeros((4, 4)), + coords={"lat": ds.lat, "lon": ds.lon}, + dims=["lat", "lon"], + ) + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + weights=weights, + min_weight=0.5, + ) + + # With all weights set to zero, the results for all grouping windows + # should be NaN because the minimum weight threshold is not met (>50%). + expected = self.ds.copy(deep=True) + expected["ts"] = xr.DataArray( + data=np.array([np.nan, np.nan, np.nan]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + def test_spatial_average_with_min_weight_edge_case_partial_nan_weights(self): + ds = self.ds.copy(deep=True) + + # Insert NaN values into the weights. + weights = xr.DataArray( + data=np.array( + [[1, np.nan, 1, 1], [1, 1, np.nan, 1], [1, 1, 1, np.nan], [1, 1, 1, 1]] + ), + coords={"lat": ds.lat, "lon": ds.lon}, + dims=["lat", "lon"], + ) + + result = ds.spatial.average( + "ts", + axis=["X", "Y"], + weights=weights, + min_weight=0.5, + ) + + # With partial NaN weights, the averages are still computed because + # at least 50% of the weights are non-NaN for each grouping window. + expected = self.ds.copy(deep=True) + expected["ts"] = xr.DataArray( + data=np.array([2.25, 1.0, 1.0]), + coords={"time": expected.time}, + dims="time", + ) + + xr.testing.assert_allclose(result, expected) + + class TestGetWeights: @pytest.fixture(autouse=True) def setup(self): @@ -300,7 +473,7 @@ def test_value_error_thrown_for_multiple_out_of_order_lon_bounds(self): self.ds.spatial._get_longitude_weights(domain_bounds, region_bounds=None) def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) # Create a second "Y" axis dimension and associated bounds ds["lat2"] = ds.lat.copy() @@ -314,7 +487,7 @@ def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self) ds.spatial.get_weights(axis=["Y", "X"]) def test_data_var_weights_for_region_in_lat_and_lon_domains(self): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) result = ds.spatial.get_weights( axis=["Y", "X"], lat_bounds=(-5, 5), lon_bounds=(-170, -120), data_var="ts" @@ -394,7 +567,7 @@ def test_weights_for_region_in_lon_domain(self): def test_dataset_weights_for_region_in_lon_domain_with_region_spanning_p_meridian( self, ): - ds = self.ds.copy() + ds = self.ds.copy(deep=True) result = ds.spatial._get_longitude_weights( domain_bounds=ds.lon_bnds, diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d4dcbe8..30d3cbfb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xcdat.utils import compare_datasets, str_to_bool +from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool class TestCompareDatasets: @@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self): with pytest.raises(ValueError): str_to_bool("1") + + +class TestValidateMinWeight: + def test_pass_None_returns_0(self): + result = _validate_min_weight(None) + + assert result == 0 + + def test_returns_error_if_less_than_0(self): + with pytest.raises(ValueError): + _validate_min_weight(-1) + + def test_returns_error_if_greater_than_1(self): + with pytest.raises(ValueError): + _validate_min_weight(1.1) + + def test_returns_valid_min_weight(self): + result = _validate_min_weight(1) + + assert result == 1 diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 15bec956..c5f12c10 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -27,7 +27,11 @@ get_dim_keys, ) from xcdat.dataset import _get_data_var -from xcdat.utils import _if_multidim_dask_array_then_load +from xcdat.utils import ( + _get_masked_weights, + _if_multidim_dask_array_then_load, + _validate_min_weight, +) #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] @@ -74,8 +78,9 @@ def average( axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), weights: Union[Literal["generate"], xr.DataArray] = "generate", keep_weights: bool = False, - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + min_weight: float | None = None, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -114,17 +119,28 @@ def average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - lat_bounds : Optional[RegionAxisBounds], optional + lat_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional latitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound cannot be larger than the upper bound, by default None. - lon_bounds : Optional[RegionAxisBounds], optional + lon_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional longitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. + min_weight : optional, float + Minimum threshold of data coverage (weight) required to compute + a spatial average for a grouping window. Must be between 0 and 1. + Useful for ensuring accurate averages in regions with missing data, + by default None (equivalent to 0.0). + + The value must be between 0 and 1, where: + - 0/``None`` means no minimum threshold (all data is considered, + even if coverage is minimal). + - 1 means full data coverage is required (no missing data is + allowed). Returns ------- @@ -184,7 +200,9 @@ def average( """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var) + self._validate_axis_arg(axis) + min_weight = _validate_min_weight(min_weight) if isinstance(weights, str) and weights == "generate": if lat_bounds is not None: @@ -196,7 +214,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis) + ds[dv.name] = self._averager(dv, axis, min_weight=min_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -206,9 +224,9 @@ def average( def get_weights( self, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, - data_var: Optional[str] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + data_var: str | None = None, ) -> xr.DataArray: """ Get area weights for specified axis keys and an optional target domain. @@ -227,13 +245,13 @@ def get_weights( ---------- axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - lat_bounds : Optional[RegionAxisBounds] + lat_bounds : RegionAxisBounds | None Tuple of latitude boundaries for regional selection, by default None. - lon_bounds : Optional[RegionAxisBounds] + lon_bounds : RegionAxisBounds | None Tuple of longitude boundaries for regional selection, by default None. - data_var: Optional[str] + data_var: str | None The key of the data variable, by default None. Pass this argument when the dataset has more than one bounds per axis (e.g., "lon" and "zlon_bnds" for the "X" axis), or you want weights for a @@ -377,7 +395,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds): ) def _get_longitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the longitude axis. @@ -404,7 +422,7 @@ def _get_longitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the longitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for longitude regional selection. Returns @@ -418,7 +436,7 @@ def _get_longitude_weights( If the there are multiple instances in which the domain_bounds[:, 0] > domain_bounds[:, 1] """ - p_meridian_index: Optional[np.ndarray] = None + p_meridian_index: np.ndarray | None = None d_bounds = domain_bounds.copy() pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0] @@ -450,7 +468,7 @@ def _get_longitude_weights( return weights def _get_latitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the latitude axis. @@ -462,7 +480,7 @@ def _get_latitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the latitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for latitude regional selection. Returns @@ -702,7 +720,10 @@ def _validate_weights( ) def _averager( - self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] + self, + data_var: xr.DataArray, + axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], + min_weight: float, ): """Perform a weighted average of a data variable. @@ -721,6 +742,17 @@ def _averager( Data variable inside a Dataset. axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. + min_weight : optional, float + Minimum threshold of data coverage (weight) required to compute + a spatial average for a grouping window. Must be between 0 and 1. + Useful for ensuring accurate averages in regions with missing data, + by default None (equivalent to 0.0). + + The value must be between 0 and 1, where: + - 0/``None`` means no minimum threshold (all data is considered, + even if coverage is minimal). + - 1 means full data coverage is required (no missing data is + allowed). Returns ------- @@ -732,13 +764,81 @@ def _averager( ``weights`` must be a DataArray and cannot contain missing values. Missing values are replaced with 0 using ``weights.fillna(0)``. """ + dv = data_var.copy() weights = self._weights.fillna(0) - dim = [] + dim: List[str] = [] for key in axis: - dim.append(get_dim_keys(data_var, key)) + dim.append(get_dim_keys(dv, key)) # type: ignore with xr.set_options(keep_attrs=True): - weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) + dv_mean = dv.cf.weighted(weights).mean(dim=dim) + + if min_weight > 0.0: + dv_mean = self._mask_var_with_weight_threshold( + dv, dv_mean, dim, weights, min_weight + ) + + return dv_mean + + def _mask_var_with_weight_threshold( + self, + dv: xr.DataArray, + dv_mean: xr.DataArray, + dim: List[str], + weights: xr.DataArray, + min_weight: float, + ) -> xr.DataArray: + """Mask values that do not meet the minimum weight threshold with np.nan. + + This function is useful for cases where the weighting of data might be + skewed based on the availability of data. For example, if a portion of + cells in a region has significantly more missing data than other other + regions, it can result in inaccurate calculations of spatial averaging. + Masking values that do not meet the minimum weight threshold ensures + more accurate calculations. + + Parameters + ---------- + dv : xr.DataArray + The weighted variable used for getting masked weights. + dv_mean : xr.DataArray + The average of the weighted variable. + dim: List[str]: + List of axis dimensions to average over. + weights : xr.DataArray + A DataArray containing either the regional weights used for weighted + averaging. ``weights`` must include the same axis dimensions and + dimensional sizes as the data variable. + min_weight : optional, float + Minimum threshold of data coverage (weight) required to compute + a spatial average for a grouping window. Must be between 0 and 1. + Useful for ensuring accurate averages in regions with missing data, + by default None (equivalent to 0.0). + + The value must be between 0 and 1, where: + - 0/``None`` means no minimum threshold (all data is considered, + even if coverage is minimal). + - 1 means full data coverage is required (no missing data is + allowed). + + Returns + ------- + xr.DataArray + The average of the weighted variable with the minimum weight + threshold applied. + """ + # Sum all weights, including zero for missing values. + weight_sum_all = weights.sum(dim=dim) + + masked_weights = _get_masked_weights(dv, weights) + weight_sum_masked = masked_weights.sum(dim=dim) + + # Get fraction of the available weight. + frac = weight_sum_masked / weight_sum_all + + # Nan out values that don't meet specified weight threshold. + dv_new = xr.where(frac >= min_weight, dv_mean, np.nan, keep_attrs=True) + dv_new.name = dv_mean.name - return weighted_mean + return dv_new diff --git a/xcdat/utils.py b/xcdat/utils.py index 83596561..8828272a 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import json from typing import Dict, List, Optional, Union @@ -132,3 +134,60 @@ def _if_multidim_dask_array_then_load( return obj.load() return None + + +def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray: + """Get weights with missing data (`np.nan`) receiving no weight (zero). + + Parameters + ---------- + dv : xr.DataArray + The variable. + weights : xr.DataArray + A DataArray containing either the regional or temporal weights used for + weighted averaging. ``weights`` must include the same axis dimensions + and dimensional sizes as the data variable. + + Returns + ------- + xr.DataArray + The masked weights. + """ + masked_weights = xr.where(dv.copy().isnull(), 0.0, weights) + + return masked_weights + + +def _validate_min_weight(min_weight: float | None) -> float: + """Validate the ``min_weight`` value. + + Parameters + ---------- + min_weight : float | None + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + float + The required weight percentage. + + Raises + ------ + ValueError + If the `min_weight` argument is less than 0. + ValueError + If the `min_weight` argument is greater than 1. + """ + if min_weight is None: + return 0.0 + elif min_weight < 0.0: + raise ValueError( + "min_weight argument is less than 0. min_weight must be between 0 and 1.", + ) + elif min_weight > 1.0: + raise ValueError( + "min_weight argument is greater than 1. min_weight must be between 0 and 1.", + ) + + return min_weight