Skip to content

Commit c05442e

Browse files
committed
jax.numpy: move linspace & friends into array_creation
1 parent f045bb3 commit c05442e

File tree

4 files changed

+331
-325
lines changed

4 files changed

+331
-325
lines changed

jax/_src/numpy/array_creation.py

+319-2
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,23 @@
1313
# limitations under the License.
1414

1515
import types
16-
from typing import Any
16+
from functools import partial
17+
import operator
18+
from typing import Any, Literal, overload
1719

1820
import numpy as np
1921

2022
import jax
2123
from jax import lax
24+
from jax._src.api import jit
2225
from jax._src import core
2326
from jax._src import dtypes
27+
from jax._src.lax import lax as lax_internal
2428
from jax._src.lib import xla_client as xc
29+
from jax._src.numpy import ufuncs
2530
from jax._src.numpy import util
2631
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike
27-
from jax._src.util import set_module
32+
from jax._src.util import canonicalize_axis, set_module
2833
from jax.sharding import Sharding
2934

3035

@@ -405,3 +410,315 @@ def full_like(a: ArrayLike | DuckTypedArray,
405410
dtype = dtypes.result_type(a) if dtype is None else dtype
406411
return jax.device_put(
407412
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
413+
414+
@overload
415+
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
416+
endpoint: bool = True, retstep: Literal[False] = False,
417+
dtype: DTypeLike | None = None,
418+
axis: int = 0,
419+
*, device: xc.Device | Sharding | None = None) -> Array: ...
420+
@overload
421+
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
422+
endpoint: bool, retstep: Literal[True],
423+
dtype: DTypeLike | None = None,
424+
axis: int = 0,
425+
*, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
426+
@overload
427+
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
428+
endpoint: bool = True, *, retstep: Literal[True],
429+
dtype: DTypeLike | None = None,
430+
axis: int = 0,
431+
device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
432+
@overload
433+
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
434+
endpoint: bool = True, retstep: bool = False,
435+
dtype: DTypeLike | None = None,
436+
axis: int = 0,
437+
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
438+
@export
439+
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
440+
endpoint: bool = True, retstep: bool = False,
441+
dtype: DTypeLike | None = None,
442+
axis: int = 0,
443+
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
444+
"""Return evenly-spaced numbers within an interval.
445+
446+
JAX implementation of :func:`numpy.linspace`.
447+
448+
Args:
449+
start: scalar or array of starting values.
450+
stop: scalar or array of stop values.
451+
num: number of values to generate. Default: 50.
452+
endpoint: if True (default) then include the ``stop`` value in the result.
453+
If False, then exclude the ``stop`` value.
454+
retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the
455+
interval between adjacent values in ``result``.
456+
axis: integer axis along which to generate the linspace. Defaults to zero.
457+
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
458+
to which the created array will be committed.
459+
460+
Returns:
461+
An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where:
462+
463+
- ``values`` is an array of evenly-spaced values from ``start`` to ``stop``
464+
- ``step`` is the interval between adjacent values.
465+
466+
See also:
467+
- :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting
468+
point and a step
469+
- :func:`jax.numpy.logspace`: Generate logarithmically-spaced values.
470+
- :func:`jax.numpy.geomspace`: Generate geometrically-spaced values.
471+
472+
Examples:
473+
List of 5 values between 0 and 10:
474+
475+
>>> jnp.linspace(0, 10, 5)
476+
Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
477+
478+
List of 8 values between 0 and 10, excluding the endpoint:
479+
480+
>>> jnp.linspace(0, 10, 8, endpoint=False)
481+
Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32)
482+
483+
List of values and the step size between them
484+
485+
>>> vals, step = jnp.linspace(0, 10, 9, retstep=True)
486+
>>> vals
487+
Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32)
488+
>>> step
489+
Array(1.25, dtype=float32)
490+
491+
Multi-dimensional linspace:
492+
493+
>>> start = jnp.array([0, 5])
494+
>>> stop = jnp.array([5, 10])
495+
>>> jnp.linspace(start, stop, 5)
496+
Array([[ 0. , 5. ],
497+
[ 1.25, 6.25],
498+
[ 2.5 , 7.5 ],
499+
[ 3.75, 8.75],
500+
[ 5. , 10. ]], dtype=float32)
501+
"""
502+
num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace")
503+
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
504+
return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
505+
506+
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device'))
507+
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
508+
endpoint: bool = True, retstep: bool = False,
509+
dtype: DTypeLike | None = None,
510+
axis: int = 0,
511+
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
512+
"""Implementation of linspace differentiable in start and stop args."""
513+
dtypes.check_user_dtype_supported(dtype, "linspace")
514+
if num < 0:
515+
raise ValueError(f"Number of samples, {num}, must be non-negative.")
516+
start, stop = util.ensure_arraylike("linspace", start, stop)
517+
518+
if dtype is None:
519+
dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop))
520+
dtype = dtypes.jax_dtype(dtype)
521+
computation_dtype = dtypes.to_inexact_dtype(dtype)
522+
start = start.astype(computation_dtype)
523+
stop = stop.astype(computation_dtype)
524+
525+
bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
526+
broadcast_start = util._broadcast_to(start, bounds_shape)
527+
broadcast_stop = util._broadcast_to(stop, bounds_shape)
528+
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
529+
bounds_shape.insert(axis, 1)
530+
div = (num - 1) if endpoint else num
531+
if num > 1:
532+
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / jax.numpy.array(div, dtype=computation_dtype)
533+
iota_shape = [1,] * len(bounds_shape)
534+
iota_shape[axis] = div
535+
# This approach recovers the endpoints with float32 arithmetic,
536+
# but can lead to rounding errors for integer outputs.
537+
real_dtype = dtypes.finfo(computation_dtype).dtype
538+
step = lax.iota(real_dtype, div).reshape(iota_shape) / jax.numpy.array(div, real_dtype)
539+
step = step.astype(computation_dtype)
540+
out = (broadcast_start.reshape(bounds_shape) * (1 - step) +
541+
broadcast_stop.reshape(bounds_shape) * step)
542+
543+
if endpoint:
544+
out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))],
545+
canonicalize_axis(axis, out.ndim))
546+
547+
elif num == 1:
548+
delta = jax.numpy.asarray(np.nan if endpoint else stop - start, dtype=computation_dtype)
549+
out = broadcast_start.reshape(bounds_shape)
550+
else: # num == 0 degenerate case, match numpy behavior
551+
empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
552+
empty_shape.insert(axis, 0)
553+
delta = full((), np.nan, computation_dtype)
554+
out = empty(empty_shape, dtype)
555+
556+
if dtypes.issubdtype(dtype, np.integer) and not dtypes.issubdtype(out.dtype, np.integer):
557+
out = lax.floor(out)
558+
559+
sharding = util.canonicalize_device_to_sharding(device)
560+
result = lax_internal._convert_element_type(out, dtype, sharding=sharding)
561+
return (result, delta) if retstep else result
562+
563+
564+
@export
565+
def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
566+
endpoint: bool = True, base: ArrayLike = 10.0,
567+
dtype: DTypeLike | None = None, axis: int = 0) -> Array:
568+
"""Generate logarithmically-spaced values.
569+
570+
JAX implementation of :func:`numpy.logspace`.
571+
572+
Args:
573+
start: scalar or array. Used to specify the start value. The start value is
574+
``base ** start``.
575+
stop: scalar or array. Used to specify the stop value. The end value is
576+
``base ** stop``.
577+
num: int, optional, default=50. Number of values to generate.
578+
endpoint: bool, optional, default=True. If True, then include the ``stop`` value
579+
in the result. If False, then exclude the ``stop`` value.
580+
base: scalar or array, optional, default=10. Specifies the base of the logarithm.
581+
dtype: optional. Specifies the dtype of the output.
582+
axis: int, optional, default=0. Axis along which to generate the logspace.
583+
584+
Returns:
585+
An array of logarithm.
586+
587+
See also:
588+
- :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting
589+
point and a step value.
590+
- :func:`jax.numpy.linspace`: Generate evenly-spaced values.
591+
- :func:`jax.numpy.geomspace`: Generate geometrically-spaced values.
592+
593+
Examples:
594+
List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100
595+
(``10 ** 2``):
596+
597+
>>> with jnp.printoptions(precision=3, suppress=True):
598+
... jnp.logspace(0, 2, 5)
599+
Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32)
600+
601+
List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100
602+
(``10 ** 2``), excluding endpoint:
603+
604+
>>> with jnp.printoptions(precision=3, suppress=True):
605+
... jnp.logspace(0, 2, 5, endpoint=False)
606+
Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32)
607+
608+
List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``)
609+
with base 2:
610+
611+
>>> with jnp.printoptions(precision=3, suppress=True):
612+
... jnp.logspace(0, 2, 7, base=2)
613+
Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32)
614+
615+
Multi-dimensional logspace:
616+
617+
>>> start = jnp.array([0, 5])
618+
>>> stop = jnp.array([5, 0])
619+
>>> base = jnp.array([2, 3])
620+
>>> with jnp.printoptions(precision=3, suppress=True):
621+
... jnp.logspace(start, stop, 5, base=base)
622+
Array([[ 1. , 243. ],
623+
[ 2.378, 61.547],
624+
[ 5.657, 15.588],
625+
[ 13.454, 3.948],
626+
[ 32. , 1. ]], dtype=float32)
627+
"""
628+
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace")
629+
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace")
630+
return _logspace(start, stop, num, endpoint, base, dtype, axis)
631+
632+
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
633+
def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
634+
endpoint: bool = True, base: ArrayLike = 10.0,
635+
dtype: DTypeLike | None = None, axis: int = 0) -> Array:
636+
"""Implementation of logspace differentiable in start and stop args."""
637+
dtypes.check_user_dtype_supported(dtype, "logspace")
638+
if dtype is None:
639+
dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop))
640+
dtype = dtypes.jax_dtype(dtype)
641+
computation_dtype = dtypes.to_inexact_dtype(dtype)
642+
start, stop = util.ensure_arraylike("logspace", start, stop)
643+
start = start.astype(computation_dtype)
644+
stop = stop.astype(computation_dtype)
645+
lin = linspace(start, stop, num,
646+
endpoint=endpoint, retstep=False, dtype=None, axis=axis)
647+
return lax.convert_element_type(ufuncs.power(base, lin), dtype)
648+
649+
650+
@export
651+
def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True,
652+
dtype: DTypeLike | None = None, axis: int = 0) -> Array:
653+
"""Generate geometrically-spaced values.
654+
655+
JAX implementation of :func:`numpy.geomspace`.
656+
657+
Args:
658+
start: scalar or array. Specifies the starting values.
659+
stop: scalar or array. Specifies the stop values.
660+
num: int, optional, default=50. Number of values to generate.
661+
endpoint: bool, optional, default=True. If True, then include the ``stop`` value
662+
in the result. If False, then exclude the ``stop`` value.
663+
dtype: optional. Specifies the dtype of the output.
664+
axis: int, optional, default=0. Axis along which to generate the geomspace.
665+
666+
Returns:
667+
An array containing the geometrically-spaced values.
668+
669+
See also:
670+
- :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting
671+
point and a step value.
672+
- :func:`jax.numpy.linspace`: Generate evenly-spaced values.
673+
- :func:`jax.numpy.logspace`: Generate logarithmically-spaced values.
674+
675+
Examples:
676+
List 5 geometrically-spaced values between 1 and 16:
677+
678+
>>> with jnp.printoptions(precision=3, suppress=True):
679+
... jnp.geomspace(1, 16, 5)
680+
Array([ 1., 2., 4., 8., 16.], dtype=float32)
681+
682+
List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``:
683+
684+
>>> with jnp.printoptions(precision=3, suppress=True):
685+
... jnp.geomspace(1, 16, 4, endpoint=False)
686+
Array([1., 2., 4., 8.], dtype=float32)
687+
688+
Multi-dimensional geomspace:
689+
690+
>>> start = jnp.array([1, 1000])
691+
>>> stop = jnp.array([27, 1])
692+
>>> with jnp.printoptions(precision=3, suppress=True):
693+
... jnp.geomspace(start, stop, 4)
694+
Array([[ 1., 1000.],
695+
[ 3., 100.],
696+
[ 9., 10.],
697+
[ 27., 1.]], dtype=float32)
698+
"""
699+
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace")
700+
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace")
701+
return _geomspace(start, stop, num, endpoint, dtype, axis)
702+
703+
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
704+
def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True,
705+
dtype: DTypeLike | None = None, axis: int = 0) -> Array:
706+
"""Implementation of geomspace differentiable in start and stop args."""
707+
dtypes.check_user_dtype_supported(dtype, "geomspace")
708+
if dtype is None:
709+
dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop))
710+
dtype = dtypes.jax_dtype(dtype)
711+
computation_dtype = dtypes.to_inexact_dtype(dtype)
712+
start, stop = util.ensure_arraylike("geomspace", start, stop)
713+
start = start.astype(computation_dtype)
714+
stop = stop.astype(computation_dtype)
715+
716+
sign = ufuncs.sign(start)
717+
res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign),
718+
num, endpoint=endpoint, base=10.0,
719+
dtype=computation_dtype, axis=0)
720+
axis = canonicalize_axis(axis, res.ndim)
721+
if axis != 0:
722+
# res = moveaxis(res, 0, axis)
723+
res = lax.transpose(res, permutation=(*range(1, axis + 1), 0, *range(axis + 1, res.ndim)))
724+
return lax.convert_element_type(res, dtype)

0 commit comments

Comments
 (0)