|
2 | 2 |
|
3 | 3 | from cubed.array_api.data_type_functions import isdtype |
4 | 4 | from cubed.array_api.dtypes import ( |
5 | | - _floating_dtypes, |
6 | | - _real_floating_dtypes, |
| 5 | + _integer_dtypes, |
7 | 6 | _real_numeric_dtypes, |
8 | 7 | _upcast_integral_dtypes, |
9 | 8 | ) |
@@ -96,12 +95,17 @@ def max(x, /, *, axis=None, keepdims=False, split_every=None): |
96 | 95 |
|
97 | 96 |
|
98 | 97 | def mean(x, /, *, axis=None, keepdims=False, split_every=None): |
99 | | - if x.dtype not in _floating_dtypes: |
100 | | - raise TypeError("Only floating-point dtypes are allowed in mean") |
101 | 98 | # This implementation uses a Zarr group of two arrays to store a |
102 | 99 | # pair of fields needed to keep per-chunk counts and totals for computing |
103 | 100 | # the mean. |
104 | | - dtype = x.dtype |
| 101 | + if x.dtype in _integer_dtypes: |
| 102 | + # From the spec: "if the input array x has an integer data type, |
| 103 | + # the returned array must have the default real-valued floating-point data type" |
| 104 | + dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)[ |
| 105 | + "real floating" |
| 106 | + ] |
| 107 | + else: |
| 108 | + dtype = x.dtype |
105 | 109 | # TODO(#658): Should these be default dtypes? |
106 | 110 | if isdtype(x.dtype, "complex floating"): |
107 | 111 | intermediate_dtype = [("n", nxp.int64), ("total", nxp.complex128)] |
@@ -252,10 +256,16 @@ def var( |
252 | 256 | split_every=None, |
253 | 257 | ): |
254 | 258 | # This implementation follows https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm |
255 | | - |
256 | | - if x.dtype not in _real_floating_dtypes: |
257 | | - raise TypeError("Only real floating-point dtypes are allowed in var") |
258 | | - dtype = x.dtype |
| 259 | + if x.dtype not in _real_numeric_dtypes: |
| 260 | + raise TypeError("Only real numeric dtypes are allowed in var") |
| 261 | + if x.dtype in _integer_dtypes: |
| 262 | + # From the spec: "if the input array x has an integer data type, |
| 263 | + # the returned array must have the default real-valued floating-point data type" |
| 264 | + dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)[ |
| 265 | + "real floating" |
| 266 | + ] |
| 267 | + else: |
| 268 | + dtype = x.dtype |
259 | 269 | # TODO(#658): Should these be default dtypes? |
260 | 270 | intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)] |
261 | 271 | extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction) |
|
0 commit comments