Skip to content

Commit 50a007e

Browse files
authored
ENH: special: softmax / log_softmax Array API support (scipy#22633)
1 parent 737d134 commit 50a007e

File tree

4 files changed

+155
-167
lines changed

4 files changed

+155
-167
lines changed

scipy/special/_logsumexp.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from scipy._lib._util import _asarray_validated
32
from scipy._lib._array_api import (
43
array_namespace,
54
xp_size,
@@ -330,10 +329,11 @@ def softmax(x, axis=None):
330329
array([ 1., 1., 1.])
331330
332331
"""
333-
x = _asarray_validated(x, check_finite=False)
334-
x_max = np.amax(x, axis=axis, keepdims=True)
335-
exp_x_shifted = np.exp(x - x_max)
336-
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
332+
xp = array_namespace(x)
333+
x = xp.asarray(x)
334+
x_max = xp.max(x, axis=axis, keepdims=True)
335+
exp_x_shifted = xp.exp(x - x_max)
336+
return exp_x_shifted / xp.sum(exp_x_shifted, axis=axis, keepdims=True)
337337

338338

339339
def log_softmax(x, axis=None):
@@ -387,23 +387,22 @@ def log_softmax(x, axis=None):
387387
array([ 0., -inf])
388388
389389
"""
390+
xp = array_namespace(x)
391+
x = xp.asarray(x)
390392

391-
x = _asarray_validated(x, check_finite=False)
392-
393-
x_max = np.amax(x, axis=axis, keepdims=True)
393+
x_max = xp.max(x, axis=axis, keepdims=True)
394394

395395
if x_max.ndim > 0:
396-
x_max[~np.isfinite(x_max)] = 0
397-
elif not np.isfinite(x_max):
396+
x_max = xpx.at(x_max, ~xp.isfinite(x_max)).set(0)
397+
elif not xp.isfinite(x_max):
398398
x_max = 0
399399

400400
tmp = x - x_max
401-
exp_tmp = np.exp(tmp)
401+
exp_tmp = xp.exp(tmp)
402402

403403
# suppress warnings about log of zero
404404
with np.errstate(divide='ignore'):
405-
s = np.sum(exp_tmp, axis=axis, keepdims=True)
406-
out = np.log(s)
405+
s = xp.sum(exp_tmp, axis=axis, keepdims=True)
406+
out = xp.log(s)
407407

408-
out = tmp - out
409-
return out
408+
return tmp - out

scipy/special/tests/meson.build

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ python_sources = [
2424
'test_kolmogorov.py',
2525
'test_lambertw.py',
2626
'test_legendre.py',
27-
'test_log_softmax.py',
2827
'test_log1mexp.py',
2928
'test_loggamma.py',
3029
'test_logit.py',

scipy/special/tests/test_log_softmax.py

-109
This file was deleted.

0 commit comments

Comments
 (0)