|
1 | 1 | import numpy as np
|
2 |
| -from scipy._lib._util import _asarray_validated |
3 | 2 | from scipy._lib._array_api import (
|
4 | 3 | array_namespace,
|
5 | 4 | xp_size,
|
@@ -330,10 +329,11 @@ def softmax(x, axis=None):
|
330 | 329 | array([ 1., 1., 1.])
|
331 | 330 |
|
332 | 331 | """
|
333 |
| - x = _asarray_validated(x, check_finite=False) |
334 |
| - x_max = np.amax(x, axis=axis, keepdims=True) |
335 |
| - exp_x_shifted = np.exp(x - x_max) |
336 |
| - return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) |
| 332 | + xp = array_namespace(x) |
| 333 | + x = xp.asarray(x) |
| 334 | + x_max = xp.max(x, axis=axis, keepdims=True) |
| 335 | + exp_x_shifted = xp.exp(x - x_max) |
| 336 | + return exp_x_shifted / xp.sum(exp_x_shifted, axis=axis, keepdims=True) |
337 | 337 |
|
338 | 338 |
|
339 | 339 | def log_softmax(x, axis=None):
|
@@ -387,23 +387,22 @@ def log_softmax(x, axis=None):
|
387 | 387 | array([ 0., -inf])
|
388 | 388 |
|
389 | 389 | """
|
| 390 | + xp = array_namespace(x) |
| 391 | + x = xp.asarray(x) |
390 | 392 |
|
391 |
| - x = _asarray_validated(x, check_finite=False) |
392 |
| - |
393 |
| - x_max = np.amax(x, axis=axis, keepdims=True) |
| 393 | + x_max = xp.max(x, axis=axis, keepdims=True) |
394 | 394 |
|
395 | 395 | if x_max.ndim > 0:
|
396 |
| - x_max[~np.isfinite(x_max)] = 0 |
397 |
| - elif not np.isfinite(x_max): |
| 396 | + x_max = xpx.at(x_max, ~xp.isfinite(x_max)).set(0) |
| 397 | + elif not xp.isfinite(x_max): |
398 | 398 | x_max = 0
|
399 | 399 |
|
400 | 400 | tmp = x - x_max
|
401 |
| - exp_tmp = np.exp(tmp) |
| 401 | + exp_tmp = xp.exp(tmp) |
402 | 402 |
|
403 | 403 | # suppress warnings about log of zero
|
404 | 404 | with np.errstate(divide='ignore'):
|
405 |
| - s = np.sum(exp_tmp, axis=axis, keepdims=True) |
406 |
| - out = np.log(s) |
| 405 | + s = xp.sum(exp_tmp, axis=axis, keepdims=True) |
| 406 | + out = xp.log(s) |
407 | 407 |
|
408 |
| - out = tmp - out |
409 |
| - return out |
| 408 | + return tmp - out |
0 commit comments