From 1b4742c1b918d364d428f749f286fb01df273283 Mon Sep 17 00:00:00 2001
From: Stephen Po-Chedley <pochedley@gmail.com>
Date: Fri, 28 Jun 2024 12:54:47 -0700
Subject: [PATCH 01/11] initial attempt at #531 (for spatial averaging)

---
 tests/test_spatial.py | 34 +++++++++++++++++++++++++++++
 xcdat/spatial.py      | 51 +++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 83 insertions(+), 2 deletions(-)

diff --git a/tests/test_spatial.py b/tests/test_spatial.py
index fe0361cd..4f27b226 100644
--- a/tests/test_spatial.py
+++ b/tests/test_spatial.py
@@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
         with pytest.raises(ValueError):
             self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)
 
+    def test_raises_error_if_required_weight_not_between_zero_and_one(
+        self,
+    ):
+        # ensure error if required_weight less than zero
+        with pytest.raises(ValueError):
+            self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01)
+
+        # ensure error if required_weight greater than 1
+        with pytest.raises(ValueError):
+            self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01)
+
     def test_spatial_average_for_lat_region_and_keep_weights(self):
         ds = self.ds.copy()
 
@@ -254,6 +265,29 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):
 
         xr.testing.assert_allclose(result, expected)
 
+    def test_spatial_average_with_required_weight(self):
+        ds = self.ds.copy()
+
+        # insert a nan
+        ds["ts"][0, :, 2] = np.nan
+
+        result = ds.spatial.average(
+            "ts",
+            axis=["X", "Y"],
+            lat_bounds=(-5.0, 5),
+            lon_bounds=(-170, -120.1),
+            required_weight=1.0,
+        )
+
+        expected = self.ds.copy()
+        expected["ts"] = xr.DataArray(
+            data=np.array([np.nan, 1.0, 1.0]),
+            coords={"time": expected.time},
+            dims="time",
+        )
+
+        xr.testing.assert_allclose(result, expected)
+
     def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self):
         ds = self.ds.copy()
 
diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index 15bec956..1beab80a 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -76,6 +76,7 @@ def average(
         keep_weights: bool = False,
         lat_bounds: Optional[RegionAxisBounds] = None,
         lon_bounds: Optional[RegionAxisBounds] = None,
+        required_weight: Optional[float] = 0.0,
     ) -> xr.Dataset:
         """
         Calculates the spatial average for a rectilinear grid over an optionally
@@ -125,6 +126,9 @@ def average(
             ignored if ``weights`` are supplied. The lower bound can be larger
             than the upper bound (e.g., across the prime meridian, dateline), by
             default None.
+        required_weight : optional, float
+            Fraction of data coverage (i..e, weight) needed to return a
+            spatial average value. Value must range from 0 to 1.
 
         Returns
         -------
@@ -196,7 +200,7 @@ def average(
             self._weights = weights
 
         self._validate_weights(dv, axis)
-        ds[dv.name] = self._averager(dv, axis)
+        ds[dv.name] = self._averager(dv, axis, required_weight=required_weight)
 
         if keep_weights:
             ds[self._weights.name] = self._weights
@@ -702,7 +706,10 @@ def _validate_weights(
                 )
 
     def _averager(
-        self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...]
+        self,
+        data_var: xr.DataArray,
+        axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
+        required_weight: Optional[float] = 0.0,
     ):
         """Perform a weighted average of a data variable.
 
@@ -721,6 +728,9 @@ def _averager(
             Data variable inside a Dataset.
         axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
             List of axis dimensions to average over.
+        required_weight : optional, float
+            Fraction of data coverage (i..e, weight) needed to return a
+            spatial average value. Value must range from 0 to 1.
 
         Returns
         -------
@@ -734,11 +744,48 @@ def _averager(
         """
         weights = self._weights.fillna(0)
 
+        # ensure required weight is between 0 and 1
+        if required_weight is None:
+            required_weight = 0.0
+
+        if required_weight < 0.0:
+            raise ValueError(
+                "required_weight argment is less than zero. "
+                "required_weight must be between 0 and 1."
+            )
+
+        if required_weight > 1.0:
+            raise ValueError(
+                "required_weight argment is greater than zero. "
+                "required_weight must be between 0 and 1."
+            )
+
+        # need weights to match data_var dimensionality
+        if required_weight > 0.0:
+            weights, data_var = xr.broadcast(weights, data_var)
+
+        # get averaging dimensions
         dim = []
         for key in axis:
             dim.append(get_dim_keys(data_var, key))
 
+        # compute weighed mean
         with xr.set_options(keep_attrs=True):
             weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)
 
+        # if weight thresholds applied, calculate fraction of data availability
+        # replace values that do not meet minimum weight with nan
+        if required_weight > 0.0:
+            # sum all weights (assuming no missing values exist)
+            print(dim)
+            weight_sum_all = weights.sum(dim=dim)  # type: ignore
+            # zero out cells with missing values in data_var
+            weights = xr.where(~np.isnan(data_var), weights, 0)
+            # sum all weights (including zero for missing values)
+            weight_sum_masked = weights.sum(dim=dim)  # type: ignore
+            # get fraction of weight available
+            frac = weight_sum_masked / weight_sum_all
+            # nan out values that don't meet specified weight threshold
+            weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan)
+
         return weighted_mean

From f69853e8d22f9ef4e19eab32d73d33bb7fa6024b Mon Sep 17 00:00:00 2001
From: Stephen Po-Chedley <pochedley@gmail.com>
Date: Fri, 28 Jun 2024 13:08:26 -0700
Subject: [PATCH 02/11] cleanup print statement and complete code coverage

---
 tests/test_spatial.py | 20 ++++++++++++++++++++
 xcdat/spatial.py      |  1 -
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/tests/test_spatial.py b/tests/test_spatial.py
index 4f27b226..ea9f44a6 100644
--- a/tests/test_spatial.py
+++ b/tests/test_spatial.py
@@ -288,6 +288,26 @@ def test_spatial_average_with_required_weight(self):
 
         xr.testing.assert_allclose(result, expected)
 
+    def test_spatial_average_with_required_weight_as_None(self):
+        ds = self.ds.copy()
+
+        result = ds.spatial.average(
+            "ts",
+            axis=["X", "Y"],
+            lat_bounds=(-5.0, 5),
+            lon_bounds=(-170, -120.1),
+            required_weight=None,
+        )
+
+        expected = self.ds.copy()
+        expected["ts"] = xr.DataArray(
+            data=np.array([2.25, 1.0, 1.0]),
+            coords={"time": expected.time},
+            dims="time",
+        )
+
+        xr.testing.assert_allclose(result, expected)
+
     def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self):
         ds = self.ds.copy()
 
diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index 1beab80a..07b8eab6 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -777,7 +777,6 @@ def _averager(
         # replace values that do not meet minimum weight with nan
         if required_weight > 0.0:
             # sum all weights (assuming no missing values exist)
-            print(dim)
             weight_sum_all = weights.sum(dim=dim)  # type: ignore
             # zero out cells with missing values in data_var
             weights = xr.where(~np.isnan(data_var), weights, 0)

From 56dc0f98010dd2ba0b6c1f438dc913b84eb5e7b1 Mon Sep 17 00:00:00 2001
From: Stephen Po-Chedley <pochedley@gmail.com>
Date: Fri, 23 Aug 2024 11:21:24 -0700
Subject: [PATCH 03/11] Apply review suggestion.

Co-authored-by: Tom Vo <tomvothecoder@gmail.com>
---
 xcdat/spatial.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index 07b8eab6..d34d34de 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -747,19 +747,17 @@ def _averager(
         # ensure required weight is between 0 and 1
         if required_weight is None:
             required_weight = 0.0
-
-        if required_weight < 0.0:
+        elif required_weight < 0.0:
             raise ValueError(
-                "required_weight argment is less than zero. "
+                "required_weight argument is less than 0. "
                 "required_weight must be between 0 and 1."
             )
-
-        if required_weight > 1.0:
+        elif required_weight > 1.0:
             raise ValueError(
-                "required_weight argment is greater than zero. "
+                "required_weight argument is greater than 1. "
                 "required_weight must be between 0 and 1."
             )
-
+            
         # need weights to match data_var dimensionality
         if required_weight > 0.0:
             weights, data_var = xr.broadcast(weights, data_var)

From dba659e45085b1fa46bb0d288c69b6398cc46d08 Mon Sep 17 00:00:00 2001
From: Stephen Po-Chedley <pochedley@gmail.com>
Date: Fri, 23 Aug 2024 11:34:37 -0700
Subject: [PATCH 04/11] update required_weight argument (to minimum_weight)

---
 xcdat/spatial.py | 40 +++++++++++++++++++++-------------------
 1 file changed, 21 insertions(+), 19 deletions(-)

diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index d34d34de..7dc4075f 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -76,7 +76,7 @@ def average(
         keep_weights: bool = False,
         lat_bounds: Optional[RegionAxisBounds] = None,
         lon_bounds: Optional[RegionAxisBounds] = None,
-        required_weight: Optional[float] = 0.0,
+        minimum_weight: Optional[float] = None,
     ) -> xr.Dataset:
         """
         Calculates the spatial average for a rectilinear grid over an optionally
@@ -126,9 +126,10 @@ def average(
             ignored if ``weights`` are supplied. The lower bound can be larger
             than the upper bound (e.g., across the prime meridian, dateline), by
             default None.
-        required_weight : optional, float
+        minimum_weight : optional, float
             Fraction of data coverage (i..e, weight) needed to return a
-            spatial average value. Value must range from 0 to 1.
+            spatial average value. Value must range from 0 to 1, by default None
+            (equivalent to minimum_weight=0.0).
 
         Returns
         -------
@@ -200,7 +201,7 @@ def average(
             self._weights = weights
 
         self._validate_weights(dv, axis)
-        ds[dv.name] = self._averager(dv, axis, required_weight=required_weight)
+        ds[dv.name] = self._averager(dv, axis, minimum_weight=minimum_weight)
 
         if keep_weights:
             ds[self._weights.name] = self._weights
@@ -709,7 +710,7 @@ def _averager(
         self,
         data_var: xr.DataArray,
         axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
-        required_weight: Optional[float] = 0.0,
+        minimum_weight: Optional[float] = None,
     ):
         """Perform a weighted average of a data variable.
 
@@ -728,9 +729,10 @@ def _averager(
             Data variable inside a Dataset.
         axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
             List of axis dimensions to average over.
-        required_weight : optional, float
+        minimum_weight : optional, float
             Fraction of data coverage (i..e, weight) needed to return a
-            spatial average value. Value must range from 0 to 1.
+            spatial average value. Value must range from 0 to 1, by default None
+            (equivalent to minimum_weight=0.0).
 
         Returns
         -------
@@ -745,21 +747,21 @@ def _averager(
         weights = self._weights.fillna(0)
 
         # ensure required weight is between 0 and 1
-        if required_weight is None:
-            required_weight = 0.0
-        elif required_weight < 0.0:
+        if minimum_weight is None:
+            minimum_weight = 0.0
+        elif minimum_weight < 0.0:
             raise ValueError(
-                "required_weight argument is less than 0. "
-                "required_weight must be between 0 and 1."
+                "minimum_weight argument is less than 0. "
+                "minimum_weight must be between 0 and 1."
             )
-        elif required_weight > 1.0:
+        elif minimum_weight > 1.0:
             raise ValueError(
-                "required_weight argument is greater than 1. "
-                "required_weight must be between 0 and 1."
+                "minimum_weight argument is greater than 1. "
+                "minimum_weight must be between 0 and 1."
             )
-            
+
         # need weights to match data_var dimensionality
-        if required_weight > 0.0:
+        if minimum_weight > 0.0:
             weights, data_var = xr.broadcast(weights, data_var)
 
         # get averaging dimensions
@@ -773,7 +775,7 @@ def _averager(
 
         # if weight thresholds applied, calculate fraction of data availability
         # replace values that do not meet minimum weight with nan
-        if required_weight > 0.0:
+        if minimum_weight > 0.0:
             # sum all weights (assuming no missing values exist)
             weight_sum_all = weights.sum(dim=dim)  # type: ignore
             # zero out cells with missing values in data_var
@@ -783,6 +785,6 @@ def _averager(
             # get fraction of weight available
             frac = weight_sum_masked / weight_sum_all
             # nan out values that don't meet specified weight threshold
-            weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan)
+            weighted_mean = xr.where(frac >= minimum_weight, weighted_mean, np.nan)
 
         return weighted_mean

From 95f57a5f35a77bb59ad9fd0e76ad9c28e478cd2f Mon Sep 17 00:00:00 2001
From: Stephen Po-Chedley <pochedley@gmail.com>
Date: Fri, 23 Aug 2024 11:51:34 -0700
Subject: [PATCH 05/11] update tests for minimum_weight parameter

---
 tests/test_spatial.py | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/tests/test_spatial.py b/tests/test_spatial.py
index ea9f44a6..ea2ad27f 100644
--- a/tests/test_spatial.py
+++ b/tests/test_spatial.py
@@ -140,16 +140,16 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
         with pytest.raises(ValueError):
             self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)
 
-    def test_raises_error_if_required_weight_not_between_zero_and_one(
+    def test_raises_error_if_minimum_weight_not_between_zero_and_one(
         self,
     ):
-        # ensure error if required_weight less than zero
+        # ensure error if minimum_weight less than zero
         with pytest.raises(ValueError):
-            self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01)
+            self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=-0.01)
 
-        # ensure error if required_weight greater than 1
+        # ensure error if minimum_weight greater than 1
         with pytest.raises(ValueError):
-            self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01)
+            self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=1.01)
 
     def test_spatial_average_for_lat_region_and_keep_weights(self):
         ds = self.ds.copy()
@@ -265,7 +265,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):
 
         xr.testing.assert_allclose(result, expected)
 
-    def test_spatial_average_with_required_weight(self):
+    def test_spatial_average_with_minimum_weight(self):
         ds = self.ds.copy()
 
         # insert a nan
@@ -276,7 +276,7 @@ def test_spatial_average_with_required_weight(self):
             axis=["X", "Y"],
             lat_bounds=(-5.0, 5),
             lon_bounds=(-170, -120.1),
-            required_weight=1.0,
+            minimum_weight=1.0,
         )
 
         expected = self.ds.copy()
@@ -288,7 +288,7 @@ def test_spatial_average_with_required_weight(self):
 
         xr.testing.assert_allclose(result, expected)
 
-    def test_spatial_average_with_required_weight_as_None(self):
+    def test_spatial_average_with_minimum_weight_as_None(self):
         ds = self.ds.copy()
 
         result = ds.spatial.average(
@@ -296,7 +296,7 @@ def test_spatial_average_with_required_weight_as_None(self):
             axis=["X", "Y"],
             lat_bounds=(-5.0, 5),
             lon_bounds=(-170, -120.1),
-            required_weight=None,
+            minimum_weight=None,
         )
 
         expected = self.ds.copy()

From fd06f56990197724a5c34b24c2b8d72839619f65 Mon Sep 17 00:00:00 2001
From: Tom Vo <tomvothecoder@gmail.com>
Date: Thu, 5 Sep 2024 10:46:47 -0700
Subject: [PATCH 06/11] Updates from code review - Rename arg `minimum_weight`
 to `min_weight` - Add `_get_masked_weights()` and `_validate_min_weight()` to
 `utils.py` - Update `SpatialAccessor` to use `_get_masked_weights()` and
 `_validate_min_weight()` - Replace type annotation `Optional` with `|` -
 Extract `_mask_var_with_with_threshold()` from `_averager()` for readability

---
 tests/test_spatial.py |  18 ++---
 tests/test_utils.py   |  22 +++++-
 xcdat/spatial.py      | 160 +++++++++++++++++++++++++-----------------
 xcdat/utils.py        |  60 ++++++++++++++++
 4 files changed, 185 insertions(+), 75 deletions(-)

diff --git a/tests/test_spatial.py b/tests/test_spatial.py
index ea2ad27f..244bcf31 100644
--- a/tests/test_spatial.py
+++ b/tests/test_spatial.py
@@ -140,16 +140,16 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
         with pytest.raises(ValueError):
             self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)
 
-    def test_raises_error_if_minimum_weight_not_between_zero_and_one(
+    def test_raises_error_if_min_weight_not_between_zero_and_one(
         self,
     ):
-        # ensure error if minimum_weight less than zero
+        # ensure error if min_weight less than zero
         with pytest.raises(ValueError):
-            self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=-0.01)
+            self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.01)
 
-        # ensure error if minimum_weight greater than 1
+        # ensure error if min_weight greater than 1
         with pytest.raises(ValueError):
-            self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=1.01)
+            self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.01)
 
     def test_spatial_average_for_lat_region_and_keep_weights(self):
         ds = self.ds.copy()
@@ -265,7 +265,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):
 
         xr.testing.assert_allclose(result, expected)
 
-    def test_spatial_average_with_minimum_weight(self):
+    def test_spatial_average_with_min_weight(self):
         ds = self.ds.copy()
 
         # insert a nan
@@ -276,7 +276,7 @@ def test_spatial_average_with_minimum_weight(self):
             axis=["X", "Y"],
             lat_bounds=(-5.0, 5),
             lon_bounds=(-170, -120.1),
-            minimum_weight=1.0,
+            min_weight=1.0,
         )
 
         expected = self.ds.copy()
@@ -288,7 +288,7 @@ def test_spatial_average_with_minimum_weight(self):
 
         xr.testing.assert_allclose(result, expected)
 
-    def test_spatial_average_with_minimum_weight_as_None(self):
+    def test_spatial_average_with_min_weight_as_None(self):
         ds = self.ds.copy()
 
         result = ds.spatial.average(
@@ -296,7 +296,7 @@ def test_spatial_average_with_minimum_weight_as_None(self):
             axis=["X", "Y"],
             lat_bounds=(-5.0, 5),
             lon_bounds=(-170, -120.1),
-            minimum_weight=None,
+            min_weight=None,
         )
 
         expected = self.ds.copy()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 1d4dcbe8..30d3cbfb 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,7 +1,7 @@
 import pytest
 import xarray as xr
 
-from xcdat.utils import compare_datasets, str_to_bool
+from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool
 
 
 class TestCompareDatasets:
@@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self):
 
         with pytest.raises(ValueError):
             str_to_bool("1")
+
+
+class TestValidateMinWeight:
+    def test_pass_None_returns_0(self):
+        result = _validate_min_weight(None)
+
+        assert result == 0
+
+    def test_returns_error_if_less_than_0(self):
+        with pytest.raises(ValueError):
+            _validate_min_weight(-1)
+
+    def test_returns_error_if_greater_than_1(self):
+        with pytest.raises(ValueError):
+            _validate_min_weight(1.1)
+
+    def test_returns_valid_min_weight(self):
+        result = _validate_min_weight(1)
+
+        assert result == 1
diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index 7dc4075f..20d8cc0c 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -1,5 +1,8 @@
 """Module containing geospatial averaging functions."""
+<<<<<<< HEAD
 
+=======
+>>>>>>> 34b570d6 (Updates from code review)
 from __future__ import annotations
 
 from functools import reduce
@@ -9,7 +12,6 @@
     Hashable,
     List,
     Literal,
-    Optional,
     Tuple,
     TypedDict,
     Union,
@@ -27,7 +29,11 @@
     get_dim_keys,
 )
 from xcdat.dataset import _get_data_var
-from xcdat.utils import _if_multidim_dask_array_then_load
+from xcdat.utils import (
+    _get_masked_weights,
+    _if_multidim_dask_array_then_load,
+    _validate_min_weight,
+)
 
 #: Type alias for a dictionary of axis keys mapped to their bounds.
 AxisWeights = Dict[Hashable, xr.DataArray]
@@ -74,9 +80,9 @@ def average(
         axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"),
         weights: Union[Literal["generate"], xr.DataArray] = "generate",
         keep_weights: bool = False,
-        lat_bounds: Optional[RegionAxisBounds] = None,
-        lon_bounds: Optional[RegionAxisBounds] = None,
-        minimum_weight: Optional[float] = None,
+        lat_bounds: RegionAxisBounds | None = None,
+        lon_bounds: RegionAxisBounds | None = None,
+        min_weight: float | None = None,
     ) -> xr.Dataset:
         """
         Calculates the spatial average for a rectilinear grid over an optionally
@@ -115,21 +121,21 @@ def average(
         keep_weights : bool, optional
             If calculating averages using weights, keep the weights in the
             final dataset output, by default False.
-        lat_bounds : Optional[RegionAxisBounds], optional
+        lat_bounds : RegionAxisBounds | None, optional
             A tuple of floats/ints for the regional latitude lower and upper
             boundaries. This arg is used when calculating axis weights, but is
             ignored if ``weights`` are supplied. The lower bound cannot be
             larger than the upper bound, by default None.
-        lon_bounds : Optional[RegionAxisBounds], optional
+        lon_bounds : RegionAxisBounds | None, optional
             A tuple of floats/ints for the regional longitude lower and upper
             boundaries. This arg is used when calculating axis weights, but is
             ignored if ``weights`` are supplied. The lower bound can be larger
             than the upper bound (e.g., across the prime meridian, dateline), by
             default None.
-        minimum_weight : optional, float
-            Fraction of data coverage (i..e, weight) needed to return a
+        min_weight : optional, float
+            Fraction of data coverage (i.e, weight) needed to return a
             spatial average value. Value must range from 0 to 1, by default None
-            (equivalent to minimum_weight=0.0).
+            (equivalent to ``min_weight=0.0``).
 
         Returns
         -------
@@ -189,7 +195,9 @@ def average(
         """
         ds = self._dataset.copy()
         dv = _get_data_var(ds, data_var)
+
         self._validate_axis_arg(axis)
+        min_weight = _validate_min_weight(min_weight)
 
         if isinstance(weights, str) and weights == "generate":
             if lat_bounds is not None:
@@ -201,7 +209,7 @@ def average(
             self._weights = weights
 
         self._validate_weights(dv, axis)
-        ds[dv.name] = self._averager(dv, axis, minimum_weight=minimum_weight)
+        ds[dv.name] = self._averager(dv, axis, min_weight=min_weight)
 
         if keep_weights:
             ds[self._weights.name] = self._weights
@@ -211,9 +219,9 @@ def average(
     def get_weights(
         self,
         axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
-        lat_bounds: Optional[RegionAxisBounds] = None,
-        lon_bounds: Optional[RegionAxisBounds] = None,
-        data_var: Optional[str] = None,
+        lat_bounds: RegionAxisBounds | None = None,
+        lon_bounds: RegionAxisBounds | None = None,
+        data_var: str | None = None,
     ) -> xr.DataArray:
         """
         Get area weights for specified axis keys and an optional target domain.
@@ -232,13 +240,13 @@ def get_weights(
         ----------
         axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
             List of axis dimensions to average over.
-        lat_bounds : Optional[RegionAxisBounds]
+        lat_bounds : RegionAxisBounds | None
             Tuple of latitude boundaries for regional selection, by default
             None.
-        lon_bounds : Optional[RegionAxisBounds]
+        lon_bounds : RegionAxisBounds | None
             Tuple of longitude boundaries for regional selection, by default
             None.
-        data_var: Optional[str]
+        data_var: str | None
             The key of the data variable, by default None. Pass this argument
             when the dataset has more than one bounds per axis (e.g., "lon"
             and "zlon_bnds" for the "X" axis), or you want weights for a
@@ -259,7 +267,7 @@ def get_weights(
         and pressure).
         """
         Bounds = TypedDict(
-            "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]}
+            "Bounds", {"weights_method": Callable, "region": np.ndarray | None}
         )
 
         axis_bounds: Dict[SpatialAxis, Bounds] = {
@@ -382,7 +390,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds):
             )
 
     def _get_longitude_weights(
-        self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
+        self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None
     ) -> xr.DataArray:
         """Gets weights for the longitude axis.
 
@@ -409,7 +417,7 @@ def _get_longitude_weights(
         ----------
         domain_bounds : xr.DataArray
             The array of bounds for the longitude domain.
-        region_bounds : Optional[np.ndarray]
+        region_bounds : np.ndarray | None
             The array of bounds for longitude regional selection.
 
         Returns
@@ -423,7 +431,7 @@ def _get_longitude_weights(
             If the there are multiple instances in which the
             domain_bounds[:, 0] > domain_bounds[:, 1]
         """
-        p_meridian_index: Optional[np.ndarray] = None
+        p_meridian_index: np.ndarray | None = None
         d_bounds = domain_bounds.copy()
 
         pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0]
@@ -455,7 +463,7 @@ def _get_longitude_weights(
         return weights
 
     def _get_latitude_weights(
-        self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
+        self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None
     ) -> xr.DataArray:
         """Gets weights for the latitude axis.
 
@@ -467,7 +475,7 @@ def _get_latitude_weights(
         ----------
         domain_bounds : xr.DataArray
             The array of bounds for the latitude domain.
-        region_bounds : Optional[np.ndarray]
+        region_bounds : np.ndarray | None
             The array of bounds for latitude regional selection.
 
         Returns
@@ -710,7 +718,7 @@ def _averager(
         self,
         data_var: xr.DataArray,
         axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
-        minimum_weight: Optional[float] = None,
+        min_weight: float,
     ):
         """Perform a weighted average of a data variable.
 
@@ -729,10 +737,9 @@ def _averager(
             Data variable inside a Dataset.
         axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
             List of axis dimensions to average over.
-        minimum_weight : optional, float
-            Fraction of data coverage (i..e, weight) needed to return a
-            spatial average value. Value must range from 0 to 1, by default None
-            (equivalent to minimum_weight=0.0).
+        min_weight : float
+            Fraction of data coverage (i.e, weight) needed to return a
+            spatial average value. Value must range from 0 to 1.
 
         Returns
         -------
@@ -746,45 +753,68 @@ def _averager(
         """
         weights = self._weights.fillna(0)
 
-        # ensure required weight is between 0 and 1
-        if minimum_weight is None:
-            minimum_weight = 0.0
-        elif minimum_weight < 0.0:
-            raise ValueError(
-                "minimum_weight argument is less than 0. "
-                "minimum_weight must be between 0 and 1."
-            )
-        elif minimum_weight > 1.0:
-            raise ValueError(
-                "minimum_weight argument is greater than 1. "
-                "minimum_weight must be between 0 and 1."
-            )
-
-        # need weights to match data_var dimensionality
-        if minimum_weight > 0.0:
+        # TODO: This conditional might not be needed because Xarray will
+        # automatically broadcast the weights to the data variable for
+        # operations such as .mean() and .where().
+        if min_weight > 0.0:
             weights, data_var = xr.broadcast(weights, data_var)
 
-        # get averaging dimensions
-        dim = []
+        dim: List[str] = []
         for key in axis:
-            dim.append(get_dim_keys(data_var, key))
+            dim.append(get_dim_keys(data_var, key))  # type: ignore
 
-        # compute weighed mean
         with xr.set_options(keep_attrs=True):
-            weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)
-
-        # if weight thresholds applied, calculate fraction of data availability
-        # replace values that do not meet minimum weight with nan
-        if minimum_weight > 0.0:
-            # sum all weights (assuming no missing values exist)
-            weight_sum_all = weights.sum(dim=dim)  # type: ignore
-            # zero out cells with missing values in data_var
-            weights = xr.where(~np.isnan(data_var), weights, 0)
-            # sum all weights (including zero for missing values)
-            weight_sum_masked = weights.sum(dim=dim)  # type: ignore
-            # get fraction of weight available
-            frac = weight_sum_masked / weight_sum_all
-            # nan out values that don't meet specified weight threshold
-            weighted_mean = xr.where(frac >= minimum_weight, weighted_mean, np.nan)
-
-        return weighted_mean
+            dv_mean = data_var.cf.weighted(weights).mean(dim=dim)
+
+        if min_weight > 0.0:
+            dv_mean = self._mask_var_with_weight_threshold(
+                dv_mean, dim, weights, min_weight
+            )
+
+        return dv_mean
+
+    def _mask_var_with_weight_threshold(
+        self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float
+    ) -> xr.DataArray:
+        """Mask values that do not meet the minimum weight threshold with np.nan.
+
+        This function is useful for cases where the weighting of data might be
+        skewed based on the availability of data. For example, if a portion of
+        cells in a region has significantly more missing data than other other
+        regions, it can result in inaccurate calculations of spatial averaging.
+        Masking values that do not meet the minimum weight threshold ensures
+        more accurate calculations.
+
+        Parameters
+        ----------
+        dv : xr.DataArray
+            The weighted variable.
+        dim: List[str]:
+            List of axis dimensions to average over.
+        weights : xr.DataArray
+            A DataArray containing either the regional weights used for weighted
+            averaging. ``weights`` must include the same axis dimensions and
+            dimensional sizes as the data variable.
+        min_weight : float
+            Fraction of data coverage (i.e, weight) needed to return a
+            spatial average value. Value must range from 0 to 1.
+
+        Returns
+        -------
+        xr.DataArray
+            The variable with the minimum weight threshold applied.
+        """
+        # Sum all weights, including zero for missing values.
+        weight_sum_all = weights.sum(dim=dim)
+
+        masked_weights = _get_masked_weights(dv, weights)
+        weight_sum_masked = masked_weights.sum(dim=dim)
+
+        # Get fraction of the available weight.
+        frac = weight_sum_masked / weight_sum_all
+
+        # Nan out values that don't meet specified weight threshold.
+        dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True)
+        dv_new.name = dv.name
+
+        return dv_new
diff --git a/xcdat/utils.py b/xcdat/utils.py
index 83596561..a2f674fa 100644
--- a/xcdat/utils.py
+++ b/xcdat/utils.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import importlib
 import json
 from typing import Dict, List, Optional, Union
@@ -132,3 +134,61 @@ def _if_multidim_dask_array_then_load(
         return obj.load()
 
     return None
+
+
+def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray:
+    """Get weights with missing data (`np.nan`) receiving no weight (zero).
+
+    Parameters
+    ----------
+    dv : xr.DataArray
+        The variable.
+    weights : xr.DataArray
+        A DataArray containing either the regional or temporal weights used for
+        weighted averaging. ``weights`` must include the same axis dimensions
+        and dimensional sizes as the data variable.
+
+    Returns
+    -------
+    xr.DataArray
+        The masked weights.
+    """
+    masked_weights = xr.where(dv.copy().isnull(), 0.0, weights)
+
+    return masked_weights
+
+
+def _validate_min_weight(min_weight: float | None) -> float:
+    """Validate the ``min_weight`` value.
+
+    Parameters
+    ----------
+    min_weight : float | None
+        Fraction of data coverage (i..e, weight) needed to return a
+        spatial average value. Value must range from 0 to 1.
+
+    Returns
+    -------
+    float
+        The required weight percentage.
+
+    Raises
+    ------
+    ValueError
+        If the `min_weight` argument is less than 0.
+    ValueError
+        If the `min_weight` argument is greater than 1.
+    """
+    if min_weight is None:
+        return 0.0
+    elif min_weight < 0.0:
+        raise ValueError(
+            "min_weight argument is less than 0. " "min_weight must be between 0 and 1."
+        )
+    elif min_weight > 1.0:
+        raise ValueError(
+            "min_weight argument is greater than 1. "
+            "min_weight must be between 0 and 1."
+        )
+
+    return min_weight

From 22397a4a282c923a3e5420b24aed22def2f85075 Mon Sep 17 00:00:00 2001
From: Tom Vo <tomvothecoder@gmail.com>
Date: Thu, 5 Sep 2024 11:29:33 -0700
Subject: [PATCH 07/11] Fix `TypeError` for optional `region` arg in
 `TypedDict`

---
 xcdat/spatial.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index 20d8cc0c..f3508116 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -12,6 +12,7 @@
     Hashable,
     List,
     Literal,
+    Optional,
     Tuple,
     TypedDict,
     Union,
@@ -267,7 +268,7 @@ def get_weights(
         and pressure).
         """
         Bounds = TypedDict(
-            "Bounds", {"weights_method": Callable, "region": np.ndarray | None}
+            "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]}
         )
 
         axis_bounds: Dict[SpatialAxis, Bounds] = {

From fbd0d550008facf6bdc24edcbe9ba3c59e4f4c6e Mon Sep 17 00:00:00 2001
From: Tom Vo <tomvothecoder@gmail.com>
Date: Thu, 21 Nov 2024 12:35:58 -0800
Subject: [PATCH 08/11] Remove rebase comment

---
 xcdat/spatial.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index f3508116..d105666e 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -1,8 +1,5 @@
 """Module containing geospatial averaging functions."""
-<<<<<<< HEAD
 
-=======
->>>>>>> 34b570d6 (Updates from code review)
 from __future__ import annotations
 
 from functools import reduce

From ce058ce26d87a906f66e47ceb5edc5f42c70b46c Mon Sep 17 00:00:00 2001
From: Tom Vo <tomvothecoder@gmail.com>
Date: Wed, 26 Feb 2025 10:40:03 -0700
Subject: [PATCH 09/11] Fix incorrect weight generation - Add `dv_mean`
 parameter to `_mask_var_with_weight_threshold()` and fix parameter references
 - Remove unnecessary broadcasting of variable and weights in `_averager()`

---
 xcdat/spatial.py | 31 +++++++++++++++++--------------
 1 file changed, 17 insertions(+), 14 deletions(-)

diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index d105666e..e74af7d4 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -749,30 +749,30 @@ def _averager(
         ``weights`` must be a DataArray and cannot contain missing values.
         Missing values are replaced with 0 using ``weights.fillna(0)``.
         """
+        dv = data_var.copy()
         weights = self._weights.fillna(0)
 
-        # TODO: This conditional might not be needed because Xarray will
-        # automatically broadcast the weights to the data variable for
-        # operations such as .mean() and .where().
-        if min_weight > 0.0:
-            weights, data_var = xr.broadcast(weights, data_var)
-
         dim: List[str] = []
         for key in axis:
-            dim.append(get_dim_keys(data_var, key))  # type: ignore
+            dim.append(get_dim_keys(dv, key))  # type: ignore
 
         with xr.set_options(keep_attrs=True):
-            dv_mean = data_var.cf.weighted(weights).mean(dim=dim)
+            dv_mean = dv.cf.weighted(weights).mean(dim=dim)
 
         if min_weight > 0.0:
             dv_mean = self._mask_var_with_weight_threshold(
-                dv_mean, dim, weights, min_weight
+                dv, dv_mean, dim, weights, min_weight
             )
 
         return dv_mean
 
     def _mask_var_with_weight_threshold(
-        self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float
+        self,
+        dv: xr.DataArray,
+        dv_mean: xr.DataArray,
+        dim: List[str],
+        weights: xr.DataArray,
+        min_weight: float,
     ) -> xr.DataArray:
         """Mask values that do not meet the minimum weight threshold with np.nan.
 
@@ -786,7 +786,9 @@ def _mask_var_with_weight_threshold(
         Parameters
         ----------
         dv : xr.DataArray
-            The weighted variable.
+            The weighted variable used for getting masked weights.
+        dv_mean : xr.DataArray
+            The average of the weighted variable.
         dim: List[str]:
             List of axis dimensions to average over.
         weights : xr.DataArray
@@ -800,7 +802,8 @@ def _mask_var_with_weight_threshold(
         Returns
         -------
         xr.DataArray
-            The variable with the minimum weight threshold applied.
+            The average of the weighted with the minimum weight threshold
+            applied.
         """
         # Sum all weights, including zero for missing values.
         weight_sum_all = weights.sum(dim=dim)
@@ -812,7 +815,7 @@ def _mask_var_with_weight_threshold(
         frac = weight_sum_masked / weight_sum_all
 
         # Nan out values that don't meet specified weight threshold.
-        dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True)
-        dv_new.name = dv.name
+        dv_new = xr.where(frac >= min_weight, dv_mean, np.nan, keep_attrs=True)
+        dv_new.name = dv_mean.name
 
         return dv_new

From a9c4c521acf64f465b038f126fb7f236f1c274c7 Mon Sep 17 00:00:00 2001
From: Tom Vo <tomvothecoder@gmail.com>
Date: Wed, 26 Feb 2025 10:43:14 -0700
Subject: [PATCH 10/11] Fix return docstring for
 _mask_var_with_weight_threshold()`

---
 xcdat/spatial.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/xcdat/spatial.py b/xcdat/spatial.py
index e74af7d4..6494d3f2 100644
--- a/xcdat/spatial.py
+++ b/xcdat/spatial.py
@@ -802,8 +802,8 @@ def _mask_var_with_weight_threshold(
         Returns
         -------
         xr.DataArray
-            The average of the weighted with the minimum weight threshold
-            applied.
+            The average of the weighted variable with the minimum weight
+            threshold applied.
         """
         # Sum all weights, including zero for missing values.
         weight_sum_all = weights.sum(dim=dim)

From 7033584fbb7d79cf3c4f6a8b3a881fd8bd3cc04a Mon Sep 17 00:00:00 2001
From: Tom Vo <tomvothecoder@gmail.com>
Date: Wed, 26 Feb 2025 10:47:17 -0700
Subject: [PATCH 11/11] Fix pre-commit hooks

---
 xcdat/utils.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/xcdat/utils.py b/xcdat/utils.py
index a2f674fa..8828272a 100644
--- a/xcdat/utils.py
+++ b/xcdat/utils.py
@@ -183,12 +183,11 @@ def _validate_min_weight(min_weight: float | None) -> float:
         return 0.0
     elif min_weight < 0.0:
         raise ValueError(
-            "min_weight argument is less than 0. " "min_weight must be between 0 and 1."
+            "min_weight argument is less than 0. min_weight must be between 0 and 1.",
         )
     elif min_weight > 1.0:
         raise ValueError(
-            "min_weight argument is greater than 1. "
-            "min_weight must be between 0 and 1."
+            "min_weight argument is greater than 1. min_weight must be between 0 and 1.",
         )
 
     return min_weight