|
23 | 23 | """
|
24 | 24 |
|
25 | 25 |
|
| 26 | +import operator |
| 27 | +from typing import Any |
| 28 | + |
26 | 29 | import numpy as np
|
27 | 30 |
|
28 | 31 | from arraycontext.container import NotAnArrayContainerError, serialize_container
|
@@ -100,6 +103,89 @@ def conjugate(self, x):
|
100 | 103 |
|
101 | 104 | conj = conjugate
|
102 | 105 |
|
| 106 | + # {{{ linspace |
| 107 | + |
| 108 | + # based on |
| 109 | + # https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182 |
| 110 | + |
| 111 | + def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None, |
| 112 | + axis=0): |
| 113 | + num = operator.index(num) |
| 114 | + if num < 0: |
| 115 | + raise ValueError("Number of samples, %s, must be non-negative." % num) |
| 116 | + div = (num - 1) if endpoint else num |
| 117 | + |
| 118 | + # Convert float/complex array scalars to float, gh-3504 |
| 119 | + # and make sure one can use variables that have an __array_interface__, |
| 120 | + # gh-6634 |
| 121 | + |
| 122 | + if isinstance(start, self._array_context.array_types): |
| 123 | + raise NotImplementedError("start as an actx array") |
| 124 | + if isinstance(stop, self._array_context.array_types): |
| 125 | + raise NotImplementedError("stop as an actx array") |
| 126 | + |
| 127 | + start = np.array(start) * 1.0 |
| 128 | + stop = np.array(stop) * 1.0 |
| 129 | + |
| 130 | + dt = np.result_type(start, stop, float(num)) |
| 131 | + if dtype is None: |
| 132 | + dtype = dt |
| 133 | + integer_dtype = False |
| 134 | + else: |
| 135 | + integer_dtype = np.issubdtype(dtype, np.integer) |
| 136 | + |
| 137 | + delta = stop - start |
| 138 | + |
| 139 | + y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim) |
| 140 | + |
| 141 | + if div > 0: |
| 142 | + step = delta / div |
| 143 | + #any_step_zero = _nx.asanyarray(step == 0).any() |
| 144 | + any_step_zero = self._array_context.to_numpy((step == 0)).any() |
| 145 | + if any_step_zero: |
| 146 | + delta_actx = self._array_context.from_numpy(delta) |
| 147 | + |
| 148 | + # Special handling for denormal numbers, gh-5437 |
| 149 | + y = y / div |
| 150 | + y = y * delta_actx |
| 151 | + else: |
| 152 | + step_actx = self._array_context.from_numpy(step) |
| 153 | + y = y * step_actx |
| 154 | + else: |
| 155 | + delta_actx = self._array_context.from_numpy(delta) |
| 156 | + # sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0) |
| 157 | + # have an undefined step |
| 158 | + step = np.NaN |
| 159 | + # Multiply with delta to allow possible override of output class. |
| 160 | + y = y * delta_actx |
| 161 | + |
| 162 | + y += start |
| 163 | + |
| 164 | + # FIXME reenable, without in-place ops |
| 165 | + # if endpoint and num > 1: |
| 166 | + # y[-1, ...] = stop |
| 167 | + |
| 168 | + if axis != 0: |
| 169 | + # y = _nx.moveaxis(y, 0, axis) |
| 170 | + raise NotImplementedError("axis != 0") |
| 171 | + |
| 172 | + if integer_dtype: |
| 173 | + y = self.floor(y) # pylint: disable=no-member |
| 174 | + |
| 175 | + # FIXME: Use astype |
| 176 | + # https://github.com/inducer/pytato/issues/456 |
| 177 | + if retstep: |
| 178 | + return y, step |
| 179 | + #return y.astype(dtype), step |
| 180 | + else: |
| 181 | + return y |
| 182 | + #return y.astype(dtype) |
| 183 | + |
| 184 | + # }}} |
| 185 | + |
| 186 | + def arange(self, *args: Any, **kwargs: Any): |
| 187 | + raise NotImplementedError |
| 188 | + |
103 | 189 | # }}}
|
104 | 190 |
|
105 | 191 |
|
@@ -180,6 +266,7 @@ def norm(self, ary, ord=None):
|
180 | 266 | return actx.np.sum(abs(ary)**ord)**(1/ord)
|
181 | 267 | else:
|
182 | 268 | raise NotImplementedError(f"unsupported value of 'ord': {ord}")
|
| 269 | + |
183 | 270 | # }}}
|
184 | 271 |
|
185 | 272 |
|
|
0 commit comments