Skip to content

Commit c677cc9

Browse files
committed
Unify pytato reduction handling, support 'initial' arg
1 parent 1810b0e commit c677cc9

File tree

1 file changed

+43
-11
lines changed

1 file changed

+43
-11
lines changed

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
4242
pass
4343

4444

45+
class _NoValue:
46+
pass
47+
48+
4549
class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
4650
"""
4751
A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
@@ -91,22 +95,50 @@ def minimum(self, x, y):
9195
def where(self, criterion, then, else_):
9296
return rec_multimap_array_container(pt.where, criterion, then, else_)
9397

94-
def sum(self, a, axis=None, dtype=None):
95-
def _pt_sum(ary):
98+
@staticmethod
99+
def _reduce(container_binop, array_reduce,
100+
ary, *,
101+
axis, dtype, initial):
102+
def container_reduce(ctr):
103+
if initial is _NoValue:
104+
try:
105+
return reduce(container_binop, ctr)
106+
except TypeError as exc:
107+
assert "empty sequence" in str(exc)
108+
raise ValueError("zero-size reduction operation "
109+
"without supplied 'initial' value")
110+
else:
111+
return reduce(container_binop, ctr, initial)
112+
113+
def actual_array_reduce(ary):
96114
if dtype not in [ary.dtype, None]:
97115
raise NotImplementedError
98116

99-
return pt.sum(ary, axis=axis)
100-
101-
return rec_map_reduce_array_container(sum, _pt_sum, a)
102-
103-
def min(self, a, axis=None):
104-
return rec_map_reduce_array_container(
105-
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
117+
if initial is _NoValue:
118+
return array_reduce(ary, axis=axis)
119+
else:
120+
return array_reduce(ary, axis=axis, initial=initial)
106121

107-
def max(self, a, axis=None):
108122
return rec_map_reduce_array_container(
109-
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
123+
container_reduce,
124+
actual_array_reduce,
125+
ary)
126+
127+
# * appears where positional signature starts diverging from numpy
128+
def sum(self, a, axis=None, dtype=None, *, initial=0):
129+
import operator
130+
return self._reduce(operator.add, pt.sum, a,
131+
axis=axis, dtype=dtype, initial=initial)
132+
133+
# * appears where positional signature starts diverging from numpy
134+
def min(self, a, axis=None, *, initial=_NoValue):
135+
return self._reduce(pt.minimum, pt.amin, a,
136+
axis=axis, dtype=None, initial=initial)
137+
138+
# * appears where positional signature starts diverging from numpy
139+
def max(self, a, axis=None, *, initial=_NoValue):
140+
return self._reduce(pt.maximum, pt.amax, a,
141+
axis=axis, dtype=None, initial=initial)
110142

111143
def stack(self, arrays, axis=0):
112144
return rec_multimap_array_container(

0 commit comments

Comments
 (0)