Skip to content

Commit 32a0b67

Browse files
authored
Allow int dtypes in mean (and std, var) (#831)
1 parent 24f7575 commit 32a0b67

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

cubed/array_api/statistical_functions.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from cubed.array_api.data_type_functions import isdtype
44
from cubed.array_api.dtypes import (
5-
_floating_dtypes,
6-
_real_floating_dtypes,
5+
_integer_dtypes,
76
_real_numeric_dtypes,
87
_upcast_integral_dtypes,
98
)
@@ -96,12 +95,17 @@ def max(x, /, *, axis=None, keepdims=False, split_every=None):
9695

9796

9897
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
99-
if x.dtype not in _floating_dtypes:
100-
raise TypeError("Only floating-point dtypes are allowed in mean")
10198
# This implementation uses a Zarr group of two arrays to store a
10299
# pair of fields needed to keep per-chunk counts and totals for computing
103100
# the mean.
104-
dtype = x.dtype
101+
if x.dtype in _integer_dtypes:
102+
# From the spec: "if the input array x has an integer data type,
103+
# the returned array must have the default real-valued floating-point data type"
104+
dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)[
105+
"real floating"
106+
]
107+
else:
108+
dtype = x.dtype
105109
# TODO(#658): Should these be default dtypes?
106110
if isdtype(x.dtype, "complex floating"):
107111
intermediate_dtype = [("n", nxp.int64), ("total", nxp.complex128)]
@@ -252,10 +256,16 @@ def var(
252256
split_every=None,
253257
):
254258
# This implementation follows https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
255-
256-
if x.dtype not in _real_floating_dtypes:
257-
raise TypeError("Only real floating-point dtypes are allowed in var")
258-
dtype = x.dtype
259+
if x.dtype not in _real_numeric_dtypes:
260+
raise TypeError("Only real numeric dtypes are allowed in var")
261+
if x.dtype in _integer_dtypes:
262+
# From the spec: "if the input array x has an integer data type,
263+
# the returned array must have the default real-valued floating-point data type"
264+
dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)[
265+
"real floating"
266+
]
267+
else:
268+
dtype = x.dtype
259269
# TODO(#658): Should these be default dtypes?
260270
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
261271
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)

cubed/tests/test_array_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,12 @@ def test_mean_complex():
910910
assert_array_equal(b.compute(), np.array([1.0+1.0j, 2.0+2.0j, 3.0+3.0j]).mean())
911911

912912

913+
def test_mean_int():
914+
a = xp.asarray([1, 2, 3], chunks=(2,))
915+
b = xp.mean(a)
916+
assert_array_equal(b.compute(), np.array([1, 2, 3]).mean())
917+
918+
913919
def test_sum(spec, executor):
914920
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
915921
b = xp.sum(a)
@@ -940,6 +946,12 @@ def test_var(spec, axis, correction, keepdims):
940946
)
941947

942948

949+
def test_var_int():
950+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))
951+
b = xp.var(a)
952+
assert_array_equal(b.compute(), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).var())
953+
954+
943955
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
944956
@pytest.mark.parametrize("correction", [0.0, 1.0])
945957
@pytest.mark.parametrize("keepdims", [False, True])

0 commit comments

Comments
 (0)