@@ -124,29 +124,31 @@ def map_node_min(self, expr):
124
124
def _map_elementwise_reduction (self , reduction_name , expr ):
125
125
import loopy as lp
126
126
from arraycontext import make_loopy_program
127
- from meshmode .transform_metadata import (
128
- ConcurrentElementInameTag , ConcurrentDOFInameTag )
127
+ from meshmode .transform_metadata import ConcurrentElementInameTag
128
+ actx = self . array_context
129
129
130
- @memoize_in (self .places , "elementwise_node_" + reduction_name )
130
+ @memoize_in (actx , (
131
+ EvaluationMapperBase ._map_elementwise_reduction ,
132
+ f"elementwise_node_{ reduction_name } " ))
131
133
def node_knl ():
132
134
t_unit = make_loopy_program (
133
135
"""{[iel, idof, jdof]:
134
136
0<=iel<nelements and
135
137
0<=idof, jdof<ndofs}""" ,
136
138
"""
137
- result[iel, idof] = %s(jdof, operand[iel, jdof])
139
+ <> el_result = %s(jdof, operand[iel, jdof])
140
+ result[iel, idof] = el_result
138
141
""" % reduction_name ,
139
- name = "nodewise_reduce " )
142
+ name = f"elementwise_node_ { reduction_name } " )
140
143
141
144
return lp .tag_inames (t_unit , {
142
145
"iel" : ConcurrentElementInameTag (),
143
- "idof" : ConcurrentDOFInameTag (),
144
146
})
145
147
146
- @memoize_in (self .places , "elementwise_" + reduction_name )
148
+ @memoize_in (actx , (
149
+ EvaluationMapperBase ._map_elementwise_reduction ,
150
+ f"elementwise_element_{ reduction_name } " ))
147
151
def element_knl ():
148
- # FIXME: This computes the reduction value redundantly for each
149
- # output DOF.
150
152
t_unit = make_loopy_program (
151
153
"""{[iel, jdof]:
152
154
0<=iel<nelements and
@@ -155,37 +157,27 @@ def element_knl():
155
157
"""
156
158
result[iel, 0] = %s(jdof, operand[iel, jdof])
157
159
""" % reduction_name ,
158
- name = "elementwise_reduce " )
160
+ name = f"elementwise_element_ { reduction_name } " )
159
161
160
162
return lp .tag_inames (t_unit , {
161
163
"iel" : ConcurrentElementInameTag (),
162
164
})
163
165
164
- discr = self .places .get_discretization (
165
- expr .dofdesc .geometry , expr .dofdesc .discr_stage )
166
+ dofdesc = expr .dofdesc
166
167
operand = self .rec (expr .operand )
167
- assert operand .shape == (len (discr .groups ),)
168
-
169
- def _reduce (knl , result ):
170
- for g_operand , g_result in zip (operand , result ):
171
- self .array_context .call_loopy (
172
- knl , operand = g_operand , result = g_result )
173
-
174
- return result
175
-
176
- dtype = operand .entry_dtype
177
- granularity = expr .dofdesc .granularity
178
- if granularity is sym .GRANULARITY_NODE :
179
- return _reduce (node_knl (),
180
- discr .empty (self .array_context , dtype = dtype ))
181
- elif granularity is sym .GRANULARITY_ELEMENT :
182
- result = DOFArray (self .array_context , tuple ([
183
- self .array_context .empty ((grp .nelements , 1 ), dtype = dtype )
184
- for grp in discr .groups
168
+
169
+ if dofdesc .granularity is sym .GRANULARITY_NODE :
170
+ return type (operand )(actx , tuple ([
171
+ actx .call_loopy (node_knl (), operand = operand_i )["result" ]
172
+ for operand_i in operand
173
+ ]))
174
+ elif dofdesc .granularity is sym .GRANULARITY_ELEMENT :
175
+ return type (operand )(actx , tuple ([
176
+ actx .call_loopy (element_knl (), operand = operand_i )["result" ]
177
+ for operand_i in operand
185
178
]))
186
- return _reduce (element_knl (), result )
187
179
else :
188
- raise ValueError (f"unsupported granularity: { granularity } " )
180
+ raise ValueError (f"unsupported granularity: { dofdesc . granularity } " )
189
181
190
182
def map_elementwise_sum (self , expr ):
191
183
return self ._map_elementwise_reduction ("sum" , expr )
0 commit comments