@@ -749,30 +749,30 @@ def _averager(
749
749
``weights`` must be a DataArray and cannot contain missing values.
750
750
Missing values are replaced with 0 using ``weights.fillna(0)``.
751
751
"""
752
+ dv = data_var .copy ()
752
753
weights = self ._weights .fillna (0 )
753
754
754
- # TODO: This conditional might not be needed because Xarray will
755
- # automatically broadcast the weights to the data variable for
756
- # operations such as .mean() and .where().
757
- if min_weight > 0.0 :
758
- weights , data_var = xr .broadcast (weights , data_var )
759
-
760
755
dim : List [str ] = []
761
756
for key in axis :
762
- dim .append (get_dim_keys (data_var , key )) # type: ignore
757
+ dim .append (get_dim_keys (dv , key )) # type: ignore
763
758
764
759
with xr .set_options (keep_attrs = True ):
765
- dv_mean = data_var .cf .weighted (weights ).mean (dim = dim )
760
+ dv_mean = dv .cf .weighted (weights ).mean (dim = dim )
766
761
767
762
if min_weight > 0.0 :
768
763
dv_mean = self ._mask_var_with_weight_threshold (
769
- dv_mean , dim , weights , min_weight
764
+ dv , dv_mean , dim , weights , min_weight
770
765
)
771
766
772
767
return dv_mean
773
768
774
769
def _mask_var_with_weight_threshold (
775
- self , dv : xr .DataArray , dim : List [str ], weights : xr .DataArray , min_weight : float
770
+ self ,
771
+ dv : xr .DataArray ,
772
+ dv_mean : xr .DataArray ,
773
+ dim : List [str ],
774
+ weights : xr .DataArray ,
775
+ min_weight : float ,
776
776
) -> xr .DataArray :
777
777
"""Mask values that do not meet the minimum weight threshold with np.nan.
778
778
@@ -786,7 +786,9 @@ def _mask_var_with_weight_threshold(
786
786
Parameters
787
787
----------
788
788
dv : xr.DataArray
789
- The weighted variable.
789
+ The weighted variable used for getting masked weights.
790
+ dv_mean : xr.DataArray
791
+ The average of the weighted variable.
790
792
dim: List[str]:
791
793
List of axis dimensions to average over.
792
794
weights : xr.DataArray
@@ -800,7 +802,8 @@ def _mask_var_with_weight_threshold(
800
802
Returns
801
803
-------
802
804
xr.DataArray
803
- The variable with the minimum weight threshold applied.
805
+ The average of the weighted with the minimum weight threshold
806
+ applied.
804
807
"""
805
808
# Sum all weights, including zero for missing values.
806
809
weight_sum_all = weights .sum (dim = dim )
@@ -812,7 +815,7 @@ def _mask_var_with_weight_threshold(
812
815
frac = weight_sum_masked / weight_sum_all
813
816
814
817
# Nan out values that don't meet specified weight threshold.
815
- dv_new = xr .where (frac >= min_weight , dv , np .nan , keep_attrs = True )
816
- dv_new .name = dv .name
818
+ dv_new = xr .where (frac >= min_weight , dv_mean , np .nan , keep_attrs = True )
819
+ dv_new .name = dv_mean .name
817
820
818
821
return dv_new
0 commit comments