|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import types
|
16 |
| -from typing import Any |
| 16 | +from functools import partial |
| 17 | +import operator |
| 18 | +from typing import Any, Literal, overload |
17 | 19 |
|
18 | 20 | import numpy as np
|
19 | 21 |
|
20 | 22 | import jax
|
21 | 23 | from jax import lax
|
| 24 | +from jax._src.api import jit |
22 | 25 | from jax._src import core
|
23 | 26 | from jax._src import dtypes
|
| 27 | +from jax._src.lax import lax as lax_internal |
24 | 28 | from jax._src.lib import xla_client as xc
|
| 29 | +from jax._src.numpy import ufuncs |
25 | 30 | from jax._src.numpy import util
|
26 | 31 | from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike
|
27 |
| -from jax._src.util import set_module |
| 32 | +from jax._src.util import canonicalize_axis, set_module |
28 | 33 | from jax.sharding import Sharding
|
29 | 34 |
|
30 | 35 |
|
@@ -405,3 +410,315 @@ def full_like(a: ArrayLike | DuckTypedArray,
|
405 | 410 | dtype = dtypes.result_type(a) if dtype is None else dtype
|
406 | 411 | return jax.device_put(
|
407 | 412 | util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
|
| 413 | + |
| 414 | +@overload |
| 415 | +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 416 | + endpoint: bool = True, retstep: Literal[False] = False, |
| 417 | + dtype: DTypeLike | None = None, |
| 418 | + axis: int = 0, |
| 419 | + *, device: xc.Device | Sharding | None = None) -> Array: ... |
| 420 | +@overload |
| 421 | +def linspace(start: ArrayLike, stop: ArrayLike, num: int, |
| 422 | + endpoint: bool, retstep: Literal[True], |
| 423 | + dtype: DTypeLike | None = None, |
| 424 | + axis: int = 0, |
| 425 | + *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... |
| 426 | +@overload |
| 427 | +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 428 | + endpoint: bool = True, *, retstep: Literal[True], |
| 429 | + dtype: DTypeLike | None = None, |
| 430 | + axis: int = 0, |
| 431 | + device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... |
| 432 | +@overload |
| 433 | +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 434 | + endpoint: bool = True, retstep: bool = False, |
| 435 | + dtype: DTypeLike | None = None, |
| 436 | + axis: int = 0, |
| 437 | + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... |
| 438 | +@export |
| 439 | +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 440 | + endpoint: bool = True, retstep: bool = False, |
| 441 | + dtype: DTypeLike | None = None, |
| 442 | + axis: int = 0, |
| 443 | + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: |
| 444 | + """Return evenly-spaced numbers within an interval. |
| 445 | +
|
| 446 | + JAX implementation of :func:`numpy.linspace`. |
| 447 | +
|
| 448 | + Args: |
| 449 | + start: scalar or array of starting values. |
| 450 | + stop: scalar or array of stop values. |
| 451 | + num: number of values to generate. Default: 50. |
| 452 | + endpoint: if True (default) then include the ``stop`` value in the result. |
| 453 | + If False, then exclude the ``stop`` value. |
| 454 | + retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the |
| 455 | + interval between adjacent values in ``result``. |
| 456 | + axis: integer axis along which to generate the linspace. Defaults to zero. |
| 457 | + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` |
| 458 | + to which the created array will be committed. |
| 459 | +
|
| 460 | + Returns: |
| 461 | + An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: |
| 462 | +
|
| 463 | + - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` |
| 464 | + - ``step`` is the interval between adjacent values. |
| 465 | +
|
| 466 | + See also: |
| 467 | + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting |
| 468 | + point and a step |
| 469 | + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. |
| 470 | + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. |
| 471 | +
|
| 472 | + Examples: |
| 473 | + List of 5 values between 0 and 10: |
| 474 | +
|
| 475 | + >>> jnp.linspace(0, 10, 5) |
| 476 | + Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) |
| 477 | +
|
| 478 | + List of 8 values between 0 and 10, excluding the endpoint: |
| 479 | +
|
| 480 | + >>> jnp.linspace(0, 10, 8, endpoint=False) |
| 481 | + Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) |
| 482 | +
|
| 483 | + List of values and the step size between them |
| 484 | +
|
| 485 | + >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) |
| 486 | + >>> vals |
| 487 | + Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) |
| 488 | + >>> step |
| 489 | + Array(1.25, dtype=float32) |
| 490 | +
|
| 491 | + Multi-dimensional linspace: |
| 492 | +
|
| 493 | + >>> start = jnp.array([0, 5]) |
| 494 | + >>> stop = jnp.array([5, 10]) |
| 495 | + >>> jnp.linspace(start, stop, 5) |
| 496 | + Array([[ 0. , 5. ], |
| 497 | + [ 1.25, 6.25], |
| 498 | + [ 2.5 , 7.5 ], |
| 499 | + [ 3.75, 8.75], |
| 500 | + [ 5. , 10. ]], dtype=float32) |
| 501 | + """ |
| 502 | + num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") |
| 503 | + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") |
| 504 | + return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) |
| 505 | + |
| 506 | +@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) |
| 507 | +def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 508 | + endpoint: bool = True, retstep: bool = False, |
| 509 | + dtype: DTypeLike | None = None, |
| 510 | + axis: int = 0, |
| 511 | + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: |
| 512 | + """Implementation of linspace differentiable in start and stop args.""" |
| 513 | + dtypes.check_user_dtype_supported(dtype, "linspace") |
| 514 | + if num < 0: |
| 515 | + raise ValueError(f"Number of samples, {num}, must be non-negative.") |
| 516 | + start, stop = util.ensure_arraylike("linspace", start, stop) |
| 517 | + |
| 518 | + if dtype is None: |
| 519 | + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) |
| 520 | + dtype = dtypes.jax_dtype(dtype) |
| 521 | + computation_dtype = dtypes.to_inexact_dtype(dtype) |
| 522 | + start = start.astype(computation_dtype) |
| 523 | + stop = stop.astype(computation_dtype) |
| 524 | + |
| 525 | + bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) |
| 526 | + broadcast_start = util._broadcast_to(start, bounds_shape) |
| 527 | + broadcast_stop = util._broadcast_to(stop, bounds_shape) |
| 528 | + axis = len(bounds_shape) + axis + 1 if axis < 0 else axis |
| 529 | + bounds_shape.insert(axis, 1) |
| 530 | + div = (num - 1) if endpoint else num |
| 531 | + if num > 1: |
| 532 | + delta: Array = lax.convert_element_type(stop - start, computation_dtype) / jax.numpy.array(div, dtype=computation_dtype) |
| 533 | + iota_shape = [1,] * len(bounds_shape) |
| 534 | + iota_shape[axis] = div |
| 535 | + # This approach recovers the endpoints with float32 arithmetic, |
| 536 | + # but can lead to rounding errors for integer outputs. |
| 537 | + real_dtype = dtypes.finfo(computation_dtype).dtype |
| 538 | + step = lax.iota(real_dtype, div).reshape(iota_shape) / jax.numpy.array(div, real_dtype) |
| 539 | + step = step.astype(computation_dtype) |
| 540 | + out = (broadcast_start.reshape(bounds_shape) * (1 - step) + |
| 541 | + broadcast_stop.reshape(bounds_shape) * step) |
| 542 | + |
| 543 | + if endpoint: |
| 544 | + out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], |
| 545 | + canonicalize_axis(axis, out.ndim)) |
| 546 | + |
| 547 | + elif num == 1: |
| 548 | + delta = jax.numpy.asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) |
| 549 | + out = broadcast_start.reshape(bounds_shape) |
| 550 | + else: # num == 0 degenerate case, match numpy behavior |
| 551 | + empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) |
| 552 | + empty_shape.insert(axis, 0) |
| 553 | + delta = full((), np.nan, computation_dtype) |
| 554 | + out = empty(empty_shape, dtype) |
| 555 | + |
| 556 | + if dtypes.issubdtype(dtype, np.integer) and not dtypes.issubdtype(out.dtype, np.integer): |
| 557 | + out = lax.floor(out) |
| 558 | + |
| 559 | + sharding = util.canonicalize_device_to_sharding(device) |
| 560 | + result = lax_internal._convert_element_type(out, dtype, sharding=sharding) |
| 561 | + return (result, delta) if retstep else result |
| 562 | + |
| 563 | + |
| 564 | +@export |
| 565 | +def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 566 | + endpoint: bool = True, base: ArrayLike = 10.0, |
| 567 | + dtype: DTypeLike | None = None, axis: int = 0) -> Array: |
| 568 | + """Generate logarithmically-spaced values. |
| 569 | +
|
| 570 | + JAX implementation of :func:`numpy.logspace`. |
| 571 | +
|
| 572 | + Args: |
| 573 | + start: scalar or array. Used to specify the start value. The start value is |
| 574 | + ``base ** start``. |
| 575 | + stop: scalar or array. Used to specify the stop value. The end value is |
| 576 | + ``base ** stop``. |
| 577 | + num: int, optional, default=50. Number of values to generate. |
| 578 | + endpoint: bool, optional, default=True. If True, then include the ``stop`` value |
| 579 | + in the result. If False, then exclude the ``stop`` value. |
| 580 | + base: scalar or array, optional, default=10. Specifies the base of the logarithm. |
| 581 | + dtype: optional. Specifies the dtype of the output. |
| 582 | + axis: int, optional, default=0. Axis along which to generate the logspace. |
| 583 | +
|
| 584 | + Returns: |
| 585 | + An array of logarithm. |
| 586 | +
|
| 587 | + See also: |
| 588 | + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting |
| 589 | + point and a step value. |
| 590 | + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. |
| 591 | + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. |
| 592 | +
|
| 593 | + Examples: |
| 594 | + List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 |
| 595 | + (``10 ** 2``): |
| 596 | +
|
| 597 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 598 | + ... jnp.logspace(0, 2, 5) |
| 599 | + Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) |
| 600 | +
|
| 601 | + List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 |
| 602 | + (``10 ** 2``), excluding endpoint: |
| 603 | +
|
| 604 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 605 | + ... jnp.logspace(0, 2, 5, endpoint=False) |
| 606 | + Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) |
| 607 | +
|
| 608 | + List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) |
| 609 | + with base 2: |
| 610 | +
|
| 611 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 612 | + ... jnp.logspace(0, 2, 7, base=2) |
| 613 | + Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) |
| 614 | +
|
| 615 | + Multi-dimensional logspace: |
| 616 | +
|
| 617 | + >>> start = jnp.array([0, 5]) |
| 618 | + >>> stop = jnp.array([5, 0]) |
| 619 | + >>> base = jnp.array([2, 3]) |
| 620 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 621 | + ... jnp.logspace(start, stop, 5, base=base) |
| 622 | + Array([[ 1. , 243. ], |
| 623 | + [ 2.378, 61.547], |
| 624 | + [ 5.657, 15.588], |
| 625 | + [ 13.454, 3.948], |
| 626 | + [ 32. , 1. ]], dtype=float32) |
| 627 | + """ |
| 628 | + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") |
| 629 | + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") |
| 630 | + return _logspace(start, stop, num, endpoint, base, dtype, axis) |
| 631 | + |
| 632 | +@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) |
| 633 | +def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, |
| 634 | + endpoint: bool = True, base: ArrayLike = 10.0, |
| 635 | + dtype: DTypeLike | None = None, axis: int = 0) -> Array: |
| 636 | + """Implementation of logspace differentiable in start and stop args.""" |
| 637 | + dtypes.check_user_dtype_supported(dtype, "logspace") |
| 638 | + if dtype is None: |
| 639 | + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) |
| 640 | + dtype = dtypes.jax_dtype(dtype) |
| 641 | + computation_dtype = dtypes.to_inexact_dtype(dtype) |
| 642 | + start, stop = util.ensure_arraylike("logspace", start, stop) |
| 643 | + start = start.astype(computation_dtype) |
| 644 | + stop = stop.astype(computation_dtype) |
| 645 | + lin = linspace(start, stop, num, |
| 646 | + endpoint=endpoint, retstep=False, dtype=None, axis=axis) |
| 647 | + return lax.convert_element_type(ufuncs.power(base, lin), dtype) |
| 648 | + |
| 649 | + |
| 650 | +@export |
| 651 | +def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, |
| 652 | + dtype: DTypeLike | None = None, axis: int = 0) -> Array: |
| 653 | + """Generate geometrically-spaced values. |
| 654 | +
|
| 655 | + JAX implementation of :func:`numpy.geomspace`. |
| 656 | +
|
| 657 | + Args: |
| 658 | + start: scalar or array. Specifies the starting values. |
| 659 | + stop: scalar or array. Specifies the stop values. |
| 660 | + num: int, optional, default=50. Number of values to generate. |
| 661 | + endpoint: bool, optional, default=True. If True, then include the ``stop`` value |
| 662 | + in the result. If False, then exclude the ``stop`` value. |
| 663 | + dtype: optional. Specifies the dtype of the output. |
| 664 | + axis: int, optional, default=0. Axis along which to generate the geomspace. |
| 665 | +
|
| 666 | + Returns: |
| 667 | + An array containing the geometrically-spaced values. |
| 668 | +
|
| 669 | + See also: |
| 670 | + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting |
| 671 | + point and a step value. |
| 672 | + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. |
| 673 | + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. |
| 674 | +
|
| 675 | + Examples: |
| 676 | + List 5 geometrically-spaced values between 1 and 16: |
| 677 | +
|
| 678 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 679 | + ... jnp.geomspace(1, 16, 5) |
| 680 | + Array([ 1., 2., 4., 8., 16.], dtype=float32) |
| 681 | +
|
| 682 | + List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: |
| 683 | +
|
| 684 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 685 | + ... jnp.geomspace(1, 16, 4, endpoint=False) |
| 686 | + Array([1., 2., 4., 8.], dtype=float32) |
| 687 | +
|
| 688 | + Multi-dimensional geomspace: |
| 689 | +
|
| 690 | + >>> start = jnp.array([1, 1000]) |
| 691 | + >>> stop = jnp.array([27, 1]) |
| 692 | + >>> with jnp.printoptions(precision=3, suppress=True): |
| 693 | + ... jnp.geomspace(start, stop, 4) |
| 694 | + Array([[ 1., 1000.], |
| 695 | + [ 3., 100.], |
| 696 | + [ 9., 10.], |
| 697 | + [ 27., 1.]], dtype=float32) |
| 698 | + """ |
| 699 | + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") |
| 700 | + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") |
| 701 | + return _geomspace(start, stop, num, endpoint, dtype, axis) |
| 702 | + |
| 703 | +@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) |
| 704 | +def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, |
| 705 | + dtype: DTypeLike | None = None, axis: int = 0) -> Array: |
| 706 | + """Implementation of geomspace differentiable in start and stop args.""" |
| 707 | + dtypes.check_user_dtype_supported(dtype, "geomspace") |
| 708 | + if dtype is None: |
| 709 | + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) |
| 710 | + dtype = dtypes.jax_dtype(dtype) |
| 711 | + computation_dtype = dtypes.to_inexact_dtype(dtype) |
| 712 | + start, stop = util.ensure_arraylike("geomspace", start, stop) |
| 713 | + start = start.astype(computation_dtype) |
| 714 | + stop = stop.astype(computation_dtype) |
| 715 | + |
| 716 | + sign = ufuncs.sign(start) |
| 717 | + res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), |
| 718 | + num, endpoint=endpoint, base=10.0, |
| 719 | + dtype=computation_dtype, axis=0) |
| 720 | + axis = canonicalize_axis(axis, res.ndim) |
| 721 | + if axis != 0: |
| 722 | + # res = moveaxis(res, 0, axis) |
| 723 | + res = lax.transpose(res, permutation=(*range(1, axis + 1), 0, *range(axis + 1, res.ndim))) |
| 724 | + return lax.convert_element_type(res, dtype) |
0 commit comments