Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ea08915

Browse files
pochedlstomvothecoder
authored andcommittedNov 21, 2024
initial attempt at #531 (for spatial averaging)
1 parent 27396e5 commit ea08915

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed
 

‎tests/test_spatial.py

+34
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
140140
with pytest.raises(ValueError):
141141
self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)
142142

143+
def test_raises_error_if_required_weight_not_between_zero_and_one(
144+
self,
145+
):
146+
# ensure error if required_weight less than zero
147+
with pytest.raises(ValueError):
148+
self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01)
149+
150+
# ensure error if required_weight greater than 1
151+
with pytest.raises(ValueError):
152+
self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01)
153+
143154
def test_spatial_average_for_lat_region_and_keep_weights(self):
144155
ds = self.ds.copy()
145156

@@ -254,6 +265,29 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):
254265

255266
xr.testing.assert_allclose(result, expected)
256267

268+
def test_spatial_average_with_required_weight(self):
269+
ds = self.ds.copy()
270+
271+
# insert a nan
272+
ds["ts"][0, :, 2] = np.nan
273+
274+
result = ds.spatial.average(
275+
"ts",
276+
axis=["X", "Y"],
277+
lat_bounds=(-5.0, 5),
278+
lon_bounds=(-170, -120.1),
279+
required_weight=1.0,
280+
)
281+
282+
expected = self.ds.copy()
283+
expected["ts"] = xr.DataArray(
284+
data=np.array([np.nan, 1.0, 1.0]),
285+
coords={"time": expected.time},
286+
dims="time",
287+
)
288+
289+
xr.testing.assert_allclose(result, expected)
290+
257291
def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self):
258292
ds = self.ds.copy()
259293

‎xcdat/spatial.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def average(
7676
keep_weights: bool = False,
7777
lat_bounds: Optional[RegionAxisBounds] = None,
7878
lon_bounds: Optional[RegionAxisBounds] = None,
79+
required_weight: Optional[float] = 0.0,
7980
) -> xr.Dataset:
8081
"""
8182
Calculates the spatial average for a rectilinear grid over an optionally
@@ -125,6 +126,9 @@ def average(
125126
ignored if ``weights`` are supplied. The lower bound can be larger
126127
than the upper bound (e.g., across the prime meridian, dateline), by
127128
default None.
129+
required_weight : optional, float
130+
Fraction of data coverage (i..e, weight) needed to return a
131+
spatial average value. Value must range from 0 to 1.
128132
129133
Returns
130134
-------
@@ -196,7 +200,7 @@ def average(
196200
self._weights = weights
197201

198202
self._validate_weights(dv, axis)
199-
ds[dv.name] = self._averager(dv, axis)
203+
ds[dv.name] = self._averager(dv, axis, required_weight=required_weight)
200204

201205
if keep_weights:
202206
ds[self._weights.name] = self._weights
@@ -702,7 +706,10 @@ def _validate_weights(
702706
)
703707

704708
def _averager(
705-
self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...]
709+
self,
710+
data_var: xr.DataArray,
711+
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
712+
required_weight: Optional[float] = 0.0,
706713
):
707714
"""Perform a weighted average of a data variable.
708715
@@ -721,6 +728,9 @@ def _averager(
721728
Data variable inside a Dataset.
722729
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
723730
List of axis dimensions to average over.
731+
required_weight : optional, float
732+
Fraction of data coverage (i..e, weight) needed to return a
733+
spatial average value. Value must range from 0 to 1.
724734
725735
Returns
726736
-------
@@ -734,11 +744,48 @@ def _averager(
734744
"""
735745
weights = self._weights.fillna(0)
736746

747+
# ensure required weight is between 0 and 1
748+
if required_weight is None:
749+
required_weight = 0.0
750+
751+
if required_weight < 0.0:
752+
raise ValueError(
753+
"required_weight argment is less than zero. "
754+
"required_weight must be between 0 and 1."
755+
)
756+
757+
if required_weight > 1.0:
758+
raise ValueError(
759+
"required_weight argment is greater than zero. "
760+
"required_weight must be between 0 and 1."
761+
)
762+
763+
# need weights to match data_var dimensionality
764+
if required_weight > 0.0:
765+
weights, data_var = xr.broadcast(weights, data_var)
766+
767+
# get averaging dimensions
737768
dim = []
738769
for key in axis:
739770
dim.append(get_dim_keys(data_var, key))
740771

772+
# compute weighed mean
741773
with xr.set_options(keep_attrs=True):
742774
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)
743775

776+
# if weight thresholds applied, calculate fraction of data availability
777+
# replace values that do not meet minimum weight with nan
778+
if required_weight > 0.0:
779+
# sum all weights (assuming no missing values exist)
780+
print(dim)
781+
weight_sum_all = weights.sum(dim=dim) # type: ignore
782+
# zero out cells with missing values in data_var
783+
weights = xr.where(~np.isnan(data_var), weights, 0)
784+
# sum all weights (including zero for missing values)
785+
weight_sum_masked = weights.sum(dim=dim) # type: ignore
786+
# get fraction of weight available
787+
frac = weight_sum_masked / weight_sum_all
788+
# nan out values that don't meet specified weight threshold
789+
weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan)
790+
744791
return weighted_mean

0 commit comments

Comments
 (0)
Please sign in to comment.