Skip to content

Commit 879ad35

Browse files
alexfiklinducer
authored andcommitted
simplify elementwise reductions
1 parent 0c42040 commit 879ad35

File tree

2 files changed

+69
-45
lines changed

2 files changed

+69
-45
lines changed

pytential/symbolic/execution.py

+24-32
Original file line numberDiff line numberDiff line change
@@ -124,29 +124,31 @@ def map_node_min(self, expr):
124124
def _map_elementwise_reduction(self, reduction_name, expr):
125125
import loopy as lp
126126
from arraycontext import make_loopy_program
127-
from meshmode.transform_metadata import (
128-
ConcurrentElementInameTag, ConcurrentDOFInameTag)
127+
from meshmode.transform_metadata import ConcurrentElementInameTag
128+
actx = self.array_context
129129

130-
@memoize_in(self.places, "elementwise_node_"+reduction_name)
130+
@memoize_in(actx, (
131+
EvaluationMapperBase._map_elementwise_reduction,
132+
f"elementwise_node_{reduction_name}"))
131133
def node_knl():
132134
t_unit = make_loopy_program(
133135
"""{[iel, idof, jdof]:
134136
0<=iel<nelements and
135137
0<=idof, jdof<ndofs}""",
136138
"""
137-
result[iel, idof] = %s(jdof, operand[iel, jdof])
139+
<> el_result = %s(jdof, operand[iel, jdof])
140+
result[iel, idof] = el_result
138141
""" % reduction_name,
139-
name="nodewise_reduce")
142+
name=f"elementwise_node_{reduction_name}")
140143

141144
return lp.tag_inames(t_unit, {
142145
"iel": ConcurrentElementInameTag(),
143-
"idof": ConcurrentDOFInameTag(),
144146
})
145147

146-
@memoize_in(self.places, "elementwise_"+reduction_name)
148+
@memoize_in(actx, (
149+
EvaluationMapperBase._map_elementwise_reduction,
150+
f"elementwise_element_{reduction_name}"))
147151
def element_knl():
148-
# FIXME: This computes the reduction value redundantly for each
149-
# output DOF.
150152
t_unit = make_loopy_program(
151153
"""{[iel, jdof]:
152154
0<=iel<nelements and
@@ -155,37 +157,27 @@ def element_knl():
155157
"""
156158
result[iel, 0] = %s(jdof, operand[iel, jdof])
157159
""" % reduction_name,
158-
name="elementwise_reduce")
160+
name=f"elementwise_element_{reduction_name}")
159161

160162
return lp.tag_inames(t_unit, {
161163
"iel": ConcurrentElementInameTag(),
162164
})
163165

164-
discr = self.places.get_discretization(
165-
expr.dofdesc.geometry, expr.dofdesc.discr_stage)
166+
dofdesc = expr.dofdesc
166167
operand = self.rec(expr.operand)
167-
assert operand.shape == (len(discr.groups),)
168-
169-
def _reduce(knl, result):
170-
for g_operand, g_result in zip(operand, result):
171-
self.array_context.call_loopy(
172-
knl, operand=g_operand, result=g_result)
173-
174-
return result
175-
176-
dtype = operand.entry_dtype
177-
granularity = expr.dofdesc.granularity
178-
if granularity is sym.GRANULARITY_NODE:
179-
return _reduce(node_knl(),
180-
discr.empty(self.array_context, dtype=dtype))
181-
elif granularity is sym.GRANULARITY_ELEMENT:
182-
result = DOFArray(self.array_context, tuple([
183-
self.array_context.empty((grp.nelements, 1), dtype=dtype)
184-
for grp in discr.groups
168+
169+
if dofdesc.granularity is sym.GRANULARITY_NODE:
170+
return type(operand)(actx, tuple([
171+
actx.call_loopy(node_knl(), operand=operand_i)["result"]
172+
for operand_i in operand
173+
]))
174+
elif dofdesc.granularity is sym.GRANULARITY_ELEMENT:
175+
return type(operand)(actx, tuple([
176+
actx.call_loopy(element_knl(), operand=operand_i)["result"]
177+
for operand_i in operand
185178
]))
186-
return _reduce(element_knl(), result)
187179
else:
188-
raise ValueError(f"unsupported granularity: {granularity}")
180+
raise ValueError(f"unsupported granularity: {dofdesc.granularity}")
189181

190182
def map_elementwise_sum(self, expr):
191183
return self._map_elementwise_reduction("sum", expr)

test/test_symbolic.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -306,26 +306,58 @@ def test_node_reduction(actx_factory):
306306

307307
# {{{ test
308308

309-
# create a shuffled [1, nelements + 1] array
310-
ary = []
311-
el_nr_base = 0
312-
for grp in discr.groups:
313-
x = 1 + np.arange(el_nr_base, grp.nelements)
314-
np.random.shuffle(x)
309+
# create a shuffled [1, ndofs + 1] array
310+
rng = np.random.default_rng(seed=42)
315311

316-
ary.append(actx.freeze(actx.from_numpy(x.reshape(-1, 1))))
317-
el_nr_base += grp.nelements
312+
def randrange_like(xi, offset):
313+
x = offset + np.arange(1, xi.size + 1)
314+
rng.shuffle(x)
315+
316+
return actx.from_numpy(x.reshape(xi.shape))
318317

319318
from meshmode.dof_array import DOFArray
320-
ary = DOFArray(actx, tuple(ary))
319+
base_node_nrs = np.cumsum([0] + [grp.ndofs for grp in discr.groups])
320+
ary = DOFArray(actx, tuple([
321+
randrange_like(xi, offset)
322+
for xi, offset in zip(discr.nodes()[0], base_node_nrs)
323+
]))
321324

325+
n = discr.ndofs
322326
for func, expected in [
323-
(sym.NodeSum, nelements * (nelements + 1) // 2),
324-
(sym.NodeMax, nelements),
327+
(sym.NodeSum, n * (n + 1) // 2),
328+
(sym.NodeMax, n),
325329
(sym.NodeMin, 1),
326330
]:
327-
r = bind(discr, func(sym.var("x")))(actx, x=ary)
328-
assert abs(actx.to_numpy(r) - expected) < 1.0e-15, r
331+
r = actx.to_numpy(
332+
bind(discr, func(sym.var("x")))(actx, x=ary)
333+
)
334+
assert r == expected, (r, expected)
335+
336+
arys = tuple([rng.random(size=xi.shape) for xi in ary])
337+
x = DOFArray(actx, tuple([actx.from_numpy(xi) for xi in arys]))
338+
339+
from meshmode.dof_array import flat_norm
340+
for func, np_func in [
341+
(sym.ElementwiseSum, np.sum),
342+
(sym.ElementwiseMax, np.max)
343+
]:
344+
expected = DOFArray(actx, tuple([
345+
actx.from_numpy(np.tile(np_func(xi, axis=1, keepdims=True), xi.shape[1]))
346+
for xi in arys
347+
]))
348+
r = bind(
349+
discr, func(sym.var("x"), dofdesc=sym.GRANULARITY_NODE)
350+
)(actx, x=x)
351+
assert actx.to_numpy(flat_norm(r - expected)) < 1.0e-15
352+
353+
expected = DOFArray(actx, tuple([
354+
actx.from_numpy(np_func(xi, axis=1, keepdims=True))
355+
for xi in arys
356+
]))
357+
r = bind(
358+
discr, func(sym.var("x"), dofdesc=sym.GRANULARITY_ELEMENT)
359+
)(actx, x=x)
360+
assert actx.to_numpy(flat_norm(r - expected)) < 1.0e-15
329361

330362
# }}}
331363

0 commit comments

Comments
 (0)