@@ -65,6 +65,24 @@ def _check_ns_shape_dtype(
65
65
return desired_xp
66
66
67
67
68
+ def _prepare_for_test (array : Array , xp : ModuleType ) -> Array :
69
+ """
70
+ Ensure that the array can be compared with xp.testing or np.testing.
71
+
72
+ This involves transferring it from GPU to CPU memory, densifying it, etc.
73
+ """
74
+ if is_torch_namespace (xp ):
75
+ return array .cpu () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
76
+ if is_pydata_sparse_namespace (xp ):
77
+ return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
78
+ if is_array_api_strict_namespace (xp ):
79
+ # Note: we deliberately did not add a `.to_device` method in _typing.pyi
80
+ # even if it is required by the standard as many backends don't support it
81
+ return array .to_device (xp .Device ("CPU_DEVICE" )) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82
+ # Note: nothing to do for CuPy, because it uses a bespoke test function
83
+ return array
84
+
85
+
68
86
def xp_assert_equal (actual : Array , desired : Array , err_msg : str = "" ) -> None :
69
87
"""
70
88
Array-API compatible version of `np.testing.assert_array_equal`.
@@ -84,6 +102,8 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
84
102
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
85
103
"""
86
104
xp = _check_ns_shape_dtype (actual , desired )
105
+ actual = _prepare_for_test (actual , xp )
106
+ desired = _prepare_for_test (desired , xp )
87
107
88
108
if is_cupy_namespace (xp ):
89
109
xp .testing .assert_array_equal (actual , desired , err_msg = err_msg )
@@ -102,22 +122,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
102
122
else :
103
123
import numpy as np # pylint: disable=import-outside-toplevel
104
124
105
- if is_pydata_sparse_namespace (xp ):
106
- actual = actual .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107
- desired = desired .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
108
-
109
- actual_np = None
110
- desired_np = None
111
- if is_array_api_strict_namespace (xp ):
112
- # __array__ doesn't work on array-api-strict device arrays
113
- # We need to convert to the CPU device first
114
- actual_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
115
- desired_np = np .asarray (xp .asarray (desired , device = xp .Device ("CPU_DEVICE" )))
116
-
117
- # JAX/Dask arrays work with `np.testing`
118
- actual_np = actual if actual_np is None else actual_np
119
- desired_np = desired if desired_np is None else desired_np
120
- np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg ) # pyright: ignore[reportUnknownArgumentType]
125
+ np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
121
126
122
127
123
128
def xp_assert_close (
@@ -165,6 +170,9 @@ def xp_assert_close(
165
170
elif rtol is None :
166
171
rtol = 1e-7
167
172
173
+ actual = _prepare_for_test (actual , xp )
174
+ desired = _prepare_for_test (desired , xp )
175
+
168
176
if is_cupy_namespace (xp ):
169
177
xp .testing .assert_allclose (
170
178
actual , desired , rtol = rtol , atol = atol , err_msg = err_msg
@@ -176,26 +184,11 @@ def xp_assert_close(
176
184
else :
177
185
import numpy as np # pylint: disable=import-outside-toplevel
178
186
179
- if is_pydata_sparse_namespace (xp ):
180
- actual = actual .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
181
- desired = desired .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
182
-
183
- actual_np = None
184
- desired_np = None
185
- if is_array_api_strict_namespace (xp ):
186
- # __array__ doesn't work on array-api-strict device arrays
187
- # We need to convert to the CPU device first
188
- actual_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
189
- desired_np = np .asarray (xp .asarray (desired , device = xp .Device ("CPU_DEVICE" )))
190
-
191
- # JAX/Dask arrays work with `np.testing`
192
- actual_np = actual if actual_np is None else actual_np
193
- desired_np = desired if desired_np is None else desired_np
194
-
187
+ # JAX/Dask arrays work directly with `np.testing`
195
188
assert isinstance (rtol , float )
196
- np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
197
- actual_np , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198
- desired_np , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
189
+ np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190
+ actual , # pyright: ignore[reportArgumentType]
191
+ desired , # pyright: ignore[reportArgumentType]
199
192
rtol = rtol ,
200
193
atol = atol ,
201
194
err_msg = err_msg ,
0 commit comments