Skip to content

Commit df3e169

Browse files
committed
Interpolator: bugfix for reusable matfree adjoint Interpolator in parallel
1 parent 302b76d commit df3e169

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

firedrake/interpolation.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -931,9 +931,6 @@ def callable():
931931
else:
932932
loops = []
933933

934-
if access == op2.INC:
935-
loops.append(tensor.zero)
936-
937934
# Arguments in the operand are allowed to be from a MixedFunctionSpace
938935
# We need to split the target space V and generate separate kernels
939936
if len(arguments) == 2:
@@ -966,7 +963,18 @@ def callable(loops, f):
966963
l()
967964
return f
968965

969-
return partial(callable, loops, f)
966+
def inc_callable(loops, f):
967+
# We are repeatedly incrementing into the same Dat so intermediate halo exchanges
968+
# can be skipped.
969+
f.dat.local_to_global_begin(access)
970+
with f.dat.frozen_halo(access):
971+
f.dat.zero()
972+
for l in loops:
973+
l()
974+
f.dat.local_to_global_end(access)
975+
return f
976+
977+
return partial(inc_callable if (access is op2.INC) else callable, loops, f)
970978

971979

972980
@utils.known_pyop2_safe
@@ -1076,7 +1084,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10761084
coefficient_numbers = kernel.coefficient_numbers
10771085
needs_external_coords = kernel.needs_external_coords
10781086
name = kernel.name
1079-
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
1087+
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=(access is op2.WRITE),
10801088
flop_count=kernel.flop_count, events=(kernel.event,))
10811089

10821090
parloop_args = [kernel, cell_set]
@@ -1099,7 +1107,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10991107
if isinstance(tensor, op2.Global):
11001108
parloop_args.append(tensor(access))
11011109
elif isinstance(tensor, op2.Dat):
1102-
V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V
1110+
V_dest = arguments[0].function_space()
11031111
m_ = get_interp_node_map(source_mesh, target_mesh, V_dest)
11041112
parloop_args.append(tensor(access, m_))
11051113
else:
@@ -1159,11 +1167,10 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
11591167
parloop_args.append(target_ref_coords.dat(op2.READ, m_))
11601168

11611169
parloop = op2.ParLoop(*parloop_args)
1162-
parloop_compute_callable = parloop.compute
11631170
if isinstance(tensor, op2.Mat):
1164-
return parloop_compute_callable, tensor.assemble
1171+
return parloop, tensor.assemble
11651172
else:
1166-
return copyin + callables + (parloop_compute_callable, ) + copyout
1173+
return copyin + callables + (parloop, ) + copyout
11671174

11681175

11691176
def get_interp_node_map(source_mesh, target_mesh, fs):

tests/firedrake/regression/test_interpolate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def test_trace():
327327
assert np.allclose(x_tr_cg.dat.data, x_tr_dir.dat.data)
328328

329329

330+
@pytest.mark.parallel(nprocs=[1, 3])
330331
@pytest.mark.parametrize("rank", (0, 1))
331332
@pytest.mark.parametrize("mat_type", ("matfree", "aij"))
332333
@pytest.mark.parametrize("degree", (1, 3))
@@ -566,3 +567,35 @@ def test_mixed_matrix(mode):
566567
result_explicit = assemble(action(a, u))
567568
for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions):
568569
assert np.allclose(x.dat.data, y.dat.data)
570+
571+
572+
@pytest.mark.parallel(nprocs=2)
573+
@pytest.mark.parametrize("mode", ["forward", "adjoint"])
574+
@pytest.mark.parametrize("family,degree", [("CG", 1)])
575+
def test_reuse_interpolate(family, degree, mode):
576+
mesh = UnitSquareMesh(1, 1)
577+
V = FunctionSpace(mesh, family, degree)
578+
rg = RandomGenerator(PCG64(seed=123456789))
579+
if mode == "forward":
580+
u = Function(V)
581+
expr = interpolate(u, V)
582+
583+
elif mode == "adjoint":
584+
u = Function(V.dual())
585+
expr = interpolate(TestFunction(V), u)
586+
587+
I = Interpolator(expr, V)
588+
589+
for k in range(2):
590+
u.assign(k+1)
591+
expected = u.dat.data.copy()
592+
result = I.assemble()
593+
594+
# Test that the input was not modified
595+
x = u.dat.data
596+
assert np.allclose(x, expected)
597+
598+
# Test for correctness
599+
y = result.dat.data
600+
assert np.allclose(y, expected)
601+
print("pass", k)

0 commit comments

Comments
 (0)