@@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
42
42
pass
43
43
44
44
45
+ class _NoValue :
46
+ pass
47
+
48
+
45
49
class PytatoFakeNumpyNamespace (BaseFakeNumpyNamespace ):
46
50
"""
47
51
A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
@@ -91,22 +95,50 @@ def minimum(self, x, y):
91
95
def where (self , criterion , then , else_ ):
92
96
return rec_multimap_array_container (pt .where , criterion , then , else_ )
93
97
94
- def sum (self , a , axis = None , dtype = None ):
95
- def _pt_sum (ary ):
98
+ @staticmethod
99
+ def _reduce (container_binop , array_reduce ,
100
+ ary , * ,
101
+ axis , dtype , initial ):
102
+ def container_reduce (ctr ):
103
+ if initial is _NoValue :
104
+ try :
105
+ return reduce (container_binop , ctr )
106
+ except TypeError as exc :
107
+ assert "empty sequence" in str (exc )
108
+ raise ValueError ("zero-size reduction operation "
109
+ "without supplied 'initial' value" )
110
+ else :
111
+ return reduce (container_binop , ctr , initial )
112
+
113
+ def actual_array_reduce (ary ):
96
114
if dtype not in [ary .dtype , None ]:
97
115
raise NotImplementedError
98
116
99
- return pt .sum (ary , axis = axis )
100
-
101
- return rec_map_reduce_array_container (sum , _pt_sum , a )
102
-
103
- def min (self , a , axis = None ):
104
- return rec_map_reduce_array_container (
105
- partial (reduce , pt .minimum ), partial (pt .amin , axis = axis ), a )
117
+ if initial is _NoValue :
118
+ return array_reduce (ary , axis = axis )
119
+ else :
120
+ return array_reduce (ary , axis = axis , initial = initial )
106
121
107
- def max (self , a , axis = None ):
108
122
return rec_map_reduce_array_container (
109
- partial (reduce , pt .maximum ), partial (pt .amax , axis = axis ), a )
123
+ container_reduce ,
124
+ actual_array_reduce ,
125
+ ary )
126
+
127
+ # * appears where positional signature starts diverging from numpy
128
+ def sum (self , a , axis = None , dtype = None , * , initial = 0 ):
129
+ import operator
130
+ return self ._reduce (operator .add , pt .sum , a ,
131
+ axis = axis , dtype = dtype , initial = initial )
132
+
133
+ # * appears where positional signature starts diverging from numpy
134
+ def min (self , a , axis = None , * , initial = _NoValue ):
135
+ return self ._reduce (pt .minimum , pt .amin , a ,
136
+ axis = axis , dtype = None , initial = initial )
137
+
138
+ # * appears where positional signature starts diverging from numpy
139
+ def max (self , a , axis = None , * , initial = _NoValue ):
140
+ return self ._reduce (pt .maximum , pt .amax , a ,
141
+ axis = axis , dtype = None , initial = initial )
110
142
111
143
def stack (self , arrays , axis = 0 ):
112
144
return rec_multimap_array_container (
0 commit comments