Skip to content

Commit 0f02ace

Browse files
authored
Make interpolate function avoid using Interpolator (#4432)
* changed `interpolate` function * take action if passed a `Cofunction` * change inputs to V; comment about adjoint interpolation
1 parent 7fe6ff3 commit 0f02ace

File tree

2 files changed

+39
-20
lines changed

2 files changed

+39
-20
lines changed

firedrake/interpolation.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import firedrake
2626
from firedrake import tsfc_interface, utils, functionspaceimpl
27-
from firedrake.ufl_expr import Argument, action, adjoint as expr_adjoint
27+
from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint
2828
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology
2929
from firedrake.petsc import PETSc
3030
from firedrake.halo import _get_mtype as get_dat_mpi_type
@@ -150,13 +150,13 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data):
150150

151151

152152
@PETSc.Log.EventDecorator()
153-
def interpolate(expr, V, *args, **kwargs):
153+
def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False, default_missing_val=None, matfree=True):
154154
"""Returns a UFL expression for the interpolation operation of ``expr`` into ``V``.
155155
156156
:arg expr: a UFL expression.
157-
:arg V: the :class:`.FunctionSpace` to interpolate into (or else
158-
an existing :class:`.Function` or :class:`.Cofunction`).
159-
Adjoint interpolation requires ``V`` to be a :class:`.Cofunction`.
157+
:arg V: a :class:`.FunctionSpace` to interpolate into, or a :class:`.Cofunction`,
158+
or :class:`.Coargument`, or a :class:`ufl.form.Form` with one argument (a one-form).
159+
If a :class:`.Cofunction` or a one-form is provided, then we do adjoint interpolation.
160160
:kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the
161161
interpolation over. Cannot, at present, be used when interpolating
162162
across meshes unless the target mesh is a :func:`.VertexOnlyMesh`.
@@ -182,35 +182,50 @@ def interpolate(expr, V, *args, **kwargs):
182182
some ``output`` is given to the :meth:`interpolate` method or (b) set
183183
to zero. Ignored if interpolating within the same mesh or onto a
184184
:func:`.VertexOnlyMesh`.
185-
:kwarg ad_block_tag: An optional string for tagging the resulting assemble block on the Pyadjoint tape.
185+
:kwarg matfree: If ``False``, then construct the permutation matrix for interpolating
186+
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
187+
and reduce operations.
186188
:returns: A symbolic :class:`.Interpolate` object
187189
188190
.. note::
189191
190192
If you use an access descriptor other than ``WRITE``, the
191-
behaviour of interpolation is changes if interpolating into a
193+
behaviour of interpolation changes if interpolating into a
192194
function space, or an existing function. If the former, then
193195
the newly allocated function will be initialised with
194196
appropriate values (e.g. for MIN access, it will be initialised
195197
with MAX_FLOAT). On the other hand, if you provide a function,
196198
then it is assumed that its values should take part in the
197199
reduction (hence using MIN will compute the MIN between the
198200
existing values and any new values).
201+
"""
202+
if isinstance(V, (Cofunction, Coargument)):
203+
dual_arg = V
204+
elif isinstance(V, ufl.Form):
205+
rank = len(V.arguments())
206+
if rank == 1:
207+
dual_arg = V
208+
else:
209+
raise TypeError(f"Expected a one-form, provided form had {rank} arguments")
210+
elif isinstance(V, functionspaceimpl.WithGeometry):
211+
dual_arg = Coargument(V.dual(), 0)
212+
expr_args = extract_arguments(expr)
213+
if expr_args and expr_args[0].number() == 0:
214+
# In this case we are doing adjoint interpolation
215+
# When V is a FunctionSpace and expr contains Argument(0),
216+
# we need to change expr argument number to 1 (in our current implementation)
217+
v, = expr_args
218+
expr = replace(expr, {v: v.reconstruct(number=1)})
219+
else:
220+
raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}")
199221

200-
.. note::
201-
202-
If you find interpolating the same expression again and again
203-
(for example in a time loop) you may find you get better
204-
performance by using an :class:`Interpolator` instead.
222+
interp = Interpolate(expr, dual_arg,
223+
subset=subset, access=access,
224+
allow_missing_dofs=allow_missing_dofs,
225+
default_missing_val=default_missing_val,
226+
matfree=matfree)
205227

206-
"""
207-
default_missing_val = kwargs.pop("default_missing_val", None)
208-
if isinstance(V, Cofunction):
209-
adjoint = bool(extract_arguments(expr))
210-
return Interpolator(
211-
expr, V.function_space().dual(), *args, **kwargs
212-
).interpolate(V, adjoint=adjoint, default_missing_val=default_missing_val)
213-
return Interpolator(expr, V, *args, **kwargs).interpolate(default_missing_val=default_missing_val)
228+
return interp
214229

215230

216231
class Interpolator(abc.ABC):

tests/firedrake/regression/test_interpolate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ def test_adjoint_Pk(degree):
403403

404404
assert np.allclose(u_Pk.dat.data, v_adj.dat.data)
405405

406+
v_adj_form = assemble(interpolate(TestFunction(Pk), v * dx))
407+
408+
assert np.allclose(v_adj_form.dat.data, v_adj.dat.data)
409+
406410

407411
def test_adjoint_quads():
408412
mesh = UnitSquareMesh(10, 10)

0 commit comments

Comments
 (0)