Skip to content

Commit cf073e0

Browse files
MAINT: Bump Array API to 2024.12 (scipy#22687)
* Bump Array API to 2024.12 * nit [skip ci] --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent f2445c5 commit cf073e0

File tree

7 files changed

+7
-14
lines changed

7 files changed

+7
-14
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ test = [
8282
"scikit-umfpack",
8383
"pooch",
8484
"hypothesis>=6.30",
85-
"array-api-strict>=2.0,<2.1.1",
85+
"array-api-strict>=2.3",
8686
"Cython",
8787
"meson",
8888
'ninja; sys_platform != "emscripten"',

pytest.ini

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ filterwarnings =
1717
ignore:.*`numpy.core` has been made officially private.*:DeprecationWarning
1818
ignore:.*In the future `np.long` will be defined as.*:FutureWarning
1919
ignore:.*JAX is multithreaded.*:RuntimeWarning
20-
ignore:.*The 2023.12 version of the array API specification is still preliminary.*:UserWarning
2120
ignore:^Using the slower implmentation::cupy
2221
ignore:Using the slower implementation::cupy
2322
ignore:Jitify is performing a one-time only warm-up::cupy

scipy/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def num_parallel_threads():
143143
try:
144144
import array_api_strict
145145
xp_available_backends.update({'array_api_strict': array_api_strict})
146-
if _pep440.parse(array_api_strict.__version__) < _pep440.Version('2.0'):
147-
raise ImportError("array-api-strict must be >= version 2.0")
146+
if _pep440.parse(array_api_strict.__version__) < _pep440.Version('2.3'):
147+
raise ImportError("array-api-strict must be >= version 2.3")
148148
array_api_strict.set_array_api_strict_flags(
149149
api_version='2024.12'
150150
)

scipy/integrate/_tanhsinh.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,7 @@ def customize_result(res, shape):
470470
# If the integration limits were such that b < a, we reversed them
471471
# to perform the calculation, and the final result needs to be negated.
472472
if log and xp.any(negative):
473-
dtype = res['integral'].dtype
474-
pi = xp.asarray(xp.pi, dtype=dtype)[()]
475-
j = xp.asarray(1j, dtype=xp.complex64)[()] # minimum complex type
476-
res['integral'] = res['integral'] + negative*pi*j
473+
res['integral'] = res['integral'] + negative * xp.pi * 1.0j
477474
else:
478475
res['integral'][negative] *= -1
479476

scipy/signal/tests/test_windows.py

-1
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,6 @@ def test_basic(self, xp):
499499
0.5985765418119844], dtype=xp.float64))
500500

501501

502-
@skip_xp_backends("torch", reason="implementation needs 2023.12 standard")
503502
class TestKaiserBesselDerived:
504503

505504
def test_basic(self, xp):

scipy/special/_logsumexp.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ def _logsumexp(a, b, axis, return_sign, xp):
225225
m = xp.abs(m)
226226
else:
227227
# `a_max` can have a sign component for complex input
228-
j = xp.asarray(1j, dtype=a_max.dtype)
229-
sgn = sgn * xp.exp(xp.imag(a_max) * j)
228+
sgn = sgn * xp.exp(xp.imag(a_max) * 1.0j)
230229

231230
# Take log and undo shift
232231
out = xp.log1p(s) + xp.log(m) + a_max

scipy/stats/_morestats.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2892,14 +2892,13 @@ def bartlett(*samples, axis=0):
28922892
ssq = [arr[xp.newaxis, ...] for arr in ssq]
28932893
Ni = xp.concat(Ni, axis=0)
28942894
ssq = xp.concat(ssq, axis=0)
2895-
# sum dtype can be removed when 2023.12 rules kick in
28962895
dtype = Ni.dtype
2897-
Ntot = xp.sum(Ni, axis=0, dtype=dtype)
2896+
Ntot = xp.sum(Ni, axis=0)
28982897
spsq = xp.sum((Ni - 1)*ssq, axis=0, dtype=dtype) / (Ntot - k)
28992898
numer = ((Ntot - k) * xp.log(spsq)
29002899
- xp.sum((Ni - 1)*xp.log(ssq), axis=0, dtype=dtype))
29012900
denom = (1 + 1/(3*(k - 1))
2902-
* ((xp.sum(1/(Ni - 1), axis=0, dtype=dtype)) - 1/(Ntot - k)))
2901+
* ((xp.sum(1/(Ni - 1), axis=0)) - 1/(Ntot - k)))
29032902
T = numer / denom
29042903

29052904
chi2 = _SimpleChi2(xp.asarray(k-1))

0 commit comments

Comments
 (0)