Skip to content

Commit ce058ce

Browse files
committed
Fix incorrect weight generation
- Add `dv_mean` parameter to `_mask_var_with_weight_threshold()` and fix parameter references - Remove unnecessary broadcasting of variable and weights in `_averager()`
1 parent fbd0d55 commit ce058ce

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

xcdat/spatial.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -749,30 +749,30 @@ def _averager(
749749
``weights`` must be a DataArray and cannot contain missing values.
750750
Missing values are replaced with 0 using ``weights.fillna(0)``.
751751
"""
752+
dv = data_var.copy()
752753
weights = self._weights.fillna(0)
753754

754-
# TODO: This conditional might not be needed because Xarray will
755-
# automatically broadcast the weights to the data variable for
756-
# operations such as .mean() and .where().
757-
if min_weight > 0.0:
758-
weights, data_var = xr.broadcast(weights, data_var)
759-
760755
dim: List[str] = []
761756
for key in axis:
762-
dim.append(get_dim_keys(data_var, key)) # type: ignore
757+
dim.append(get_dim_keys(dv, key)) # type: ignore
763758

764759
with xr.set_options(keep_attrs=True):
765-
dv_mean = data_var.cf.weighted(weights).mean(dim=dim)
760+
dv_mean = dv.cf.weighted(weights).mean(dim=dim)
766761

767762
if min_weight > 0.0:
768763
dv_mean = self._mask_var_with_weight_threshold(
769-
dv_mean, dim, weights, min_weight
764+
dv, dv_mean, dim, weights, min_weight
770765
)
771766

772767
return dv_mean
773768

774769
def _mask_var_with_weight_threshold(
775-
self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float
770+
self,
771+
dv: xr.DataArray,
772+
dv_mean: xr.DataArray,
773+
dim: List[str],
774+
weights: xr.DataArray,
775+
min_weight: float,
776776
) -> xr.DataArray:
777777
"""Mask values that do not meet the minimum weight threshold with np.nan.
778778
@@ -786,7 +786,9 @@ def _mask_var_with_weight_threshold(
786786
Parameters
787787
----------
788788
dv : xr.DataArray
789-
The weighted variable.
789+
The weighted variable used for getting masked weights.
790+
dv_mean : xr.DataArray
791+
The average of the weighted variable.
790792
dim: List[str]:
791793
List of axis dimensions to average over.
792794
weights : xr.DataArray
@@ -800,7 +802,8 @@ def _mask_var_with_weight_threshold(
800802
Returns
801803
-------
802804
xr.DataArray
803-
The variable with the minimum weight threshold applied.
805+
The average of the weighted with the minimum weight threshold
806+
applied.
804807
"""
805808
# Sum all weights, including zero for missing values.
806809
weight_sum_all = weights.sum(dim=dim)
@@ -812,7 +815,7 @@ def _mask_var_with_weight_threshold(
812815
frac = weight_sum_masked / weight_sum_all
813816

814817
# Nan out values that don't meet specified weight threshold.
815-
dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True)
816-
dv_new.name = dv.name
818+
dv_new = xr.where(frac >= min_weight, dv_mean, np.nan, keep_attrs=True)
819+
dv_new.name = dv_mean.name
817820

818821
return dv_new

0 commit comments

Comments
 (0)