2424
2525import firedrake
2626from firedrake import tsfc_interface , utils , functionspaceimpl
27- from firedrake .ufl_expr import Argument , action , adjoint as expr_adjoint
27+ from firedrake .ufl_expr import Argument , Coargument , action , adjoint as expr_adjoint
2828from firedrake .mesh import MissingPointsBehaviour , VertexOnlyMeshMissingPointsError , VertexOnlyMeshTopology
2929from firedrake .petsc import PETSc
3030from firedrake .halo import _get_mtype as get_dat_mpi_type
@@ -150,13 +150,13 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data):
150150
151151
152152@PETSc .Log .EventDecorator ()
153- def interpolate (expr , V , * args , ** kwargs ):
153+ def interpolate (expr , V , subset = None , access = op2 . WRITE , allow_missing_dofs = False , default_missing_val = None , matfree = True ):
154154 """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``.
155155
156156 :arg expr: a UFL expression.
157- :arg V: the :class:`.FunctionSpace` to interpolate into ( or else
158- an existing :class:`.Function` or :class:`.Cofunction` ).
159- Adjoint interpolation requires ``V`` to be a :class:`.Cofunction`.
157+ :arg V: a :class:`.FunctionSpace` to interpolate into, or a :class:`.Cofunction`,
158+ or :class:`.Coargument`, or a :class:`ufl.form.Form` with one argument (a one-form ).
159+ If a :class:`.Cofunction` or a one-form is provided, then we do adjoint interpolation .
160160 :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the
161161 interpolation over. Cannot, at present, be used when interpolating
162162 across meshes unless the target mesh is a :func:`.VertexOnlyMesh`.
@@ -182,35 +182,50 @@ def interpolate(expr, V, *args, **kwargs):
182182 some ``output`` is given to the :meth:`interpolate` method or (b) set
183183 to zero. Ignored if interpolating within the same mesh or onto a
184184 :func:`.VertexOnlyMesh`.
185- :kwarg ad_block_tag: An optional string for tagging the resulting assemble block on the Pyadjoint tape.
185+ :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating
186+ between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
187+ and reduce operations.
186188 :returns: A symbolic :class:`.Interpolate` object
187189
188190 .. note::
189191
190192 If you use an access descriptor other than ``WRITE``, the
191- behaviour of interpolation is changes if interpolating into a
193+ behaviour of interpolation changes if interpolating into a
192194 function space, or an existing function. If the former, then
193195 the newly allocated function will be initialised with
194196 appropriate values (e.g. for MIN access, it will be initialised
195197 with MAX_FLOAT). On the other hand, if you provide a function,
196198 then it is assumed that its values should take part in the
197199 reduction (hence using MIN will compute the MIN between the
198200 existing values and any new values).
201+ """
202+ if isinstance (V , (Cofunction , Coargument )):
203+ dual_arg = V
204+ elif isinstance (V , ufl .Form ):
205+ rank = len (V .arguments ())
206+ if rank == 1 :
207+ dual_arg = V
208+ else :
209+ raise TypeError (f"Expected a one-form, provided form had { rank } arguments" )
210+ elif isinstance (V , functionspaceimpl .WithGeometry ):
211+ dual_arg = Coargument (V .dual (), 0 )
212+ expr_args = extract_arguments (expr )
213+ if expr_args and expr_args [0 ].number () == 0 :
214+ # In this case we are doing adjoint interpolation
215+ # When V is a FunctionSpace and expr contains Argument(0),
216+ # we need to change expr argument number to 1 (in our current implementation)
217+ v , = expr_args
218+ expr = replace (expr , {v : v .reconstruct (number = 1 )})
219+ else :
220+ raise TypeError (f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a { type (V ).__name__ } " )
199221
200- .. note::
201-
202- If you find interpolating the same expression again and again
203- (for example in a time loop) you may find you get better
204- performance by using an :class:`Interpolator` instead.
222+ interp = Interpolate ( expr , dual_arg ,
223+ subset = subset , access = access ,
224+ allow_missing_dofs = allow_missing_dofs ,
225+ default_missing_val = default_missing_val ,
226+ matfree = matfree )
205227
206- """
207- default_missing_val = kwargs .pop ("default_missing_val" , None )
208- if isinstance (V , Cofunction ):
209- adjoint = bool (extract_arguments (expr ))
210- return Interpolator (
211- expr , V .function_space ().dual (), * args , ** kwargs
212- ).interpolate (V , adjoint = adjoint , default_missing_val = default_missing_val )
213- return Interpolator (expr , V , * args , ** kwargs ).interpolate (default_missing_val = default_missing_val )
228+ return interp
214229
215230
216231class Interpolator (abc .ABC ):
0 commit comments