Skip to content

Commit e53fa90

Browse files
committed
Implement np.linspace in fake_numpy
1 parent 500573a commit e53fa90

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

arraycontext/fake_numpy.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
"""
2424

2525

26+
import operator
27+
from typing import Any
28+
2629
import numpy as np
2730

2831
from arraycontext.container import NotAnArrayContainerError, serialize_container
@@ -100,6 +103,89 @@ def conjugate(self, x):
100103

101104
conj = conjugate
102105

106+
# {{{ linspace
107+
108+
# based on
109+
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182
110+
111+
def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
112+
axis=0):
113+
num = operator.index(num)
114+
if num < 0:
115+
raise ValueError("Number of samples, %s, must be non-negative." % num)
116+
div = (num - 1) if endpoint else num
117+
118+
# Convert float/complex array scalars to float, gh-3504
119+
# and make sure one can use variables that have an __array_interface__,
120+
# gh-6634
121+
122+
if isinstance(start, self._array_context.array_types):
123+
raise NotImplementedError("start as an actx array")
124+
if isinstance(stop, self._array_context.array_types):
125+
raise NotImplementedError("stop as an actx array")
126+
127+
start = np.array(start) * 1.0
128+
stop = np.array(stop) * 1.0
129+
130+
dt = np.result_type(start, stop, float(num))
131+
if dtype is None:
132+
dtype = dt
133+
integer_dtype = False
134+
else:
135+
integer_dtype = np.issubdtype(dtype, np.integer)
136+
137+
delta = stop - start
138+
139+
y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)
140+
141+
if div > 0:
142+
step = delta / div
143+
#any_step_zero = _nx.asanyarray(step == 0).any()
144+
any_step_zero = self._array_context.to_numpy((step == 0)).any()
145+
if any_step_zero:
146+
delta_actx = self._array_context.from_numpy(delta)
147+
148+
# Special handling for denormal numbers, gh-5437
149+
y = y / div
150+
y = y * delta_actx
151+
else:
152+
step_actx = self._array_context.from_numpy(step)
153+
y = y * step_actx
154+
else:
155+
delta_actx = self._array_context.from_numpy(delta)
156+
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
157+
# have an undefined step
158+
step = np.NaN
159+
# Multiply with delta to allow possible override of output class.
160+
y = y * delta_actx
161+
162+
y += start
163+
164+
# FIXME reenable, without in-place ops
165+
# if endpoint and num > 1:
166+
# y[-1, ...] = stop
167+
168+
if axis != 0:
169+
# y = _nx.moveaxis(y, 0, axis)
170+
raise NotImplementedError("axis != 0")
171+
172+
if integer_dtype:
173+
y = self.floor(y) # pylint: disable=no-member
174+
175+
# FIXME: Use astype
176+
# https://github.com/inducer/pytato/issues/456
177+
if retstep:
178+
return y, step
179+
#return y.astype(dtype), step
180+
else:
181+
return y
182+
#return y.astype(dtype)
183+
184+
# }}}
185+
186+
def arange(self, *args: Any, **kwargs: Any):
187+
raise NotImplementedError
188+
103189
# }}}
104190

105191

@@ -180,6 +266,7 @@ def norm(self, ary, ord=None):
180266
return actx.np.sum(abs(ary)**ord)**(1/ord)
181267
else:
182268
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
269+
183270
# }}}
184271

185272

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
THE SOFTWARE.
2323
"""
2424
from functools import partial, reduce
25+
from typing import Any
2526

2627
import numpy as np
2728

test/test_arraycontext.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,28 @@ def test_compile_anonymous_function(actx_factory):
15801580
42)
15811581

15821582

1583+
@pytest.mark.parametrize(
1584+
("args", "kwargs"), [
1585+
((1, 2, 10), {}),
1586+
((1, 2, 10), {"endpoint": False}),
1587+
((1, 2, 10), {"endpoint": True}),
1588+
((2, -3, 20), {}),
1589+
((1, 5j, 20), {"dtype": np.complex128}),
1590+
((1, 5, 20), {"dtype": np.complex128}),
1591+
((1, 5, 20), {"dtype": np.int32}),
1592+
])
1593+
def test_linspace(actx_factory, args, kwargs):
1594+
if "Jax" in actx_factory.__class__.__name__:
1595+
pytest.xfail("jax actx does not have arange")
1596+
1597+
actx = actx_factory()
1598+
1599+
actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs))
1600+
np_linspace = np.linspace(*args, **kwargs)
1601+
1602+
assert np.allclose(actx_linspace, np_linspace)
1603+
1604+
15831605
if __name__ == "__main__":
15841606
import sys
15851607
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)