diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 4e9d48be..7d1de932 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -50,6 +50,10 @@ # {{{ fake numpy +class _NoValue: + pass + + class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): def _get_fake_numpy_linalg_namespace(self): return _PyOpenCLFakeNumpyLinalgNamespace(self._array_context) @@ -110,58 +114,81 @@ def where_inner(inner_crit, inner_then, inner_else): return rec_multimap_array_container(where_inner, criterion, then, else_) - def sum(self, a, axis=None, dtype=None): - + @staticmethod + def _map_reduce(reduce_func, map_func, + ary, *, + axis, dtype, initial): if isinstance(axis, int): axis = axis, - def _rec_sum(ary): + def _reduce_func(partial_results): + if initial is _NoValue: + try: + return reduce(reduce_func, partial_results) + except TypeError as exc: + if "empty sequence" in str(exc): + raise ValueError("zero-size reduction operation " + "without supplied 'initial' value") + else: + raise + else: + return reduce(reduce_func, partial_results, initial) + + def _map_func(ary): + if dtype not in [None, ary.dtype]: + raise NotImplementedError if axis not in [None, tuple(range(ary.ndim))]: - raise NotImplementedError(f"Sum over '{axis}' axes not supported.") + raise NotImplementedError( + f"Mapping over '{axis}' axes not supported.") + if initial is _NoValue: + return map_func(ary) + else: + return map_func(ary, initial=initial) - return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue) + return rec_map_reduce_array_container( + _reduce_func, + _map_func, + ary) - result = rec_map_reduce_array_container(sum, _rec_sum, a) + # * appears where positional signature starts diverging from numpy + def sum(self, a, axis=None, dtype=None, *, initial=0): + queue = self._array_context.queue + import operator + result = self._map_reduce( + operator.add, + partial(cl_array.sum, queue=queue), + a, + axis=axis, dtype=dtype, initial=initial) if not self._array_context._force_device_scalars: result = result.get()[()] return result - def min(self, a, axis=None): + # * appears where positional signature starts diverging from numpy + # Use _NoValue to indicate a lack of neutral element so that the caller can + # use None as a neutral element + def min(self, a, axis=None, *, initial=_NoValue): queue = self._array_context.queue - - if isinstance(axis, int): - axis = axis, - - def _rec_min(ary): - if axis not in [None, tuple(range(ary.ndim))]: - raise NotImplementedError(f"Min. over '{axis}' axes not supported.") - return cl_array.min(ary, queue=queue) - - result = rec_map_reduce_array_container( - partial(reduce, partial(cl_array.minimum, queue=queue)), - _rec_min, - a) + result = self._map_reduce( + partial(cl_array.minimum, queue=queue), + partial(cl_array.min, queue=queue), + a, + axis=axis, dtype=None, initial=initial) if not self._array_context._force_device_scalars: result = result.get()[()] return result - def max(self, a, axis=None): + # * appears where positional signature starts diverging from numpy + # Use _NoValue to indicate a lack of neutral element so that the caller can + # use None as a neutral element + def max(self, a, axis=None, *, initial=_NoValue): queue = self._array_context.queue - - if isinstance(axis, int): - axis = axis, - - def _rec_max(ary): - if axis not in [None, tuple(range(ary.ndim))]: - raise NotImplementedError(f"Max. over '{axis}' axes not supported.") - return cl_array.max(ary, queue=queue) - - result = rec_map_reduce_array_container( - partial(reduce, partial(cl_array.maximum, queue=queue)), - _rec_max, - a) + result = self._map_reduce( + partial(cl_array.maximum, queue=queue), + partial(cl_array.max, queue=queue), + a, + axis=axis, dtype=None, initial=initial) if not self._array_context._force_device_scalars: result = result.get()[()] diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 905fd0f8..b4b21ba7 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): pass +class _NoValue: + pass + + class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): """ A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`. @@ -91,22 +95,56 @@ def minimum(self, x, y): def where(self, criterion, then, else_): return rec_multimap_array_container(pt.where, criterion, then, else_) - def sum(self, a, axis=None, dtype=None): - def _pt_sum(ary): - if dtype not in [ary.dtype, None]: - raise NotImplementedError - - return pt.sum(ary, axis=axis) + @staticmethod + def _map_reduce(reduce_func, map_func, + ary, *, + axis, dtype, initial): + def _reduce_func(partial_results): + if initial is _NoValue: + try: + return reduce(reduce_func, partial_results) + except TypeError as exc: + if "empty sequence" in str(exc): + raise ValueError("zero-size reduction operation " + "without supplied 'initial' value") + else: + raise + else: + return reduce(reduce_func, partial_results, initial) - return rec_map_reduce_array_container(sum, _pt_sum, a) + def _map_func(ary): + if dtype not in [None, ary.dtype]: + raise NotImplementedError - def min(self, a, axis=None): - return rec_map_reduce_array_container( - partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) + if initial is _NoValue: + return map_func(ary, axis=axis) + else: + return map_func(ary, axis=axis, initial=initial) - def max(self, a, axis=None): return rec_map_reduce_array_container( - partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a) + _reduce_func, + _map_func, + ary) + + # * appears where positional signature starts diverging from numpy + def sum(self, a, axis=None, dtype=None, *, initial=0): + import operator + return self._map_reduce(operator.add, pt.sum, a, + axis=axis, dtype=dtype, initial=initial) + + # * appears where positional signature starts diverging from numpy + # Use _NoValue to indicate a lack of neutral element so that the caller can + # use None as a neutral element + def min(self, a, axis=None, *, initial=_NoValue): + return self._map_reduce(pt.minimum, pt.amin, a, + axis=axis, dtype=None, initial=initial) + + # * appears where positional signature starts diverging from numpy + # Use _NoValue to indicate a lack of neutral element so that the caller can + # use None as a neutral element + def max(self, a, axis=None, *, initial=_NoValue): + return self._map_reduce(pt.maximum, pt.amax, a, + axis=axis, dtype=None, initial=initial) def stack(self, arrays, axis=0): return rec_multimap_array_container(