Skip to content

Match numpy interface for empty reductions #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 61 additions & 34 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@

# {{{ fake numpy

class _NoValue:
pass


class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
def _get_fake_numpy_linalg_namespace(self):
return _PyOpenCLFakeNumpyLinalgNamespace(self._array_context)
Expand Down Expand Up @@ -110,58 +114,81 @@ def where_inner(inner_crit, inner_then, inner_else):

return rec_multimap_array_container(where_inner, criterion, then, else_)

def sum(self, a, axis=None, dtype=None):

@staticmethod
def _map_reduce(reduce_func, map_func,
ary, *,
axis, dtype, initial):
if isinstance(axis, int):
axis = axis,

def _rec_sum(ary):
def _reduce_func(partial_results):
if initial is _NoValue:
try:
return reduce(reduce_func, partial_results)
except TypeError as exc:
if "empty sequence" in str(exc):
raise ValueError("zero-size reduction operation "
"without supplied 'initial' value")
else:
raise
else:
return reduce(reduce_func, partial_results, initial)

def _map_func(ary):
if dtype not in [None, ary.dtype]:
raise NotImplementedError
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Sum over '{axis}' axes not supported.")
raise NotImplementedError(
f"Mapping over '{axis}' axes not supported.")
if initial is _NoValue:
return map_func(ary)
else:
return map_func(ary, initial=initial)

return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)
return rec_map_reduce_array_container(
_reduce_func,
_map_func,
ary)

result = rec_map_reduce_array_container(sum, _rec_sum, a)
# * appears where positional signature starts diverging from numpy
def sum(self, a, axis=None, dtype=None, *, initial=0):
queue = self._array_context.queue
import operator
result = self._map_reduce(
operator.add,
partial(cl_array.sum, queue=queue),
a,
axis=axis, dtype=dtype, initial=initial)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

def min(self, a, axis=None):
# * appears where positional signature starts diverging from numpy
# Use _NoValue to indicate a lack of neutral element so that the caller can
# use None as a neutral element
def min(self, a, axis=None, *, initial=_NoValue):
queue = self._array_context.queue

if isinstance(axis, int):
axis = axis,

def _rec_min(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
return cl_array.min(ary, queue=queue)

result = rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
_rec_min,
a)
result = self._map_reduce(
partial(cl_array.minimum, queue=queue),
partial(cl_array.min, queue=queue),
a,
axis=axis, dtype=None, initial=initial)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

def max(self, a, axis=None):
# * appears where positional signature starts diverging from numpy
# Use _NoValue to indicate a lack of neutral element so that the caller can
# use None as a neutral element
def max(self, a, axis=None, *, initial=_NoValue):
queue = self._array_context.queue

if isinstance(axis, int):
axis = axis,

def _rec_max(ary):
if axis not in [None, tuple(range(ary.ndim))]:
raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
return cl_array.max(ary, queue=queue)

result = rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_rec_max,
a)
result = self._map_reduce(
partial(cl_array.maximum, queue=queue),
partial(cl_array.max, queue=queue),
a,
axis=axis, dtype=None, initial=initial)

if not self._array_context._force_device_scalars:
result = result.get()[()]
Expand Down
62 changes: 50 additions & 12 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
pass


class _NoValue:
pass


class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
Expand Down Expand Up @@ -91,22 +95,56 @@ def minimum(self, x, y):
def where(self, criterion, then, else_):
return rec_multimap_array_container(pt.where, criterion, then, else_)

def sum(self, a, axis=None, dtype=None):
def _pt_sum(ary):
if dtype not in [ary.dtype, None]:
raise NotImplementedError

return pt.sum(ary, axis=axis)
@staticmethod
def _map_reduce(reduce_func, map_func,
ary, *,
axis, dtype, initial):
def _reduce_func(partial_results):
if initial is _NoValue:
try:
return reduce(reduce_func, partial_results)
except TypeError as exc:
if "empty sequence" in str(exc):
raise ValueError("zero-size reduction operation "
"without supplied 'initial' value")
else:
raise
else:
return reduce(reduce_func, partial_results, initial)

return rec_map_reduce_array_container(sum, _pt_sum, a)
def _map_func(ary):
if dtype not in [None, ary.dtype]:
raise NotImplementedError

def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
if initial is _NoValue:
return map_func(ary, axis=axis)
else:
return map_func(ary, axis=axis, initial=initial)

def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
_reduce_func,
_map_func,
ary)

# * appears where positional signature starts diverging from numpy
def sum(self, a, axis=None, dtype=None, *, initial=0):
import operator
return self._map_reduce(operator.add, pt.sum, a,
axis=axis, dtype=dtype, initial=initial)

# * appears where positional signature starts diverging from numpy
# Use _NoValue to indicate a lack of neutral element so that the caller can
# use None as a neutral element
def min(self, a, axis=None, *, initial=_NoValue):
return self._map_reduce(pt.minimum, pt.amin, a,
axis=axis, dtype=None, initial=initial)

# * appears where positional signature starts diverging from numpy
# Use _NoValue to indicate a lack of neutral element so that the caller can
# use None as a neutral element
def max(self, a, axis=None, *, initial=_NoValue):
return self._map_reduce(pt.maximum, pt.amax, a,
axis=axis, dtype=None, initial=initial)

def stack(self, arrays, axis=0):
return rec_multimap_array_container(
Expand Down