diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 4fb288fb4d..12ed650f48 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -930,8 +930,8 @@ def callable(): return callable else: loops = [] - - if access == op2.INC: + # Initialise to zero if needed + if access is op2.INC: loops.append(tensor.zero) # Arguments in the operand are allowed to be from a MixedFunctionSpace @@ -957,7 +957,7 @@ def callable(): for indices, sub_expr in expressions.items(): sub_tensor = tensor[indices[0]] if rank == 1 else tensor loops.extend(_interpolator(sub_tensor, sub_expr, subset, access, bcs=bcs)) - + # Apply bcs if bcs and rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) @@ -1038,32 +1038,36 @@ def _interpolator(tensor, expr, subset, access, bcs=None): parameters = {} parameters['scalar_type'] = utils.ScalarType - callables = () + copyin = () + copyout = () # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple # contributions from the facet DOFs of the dual argument. # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity. needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: - # Compute the reciprocal of the DOF multiplicity + # Create a buffer for the weighted Cofunction W = dual_arg.function_space() + v = firedrake.Function(W) + expr = expr._ufl_expr_reconstruct_(operand, v=v) + copyin += (partial(dual_arg.dat.copy, v.dat),) + + # Compute the reciprocal of the DOF multiplicity + wdat = W.make_dat() + m_ = get_interp_node_map(source_mesh, target_mesh, W) wsize = W.finat_element.space_dimension() * W.block_size kernel_code = f""" void multiplicity(PetscScalar *restrict w) {{ for (PetscInt i=0; i<{wsize}; i++) w[i] += 1; }}""" - kernel = op2.Kernel(kernel_code, "multiplicity", requires_zeroed_output_arguments=False) - weight = firedrake.Function(W) - m_ = get_interp_node_map(source_mesh, target_mesh, W) - op2.par_loop(kernel, cell_set, weight.dat(op2.INC, m_)) - with weight.dat.vec as w: + kernel = op2.Kernel(kernel_code, "multiplicity") + op2.par_loop(kernel, cell_set, wdat(op2.INC, m_)) + with wdat.vec as w: w.reciprocal() - # Create a buffer for the weighted Cofunction and a callable to apply the weight - v = firedrake.Function(W) - expr = expr._ufl_expr_reconstruct_(operand, v=v) - with weight.dat.vec_ro as w, dual_arg.dat.vec_ro as x, v.dat.vec_wo as y: - callables += (partial(y.pointwiseMult, x, w),) + # Create a callable to apply the weight + with wdat.vec_ro as w, v.dat.vec as y: + copyin += (partial(y.pointwiseMult, y, w),) # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping @@ -1079,7 +1083,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None): coefficient_numbers = kernel.coefficient_numbers needs_external_coords = kernel.needs_external_coords name = kernel.name - kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True, + kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=(access is not op2.INC), flop_count=kernel.flop_count, events=(kernel.event,)) parloop_args = [kernel, cell_set] @@ -1092,17 +1096,12 @@ def _interpolator(tensor, expr, subset, access, bcs=None): output = tensor tensor = op2.Dat(tensor.dataset) if access is not op2.WRITE: - copyin = (partial(output.copy, tensor), ) - else: - copyin = () - copyout = (partial(tensor.copy, output), ) - else: - copyin = () - copyout = () + copyin += (partial(output.copy, tensor), ) + copyout += (partial(tensor.copy, output), ) if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): - V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V + V_dest = arguments[-1].function_space() m_ = get_interp_node_map(source_mesh, target_mesh, V_dest) parloop_args.append(tensor(access, m_)) else: @@ -1162,11 +1161,10 @@ def _interpolator(tensor, expr, subset, access, bcs=None): parloop_args.append(target_ref_coords.dat(op2.READ, m_)) parloop = op2.ParLoop(*parloop_args) - parloop_compute_callable = parloop.compute if isinstance(tensor, op2.Mat): - return parloop_compute_callable, tensor.assemble + return parloop, tensor.assemble else: - return copyin + callables + (parloop_compute_callable, ) + copyout + return copyin + (parloop, ) + copyout def get_interp_node_map(source_mesh, target_mesh, fs): diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index 47b3dc7a6d..e8b5cb595a 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -327,6 +327,7 @@ def test_trace(): assert np.allclose(x_tr_cg.dat.data, x_tr_dir.dat.data) +@pytest.mark.parallel([1, 3]) @pytest.mark.parametrize("rank", (0, 1)) @pytest.mark.parametrize("mat_type", ("matfree", "aij")) @pytest.mark.parametrize("degree", (1, 3)) @@ -566,3 +567,35 @@ def test_mixed_matrix(mode): result_explicit = assemble(action(a, u)) for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions): assert np.allclose(x.dat.data, y.dat.data) + + +@pytest.mark.parallel(2) +@pytest.mark.parametrize("mode", ["forward", "adjoint"]) +@pytest.mark.parametrize("family,degree", [("CG", 1), ("DG", 0)]) +def test_interpolator_reuse(family, degree, mode): + mesh = UnitSquareMesh(1, 1) + V = FunctionSpace(mesh, family, degree) + rg = RandomGenerator(PCG64(seed=123456789)) + if mode == "forward": + u = Function(V) + expr = interpolate(u, V) + + elif mode == "adjoint": + u = Function(V.dual()) + expr = interpolate(TestFunction(V), u) + + I = Interpolator(expr, V) + + for k in range(3): + u.assign(rg.uniform(u.function_space())) + expected = u.dat.data.copy() + + tensor = Function(expr.function_space()) + result = I.assemble(tensor=tensor) + assert result is tensor + + # Test that the input was not modified + assert np.allclose(u.dat.data, expected) + + # Test for correctness + assert np.allclose(result.dat.data, expected)