diff --git a/demos/boussinesq/boussinesq.py.rst b/demos/boussinesq/boussinesq.py.rst index edfdf3c1a5..5cc6708cf0 100644 --- a/demos/boussinesq/boussinesq.py.rst +++ b/demos/boussinesq/boussinesq.py.rst @@ -184,7 +184,7 @@ implements a boundary condition that fixes a field at a single point. :: # Take the basis function with the largest abs value at bc_point v = TestFunction(V) - F = assemble(Interpolate(inner(v, v), Fvom)) + F = assemble(interpolate(inner(v, v), Fvom)) with F.dat.vec as Fvec: max_index, _ = Fvec.max() nodes = V.dof_dset.lgmap.applyInverse([max_index]) diff --git a/demos/multicomponent/multicomponent.py.rst b/demos/multicomponent/multicomponent.py.rst index bf74e5d2e0..7a29d6ef1c 100644 --- a/demos/multicomponent/multicomponent.py.rst +++ b/demos/multicomponent/multicomponent.py.rst @@ -521,7 +521,7 @@ mathematically valid to do this):: # Take the basis function with the largest abs value at bc_point v = TestFunction(V) - F = assemble(Interpolate(inner(v, v), Fvom)) + F = assemble(interpolate(inner(v, v), Fvom)) with F.dat.vec as Fvec: max_index, _ = Fvec.max() nodes = V.dof_dset.lgmap.applyInverse([max_index]) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index bd4214b6fd..6f573e1a29 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -23,6 +23,7 @@ from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key +from firedrake.interpolation import get_interpolator from firedrake.petsc import PETSc from firedrake.slate import slac, slate from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg @@ -578,17 +579,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args): rank = len(expr.arguments()) if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") - # Get the target space - V = v.function_space().dual() - - # Get the interpolator - interp_data = expr.interp_data.copy() - default_missing_val = interp_data.pop('default_missing_val', None) - if rank == 1 and isinstance(tensor, firedrake.Function): - V = tensor - interpolator = firedrake.Interpolator(expr, V, **interp_data) - # Assembly - return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) + + interpolator = get_interpolator(expr) + return interpolator.assemble(tensor=tensor) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) elif tensor and isinstance(expr, ufl.ZeroBaseForm): diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 17c199c1d8..c1aaf424fb 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -1,7 +1,7 @@ # A module implementing strong (Dirichlet) boundary conditions. import numpy as np -import functools +from functools import partial, reduce import itertools import ufl @@ -167,7 +167,7 @@ def hermite_stride(bcnodes): # Edge conditions have only been tested with Lagrange elements. # Need to expand the list. bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss))) - bcnodes1 = functools.reduce(np.intersect1d, bcnodes1) + bcnodes1 = reduce(np.intersect1d, bcnodes1) bcnodes.append(bcnodes1) return np.concatenate(bcnodes) @@ -359,11 +359,10 @@ def function_arg(self, g): raise RuntimeError(f"Provided boundary value {g} does not match shape of space") try: self._function_arg = firedrake.Function(V) - # Use `Interpolator` instead of assembling an `Interpolate` form - # as the expression compilation needs to happen at this stage to - # determine if we should use interpolation or projection - # -> e.g. interpolation may not be supported for the element. - self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate + interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V)) + # Call this here to check if the element supports interpolation + interpolator._build_callable() + self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg) except (NotImplementedError, AttributeError): # Element doesn't implement interpolation self._function_arg = firedrake.Function(V).project(g) diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index f0fda10f63..9ee867c622 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -318,7 +318,7 @@ def interpolate(self, Parameters ---------- expression - A dual UFL expression to interpolate. + A UFL BaseForm to adjoint interpolate. ad_block_tag An optional string for tagging the resulting assemble block on the Pyadjoint tape. @@ -331,9 +331,9 @@ def interpolate(self, firedrake.cofunction.Cofunction Returns `self` """ - from firedrake import interpolation, assemble + from firedrake import interpolate, assemble v, = self.arguments() - interp = interpolation.Interpolate(v, expression, **kwargs) + interp = interpolate(v, expression, **kwargs) return assemble(interp, tensor=self, ad_block_tag=ad_block_tag) @property diff --git a/firedrake/external_operators/point_expr_operator.py b/firedrake/external_operators/point_expr_operator.py index 3aa40e1d5b..4e7183e47f 100644 --- a/firedrake/external_operators/point_expr_operator.py +++ b/firedrake/external_operators/point_expr_operator.py @@ -5,7 +5,7 @@ import firedrake.ufl_expr as ufl_expr from firedrake.assemble import assemble -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.external_operators import AbstractExternalOperator, assemble_method @@ -58,7 +58,7 @@ def assemble_operator(self, *args, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) if len(V) < 2: - interp = Interpolate(expr, self.function_space()) + interp = interpolate(expr, self.function_space()) return assemble(interp) # Interpolation of UFL expressions for mixed functions is not yet supported # -> `Function.assign` might be enough in some cases. @@ -72,7 +72,7 @@ def assemble_operator(self, *args, **kwargs): def assemble_Jacobian_action(self, *args, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) - interp = Interpolate(expr, V) + interp = interpolate(expr, V) u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1] w = self.argument_slots()[-1] @@ -83,7 +83,7 @@ def assemble_Jacobian_action(self, *args, **kwargs): def assemble_Jacobian(self, *args, assembly_opts, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) - interp = Interpolate(expr, V) + interp = interpolate(expr, V) u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1] jac = ufl_expr.derivative(interp, u) @@ -99,7 +99,7 @@ def assemble_Jacobian_adjoint(self, *args, assembly_opts, **kwargs): def assemble_Jacobian_adjoint_action(self, *args, **kwargs): V = self.function_space() expr = as_ufl(self.expr(*self.ufl_operands)) - interp = Interpolate(expr, V) + interp = interpolate(expr, V) u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1] ustar = self.argument_slots()[0] diff --git a/firedrake/function.py b/firedrake/function.py index b2cda5bc4e..ce9b4b1538 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -382,9 +382,9 @@ def interpolate(self, firedrake.function.Function Returns `self` """ - from firedrake import interpolation, assemble + from firedrake import interpolate, assemble V = self.function_space() - interp = interpolation.Interpolate(expression, V, **kwargs) + interp = interpolate(expression, V, **kwargs) return assemble(interp, tensor=self, ad_block_tag=ad_block_tag) def zero(self, subset=None): @@ -701,7 +701,7 @@ def __init__(self, domain, point): self.point = point def __str__(self): - return "domain %s does not contain point %s" % (self.domain, self.point) + return f"Domain {self.domain} does not contain point {self.point}" class PointEvaluator: diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 12ed650f48..7ea703c74e 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -2,17 +2,19 @@ import os import tempfile import abc -import warnings -from collections.abc import Iterable -from typing import Literal + from functools import partial, singledispatch -from typing import Hashable +from typing import Hashable, Literal, Callable, Iterable +from dataclasses import asdict, dataclass +from numbers import Number import FIAT import ufl import finat.ufl -from ufl.algorithms import extract_arguments, extract_coefficients, replace -from ufl.domain import as_domain, extract_unique_domain +from ufl.algorithms import extract_arguments +from ufl.domain import extract_unique_domain +from ufl.classes import Expr +from ufl.duals import is_dual from pyop2 import op2 from pyop2.caching import memory_and_disk_cache @@ -25,395 +27,289 @@ import finat import firedrake -import firedrake.bcs -from firedrake import tsfc_interface, utils, functionspaceimpl -from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint -from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology +from firedrake import tsfc_interface, utils +from firedrake.ufl_expr import Argument, Coargument, action +from firedrake.cofunction import Cofunction +from firedrake.function import Function +from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry from firedrake.petsc import PETSc from firedrake.halo import _get_mtype as get_dat_mpi_type -from firedrake.cofunction import Cofunction +from firedrake.functionspaceimpl import WithGeometry +from firedrake.matrix import MatrixBase +from firedrake.bcs import DirichletBC from mpi4py import MPI from pyadjoint import stop_annotating, no_annotations __all__ = ( "interpolate", - "Interpolator", "Interpolate", + "get_interpolator", "DofNotDefinedError", - "CrossMeshInterpolator", - "SameMeshInterpolator", + "InterpolateOptions", + "Interpolator" ) +@dataclass +class InterpolateOptions: + """Options for interpolation operations. + + Parameters + ---------- + subset : pyop2.types.set.Subset or None + An optional subset to apply the interpolation over. + Cannot, at present, be used when interpolating across meshes unless + the target mesh is a :func:`.VertexOnlyMesh`. + access : pyop2.types.access.Access or None + The pyop2 access descriptor for combining updates to shared + DoFs. Possible values include ``WRITE``, ``MIN``, ``MAX``, and ``INC``. + Only ``WRITE`` is supported at present when interpolating across meshes + unless the target mesh is a :func:`.VertexOnlyMesh`. Only ``INC`` is + supported for the matrix-free adjoint interpolation. + allow_missing_dofs : bool + For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) + in the target mesh that cannot be defined on the source mesh. + For example, where nodes are point evaluations, points in the target mesh + that are not in the source mesh. When ``False`` this raises a ``ValueError`` + should this occur. When ``True`` the corresponding values are either + (a) unchanged if some ``output`` is given to the :meth:`interpolate` method + or (b) set to zero. + Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. + This does not affect adjoint interpolation. Ignored if interpolating within + the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a + :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). + default_missing_val : float or None + For interpolation across meshes: the optional value to assign to DoFs + in the target mesh that are outside the source mesh. If this is not set + then the values are either (a) unchanged if some ``output`` is given to + the :meth:`interpolate` method or (b) set to zero. + Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. + matfree : bool + If ``False``, then construct the permutation matrix for interpolating + between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast + and reduce operations. + bcs : Iterable[DirichletBC] or None + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None. + """ + subset: op2.Subset | None = None + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None + allow_missing_dofs: bool = False + default_missing_val: float | None = None + matfree: bool = True + bcs: Iterable[DirichletBC] | None = None + + class Interpolate(ufl.Interpolate): - def __init__(self, expr, v, - subset=None, - access=None, - allow_missing_dofs=False, - default_missing_val=None, - matfree=True): + def __init__(self, expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs): """Symbolic representation of the interpolation operator. Parameters ---------- - expr : ufl.core.expr.Expr or ufl.BaseForm + expr : ufl.core.expr.Expr The UFL expression to interpolate. - v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument + V : firedrake.functionspaceimpl.WithGeometry or ufl.BaseForm The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. - subset : pyop2.types.set.Subset - An optional subset to apply the interpolation over. - Cannot, at present, be used when interpolating across meshes unless - the target mesh is a :func:`.VertexOnlyMesh`. - access : pyop2.types.access.Access - The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes. See note in - :func:`.interpolate` if changing this from default. - allow_missing_dofs : bool - For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) - in the target mesh that cannot be defined on the source mesh. - For example, where nodes are point evaluations, points in the target mesh - that are not in the source mesh. When ``False`` this raises a ``ValueError`` - should this occur. When ``True`` the corresponding values are either - (a) unchanged if some ``output`` is given to the :meth:`interpolate` method - or (b) set to zero. - Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. - This does not affect adjoint interpolation. Ignored if interpolating within - the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a - :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). - default_missing_val : float - For interpolation across meshes: the optional value to assign to DoFs - in the target mesh that are outside the source mesh. If this is not set - then the values are either (a) unchanged if some ``output`` is given to - the :meth:`interpolate` method or (b) set to zero. - Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. - matfree : bool - If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. + **kwargs + Additional interpolation options. See :class:`InterpolateOptions` + for available parameters and their descriptions. """ - # Check function space expr = ufl.as_ufl(expr) - if isinstance(v, functionspaceimpl.WithGeometry): - expr_args = extract_arguments(expr) - is_adjoint = len(expr_args) and expr_args[0].number() == 0 - v = Argument(v.dual(), 1 if is_adjoint else 0) - - V = v.arguments()[0].function_space() - if len(expr.ufl_shape) != len(V.value_shape): - raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}') - - if expr.ufl_shape != V.value_shape: - raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}') - super().__init__(expr, v) - - # -- Interpolate data (e.g. `subset` or `access`) -- # - self.interp_data = {"subset": subset, - "access": access, - "allow_missing_dofs": allow_missing_dofs, - "default_missing_val": default_missing_val, - "matfree": matfree} + expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)} + self.is_adjoint = expr_arg_numbers == {0} + if isinstance(V, WithGeometry): + # Need to create a Firedrake Coargument so it has a .function_space() method + V = Argument(V.dual(), 1 if self.is_adjoint else 0) + + self.target_space = V.arguments()[0].function_space() + if expr.ufl_shape != self.target_space.value_shape: + raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {self.target_space.value_shape}.") + + super().__init__(expr, V) + + self._options = InterpolateOptions(**kwargs) function_space = ufl.Interpolate.ufl_function_space - def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): - interp_data = interp_data or self.interp_data.copy() + def _ufl_expr_reconstruct_( + self, expr: Expr, v: WithGeometry | ufl.BaseForm | None = None, **interp_data + ): + interp_data = interp_data or asdict(self.options) return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) + @property + def options(self) -> InterpolateOptions: + """Access the interpolation options. + + Returns + ------- + InterpolateOptions + An :class:`InterpolateOptions` instance containing the interpolation options. + """ + return self._options + @PETSc.Log.EventDecorator() -def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): +def interpolate(expr: Expr, V: WithGeometry | ufl.BaseForm, **kwargs) -> Interpolate: """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. - :arg expr: a UFL expression. - :arg V: a :class:`.FunctionSpace` to interpolate into, or a :class:`.Cofunction`, - or :class:`.Coargument`, or a :class:`ufl.form.Form` with one argument (a one-form). - If a :class:`.Cofunction` or a one-form is provided, then we do adjoint interpolation. - :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the - interpolation over. Cannot, at present, be used when interpolating - across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg access: The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes unless the target - mesh is a :func:`.VertexOnlyMesh`. See note below. - :kwarg allow_missing_dofs: For interpolation across meshes: allow - degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be - defined on the source mesh. For example, where nodes are point - evaluations, points in the target mesh that are not in the source mesh. - When ``False`` this raises a ``ValueError`` should this occur. When - ``True`` the corresponding values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. In either case, if ``default_missing_val`` is specified, that - value is used. This does not affect adjoint interpolation. Ignored if - interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` - (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at - present, set when it is created). - :kwarg default_missing_val: For interpolation across meshes: the optional - value to assign to DoFs in the target mesh that are outside the source - mesh. If this is not set then the values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. Ignored if interpolating within the same mesh or onto a - :func:`.VertexOnlyMesh`. - :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. - :returns: A symbolic :class:`.Interpolate` object - - .. note:: - - If you use an access descriptor other than ``WRITE``, the - behaviour of interpolation changes if interpolating into a - function space, or an existing function. If the former, then - the newly allocated function will be initialised with - appropriate values (e.g. for MIN access, it will be initialised - with MAX_FLOAT). On the other hand, if you provide a function, - then it is assumed that its values should take part in the - reduction (hence using MIN will compute the MIN between the - existing values and any new values). - """ - if isinstance(V, (Cofunction, Coargument)): - dual_arg = V - elif isinstance(V, ufl.BaseForm): - rank = len(V.arguments()) - if rank == 1: - dual_arg = V - else: - raise TypeError(f"Expected a one-form, provided form had {rank} arguments") - elif isinstance(V, functionspaceimpl.WithGeometry): - dual_arg = Coargument(V.dual(), 0) - expr_args = extract_arguments(ufl.as_ufl(expr)) - if expr_args and expr_args[0].number() == 0: - warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. " - "Use a TrialFunction in the expression.") - v, = expr_args - expr = replace(expr, {v: v.reconstruct(number=1)}) - else: - raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}") - - interp = Interpolate(expr, dual_arg, - subset=subset, access=access, - allow_missing_dofs=allow_missing_dofs, - default_missing_val=default_missing_val, - matfree=matfree) + Parameters + ---------- + expr : ufl.core.expr.Expr + The UFL expression to interpolate. + V : firedrake.functionspaceimpl.WithGeometry or ufl.BaseForm + The function space to interpolate into or the coargument defined + on the dual of the function space to interpolate into. + **kwargs + Additional interpolation options. See :class:`InterpolateOptions` + for available parameters and their descriptions. - return interp + Returns + ------- + Interpolate + A symbolic :class:`Interpolate` object representing the interpolation operation. + """ + return Interpolate(expr, V, **kwargs) class Interpolator(abc.ABC): - """A reusable interpolation object. - - This object can be used to carry out the same interpolation - multiple times (for example in a timestepping loop). + """Base class for calculating interpolation. Should not be instantiated directly; use the + :func:`get_interpolator` function. Parameters ---------- - expr - The underlying ufl.Interpolate or the operand to the ufl.Interpolate. - V - The :class:`.FunctionSpace` or :class:`.Function` to - interpolate into. - subset - An optional :class:`pyop2.types.set.Subset` to apply the - interpolation over. Cannot, at present, be used when interpolating - across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - freeze_expr - Set to True to prevent the expression being - re-evaluated on each call. Cannot, at present, be used when - interpolating across meshes unless the target mesh is a - :func:`.VertexOnlyMesh`. - access - The pyop2 access descriptor for combining updates to shared DoFs. - Only ``op2.WRITE`` is supported at present when interpolating across meshes. - Only ``op2.INC`` is supported for the matrix-free adjoint interpolation. - See note in :func:`.interpolate` if changing this from default. - bcs - An optional list of boundary conditions to zero-out in the - output function space. Interpolator rows or columns which are - associated with boundary condition nodes are zeroed out when this is - specified. - allow_missing_dofs - For interpolation across meshes: allow - degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be - defined on the source mesh. For example, where nodes are point - evaluations, points in the target mesh that are not in the source mesh. - When ``False`` this raises a ``ValueError`` should this occur. When - ``True`` the corresponding values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. Can be overwritten with the ``default_missing_val`` kwarg - of :meth:`interpolate`. This does not affect adjoint interpolation. - Ignored if interpolating within the same mesh or onto a - :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in - this scenario is, at present, set when it is created). - matfree - If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. - - Notes - ----- - - The :class:`Interpolator` holds a reference to the provided - arguments (such that they won't be collected until the - :class:`Interpolator` is also collected). + expr : Interpolate + The symbolic interpolation expression. """ - - def __new__(cls, expr, V, **kwargs): - V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V_target) - - arguments = expr.arguments() - has_mixed_arguments = any(len(a.function_space()) > 1 for a in arguments) - if len(arguments) == 2 and has_mixed_arguments: - return object.__new__(MixedInterpolator) - - operand, = expr.ufl_operands - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - submesh_interp_implemented = \ - all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ - target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ - target_mesh.topological_dimension() == source_mesh.topological_dimension() - if target_mesh is source_mesh or submesh_interp_implemented: - return object.__new__(SameMeshInterpolator) - else: - if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - return object.__new__(SameMeshInterpolator) - elif has_mixed_arguments or len(V_target) > 1: - return object.__new__(MixedInterpolator) - else: - return object.__new__(CrossMeshInterpolator) - - def __init__( - self, - expr: ufl.Interpolate | ufl.classes.Expr, - V: ufl.FunctionSpace | firedrake.function.Function, - subset: op2.Subset | None = None, - freeze_expr: bool = False, - access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, - bcs: Iterable[firedrake.bcs.BCBase] | None = None, - allow_missing_dofs: bool = False, - matfree: bool = True - ): - if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) + def __init__(self, expr: Interpolate): dual_arg, operand = expr.argument_slots() - self.ufl_interpolate = expr - self.expr = operand - self.V = V - self.subset = subset - self.freeze_expr = freeze_expr - self.bcs = bcs - self._allow_missing_dofs = allow_missing_dofs - self.matfree = matfree + self.expr = expr + self.expr_args = expr.arguments() + self.rank = len(self.expr_args) + self.operand = operand + self.dual_arg = dual_arg + self.target_space = dual_arg.function_space().dual() + self.target_mesh = self.target_space.mesh() + self.source_mesh = extract_unique_domain(operand) or self.target_mesh + + # Interpolation options + self.subset = expr.options.subset + self.allow_missing_dofs = expr.options.allow_missing_dofs + self.default_missing_val = expr.options.default_missing_val + self.matfree = expr.options.matfree + self.bcs = expr.options.bcs self.callable = None + self.access = expr.options.access - # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of - # self.ufl_interpolate (which carries the dual argument). - # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(self, SameMeshInterpolator) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - if isinstance(self, CrossMeshInterpolator) or vom_onto_other_vom: - # For bespoke interpolation, we currently rely on different assembly procedures: - # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) - # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) - # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) - # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) - # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) - - # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). - # For case 2, we first redundantly assemble case 1 and then construct the transpose. - # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, - # and we separately compute the action against the dropped Cofunction within assemble(). - if not isinstance(dual_arg, ufl.Coargument): - # Drop the Cofunction - expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) - expr_args = extract_arguments(operand) - if expr_args and expr_args[0].number() == 0: - # Construct the symbolic forward Interpolate - v0, v1 = expr.arguments() - expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), - v1: v1.reconstruct(number=v0.number())}) - - dual_arg, operand = expr.argument_slots() - self.expr_renumbered = operand - self.ufl_interpolate_renumbered = expr + @abc.abstractmethod + def _build_callable(self, tensor: Function | Cofunction | MatrixBase | None = None) -> None: + """Builds callable to perform interpolation. Stored in ``self.callable``. - if not isinstance(dual_arg, ufl.Coargument): - # Matrix-free assembly of 0-form or 1-form requires INC access - if access and access != op2.INC: - raise ValueError("Matfree adjoint interpolation requires INC access") - access = op2.INC - elif access is None: - # Default access for forward 1-form or 2-form (forward and adjoint) - access = op2.WRITE - self.access = access + If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle`` + attribute that stores a PETSc matrix. If ``self.rank == 1``, then `self.callable()` must + return a ``Function`` or ``Cofunction`` (in the forward and adjoint cases respectively). + If ``self.rank == 0``, then ``self.callable()`` must return a number. - def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): + Parameters + ---------- + tensor + Optional tensor to store the result in, by default None. """ - .. warning:: + pass - This method has been removed. Use the function :func:`interpolate` to return a symbolic - :class:`Interpolate` object. - """ - raise FutureWarning( - "The 'interpolate' method on `Interpolator` objects has been " - "removed. Use the `interpolate` function instead." - ) + def assemble( + self, tensor: Function | Cofunction | MatrixBase | None = None + ) -> Function | Cofunction | MatrixBase | Number: + """Assemble the interpolation. The result depends on the rank (number of arguments) + of the :class:`Interpolate` expression: - @abc.abstractmethod - def _interpolate(self, *args, **kwargs): - """ - Compute the interpolation operation of interest. + * rank-2: assemble the operator and return a matrix + * rank-1: assemble the action and return a function or cofunction + * rank-0: assemble the action and return a scalar by applying the dual argument - .. note:: - This method is called when an :class:`Interpolate` object is being assembled. + Parameters + ---------- + tensor : Function | Cofunction | MatrixBase + Optional tensor to store the interpolated result. For rank-2 + expressions this is expected to be a subclass of + :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions + this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`, + for forward and adjoint interpolation respectively. + + Returns + ------- + Function | Cofunction | MatrixBase | numbers.Number + The function, cofunction, matrix, or scalar resulting from the + interpolation. """ - pass - - def assemble(self, tensor=None, default_missing_val=None): - """Assemble the operator (or its action).""" - from firedrake.assemble import assemble - needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate - arguments = self.ufl_interpolate.arguments() - if len(arguments) == 2: + self._build_callable(tensor=tensor) + result = self.callable() + if self.rank == 2: # Assembling the operator - res = tensor.petscmat if tensor else PETSc.Mat() + assert isinstance(tensor, MatrixBase | None) # Get the interpolation matrix - op2mat = self.callable() - petsc_mat = op2mat.handle - if needs_adjoint: - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - elif tensor: + petsc_mat = result.handle + if tensor: petsc_mat.copy(tensor.petscmat) - else: - res = petsc_mat - return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) + return tensor + return firedrake.AssembledMatrix(self.expr_args, self.bcs, petsc_mat) else: - # Assembling the action - cofunctions = () - if needs_adjoint: - # The renumbered Interpolate has dropped Cofunctions. - # We need to explicitly operate on them. - dual_arg, _ = self.ufl_interpolate.argument_slots() - if not isinstance(dual_arg, ufl.Coargument): - cofunctions = (dual_arg,) - - if needs_adjoint and len(arguments) == 0: - Iu = self._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) - else: - return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - default_missing_val=default_missing_val) + assert isinstance(tensor, Function | Cofunction | None) + return tensor.assign(result) if tensor else result + + +def get_interpolator(expr: Interpolate) -> Interpolator: + """Create an Interpolator. + + Parameters + ---------- + expr : Interpolate + Symbolic interpolation expression. + + Returns + ------- + Interpolator + An appropriate :class:`Interpolator` subclass for the given + interpolation expression. + """ + arguments = expr.arguments() + has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) + if len(arguments) == 2 and has_mixed_arguments: + return MixedInterpolator(expr) + + operand, = expr.ufl_operands + target_mesh = expr.target_space.mesh() + source_mesh = extract_unique_domain(operand) or target_mesh + submesh_interp_implemented = ( + all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) + and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] + and target_mesh.topological_dimension() == source_mesh.topological_dimension() + ) + if target_mesh is source_mesh or submesh_interp_implemented: + return SameMeshInterpolator(expr) + + target_topology = target_mesh.topology + source_topology = source_mesh.topology + + if isinstance(target_topology, VertexOnlyMeshTopology): + if isinstance(source_topology, VertexOnlyMeshTopology): + return VomOntoVomInterpolator(expr) + if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): + raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") + if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: + raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") + return SameMeshInterpolator(expr) + + if has_mixed_arguments or len(expr.target_space) > 1: + return MixedInterpolator(expr) + + return CrossMeshInterpolator(expr) class DofNotDefinedError(Exception): @@ -453,26 +349,17 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__( - self, - expr, - V, - subset=None, - freeze_expr=False, - access=None, - bcs=None, - allow_missing_dofs=False, - matfree=True - ): - if subset: - raise NotImplementedError("subset not implemented") - if freeze_expr: - # Probably just need to pass freeze_expr to the various - # interpolators for this to work. - raise NotImplementedError("freeze_expr not implemented") - if bcs: - raise NotImplementedError("bcs not implemented") - if V.ufl_element().mapping() != "identity": + def __init__(self, expr: Interpolate): + super().__init__(expr) + if self.access and self.access != op2.WRITE: + raise NotImplementedError( + "Access other than op2.WRITE not implemented for cross-mesh interpolation." + ) + else: + self.access = op2.WRITE + if self.bcs: + raise NotImplementedError("bcs not implemented for cross-mesh interpolation.") + if self.target_space.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would # require finding the global coordinates of all quadrature points @@ -480,246 +367,150 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) - - if self.access != op2.WRITE: - raise NotImplementedError("access other than op2.WRITE not implemented") + if self.allow_missing_dofs: + self.missing_points_behaviour = MissingPointsBehaviour.IGNORE + else: + self.missing_points_behaviour = MissingPointsBehaviour.ERROR - expr = self.expr_renumbered - self.arguments = extract_arguments(expr) - self.nargs = len(self.arguments) + if self.source_mesh.geometric_dimension() != self.target_mesh.geometric_dimension(): + raise ValueError("Geometric dimensions of source and destination meshes must match.") - if self._allow_missing_dofs: - missing_points_behaviour = MissingPointsBehaviour.IGNORE + dest_element = self.target_space.ufl_element() + if isinstance(dest_element, finat.ufl.MixedElement): + if isinstance(dest_element, finat.ufl.VectorElement | finat.ufl.TensorElement): + # In this case all sub elements are equal + base_element = dest_element.sub_elements[0] + if base_element.reference_value_shape != (): + raise NotImplementedError( + "Can't yet cross-mesh interpolate onto function spaces made from VectorElements " + "or TensorElements made from sub elements with value shape other than ()." + ) + self.dest_element = base_element + else: + raise NotImplementedError("Interpolation with MixedFunctionSpace requires MixedInterpolator.") else: - missing_points_behaviour = MissingPointsBehaviour.ERROR - - # setup - V_dest = V.function_space() if isinstance(V, firedrake.Function) else V - src_mesh = extract_unique_domain(expr) - dest_mesh = as_domain(V_dest) - src_mesh_gdim = src_mesh.geometric_dimension() - dest_mesh_gdim = dest_mesh.geometric_dimension() - if src_mesh_gdim != dest_mesh_gdim: - raise ValueError( - "geometric dimensions of source and destination meshes must match" - ) - self.src_mesh = src_mesh - self.dest_mesh = dest_mesh + # scalar fiat/finat element + self.dest_element = dest_element - # Create a VOM at the nodes of V_dest in src_mesh. We don't include halo - # node coordinates because interpolation doesn't usually include halos. - # NOTE: it is very important to set redundant=False, otherwise the - # input ordering VOM will only contain the points on rank 0! - # QUESTION: Should any of the below have annotation turned off? - ufl_scalar_element = V_dest.ufl_element() - if isinstance(ufl_scalar_element, finat.ufl.MixedElement): - if type(ufl_scalar_element) is finat.ufl.MixedElement: - raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") - - # For a VectorElement or TensorElement the correct - # VectorFunctionSpace equivalent is built from the scalar - # sub-element. - ufl_scalar_element, = set(ufl_scalar_element.sub_elements) - if ufl_scalar_element.reference_value_shape != (): - raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." - ) + self._build_symbolic_expressions() + + def _build_symbolic_expressions(self) -> None: + """Constructs the symbolic ``Interpolate`` expressions for cross-mesh interpolation. + Raises + ------ + DofNotDefinedError + If some DoFs in the target function space cannot be defined + in the source function space. + """ from firedrake.assemble import assemble - V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) - f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec) - f_dest_node_coords = assemble(f_dest_node_coords) - dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim) + # Immerse coordinates of target space point evaluation dofs in src_mesh + target_space_vec = firedrake.VectorFunctionSpace(self.target_mesh, self.dest_element) + f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec)) + dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension()) try: - self.vom_dest_node_coords_in_src_mesh = firedrake.VertexOnlyMesh( - src_mesh, + self.vom = firedrake.VertexOnlyMesh( + self.source_mesh, dest_node_coords, redundant=False, - missing_points_behaviour=missing_points_behaviour, + missing_points_behaviour=self.missing_points_behaviour, ) except VertexOnlyMeshMissingPointsError: - raise DofNotDefinedError(src_mesh, dest_mesh) - # vom_dest_node_coords_in_src_mesh uses the parallel decomposition of - # the global node coordinates of V_dest in the SOURCE mesh (src_mesh). - # I first point evaluate my expression at these locations, giving a - # P0DG function on the VOM. As described in the manual, this is an - # interpolation operation. - shape = V_dest.ufl_function_space().value_shape + raise DofNotDefinedError(self.source_mesh, self.target_mesh) + + # Get the correct type of function space + shape = self.target_space.ufl_function_space().value_shape if len(shape) == 0: fs_type = firedrake.FunctionSpace elif len(shape) == 1: fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0]) else: fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) - P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) - self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom) - # The parallel decomposition of the nodes of V_dest in the DESTINATION - # mesh (dest_mesh) is retrieved using the input_ordering attribute of the - # VOM. This again is an interpolation operation, which, under the hood - # is a PETSc SF reduce. - P0DG_vom_i_o = fs_type( - self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0 - ) - self.to_input_ordering_interpolate = Interpolate( - firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o - ) - # The P0DG function outputted by the above interpolation has the - # correct parallel decomposition for the nodes of V_dest in dest_mesh so - # we can safely assign the dat values. This is all done in the actual - # interpolation method below. - - @PETSc.Log.EventDecorator() - def _interpolate( - self, - *function, - output=None, - transpose=None, - adjoint=False, - default_missing_val=None, - **kwargs, - ): - """Compute the interpolation. - For arguments, see :class:`.Interpolator`. - """ - from firedrake.assemble import assemble + # Get expression for point evaluation at the dest_node_coords + self.P0DG_vom = fs_type(self.vom, "DG", 0) + self.point_eval = interpolate(self.operand, self.P0DG_vom) - if transpose is not None: - warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) - adjoint = transpose or adjoint - if adjoint and not self.nargs: - raise ValueError( - "Can currently only apply adjoint interpolation with arguments." - ) - if self.nargs != len(function): - raise ValueError( - "Passed %d Functions to interpolate, expected %d" - % (len(function), self.nargs) - ) + # If assembling the operator, we need the concrete permutation matrix + matfree = False if self.rank == 2 else self.matfree - if self.nargs: - (f_src,) = function - if not hasattr(f_src, "dat"): - raise ValueError( - "The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" - ) - else: - f_src = self.expr - - if adjoint: - try: - V_dest = self.expr.function_space().dual() - except AttributeError: - if self.nargs: - V_dest = self.arguments[-1].function_space().dual() - else: - coeffs = extract_coefficients(self.expr) - if len(coeffs): - V_dest = coeffs[0].function_space().dual() - else: - raise ValueError( - "Can't adjoint interpolate an expression with no coefficients or arguments." - ) - else: - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): - V_dest = self.V.function_space() + # Interpolate into the input-ordering VOM + self.P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0) + + arg = Argument(self.P0DG_vom, 0 if self.expr.is_adjoint else 1) + self.point_eval_input_ordering = interpolate(arg, self.P0DG_vom_input_ordering, matfree=matfree) + + def _build_callable(self, tensor=None): + from firedrake.assemble import assemble + # self.expr.function_space() is None in the 0-form case + V_dest = self.expr.function_space() or self.target_space + f = tensor or Function(V_dest) + + if self.rank == 2: + # The cross-mesh interpolation matrix is the product of the + # `self.point_eval_interpolate` and the permutation + # given by `self.to_input_ordering_interpolate`. + if self.expr.is_adjoint: + symbolic = action(self.point_eval, self.point_eval_input_ordering) else: - V_dest = self.V - if output: - if output.function_space() != V_dest: - raise ValueError("Given output has the wrong function space!") + symbolic = action(self.point_eval_input_ordering, self.point_eval) + self.handle = assemble(symbolic).petscmat + + def callable() -> CrossMeshInterpolator: + return self + elif self.expr.is_adjoint: + assert self.rank == 1 + # f_src is a cofunction on V_dest.dual + cofunc = self.dual_arg + assert isinstance(cofunc, Cofunction) + + # Our first adjoint operation is to assign the dat values to a + # P0DG cofunction on our input ordering VOM. + f_input_ordering = Cofunction(self.P0DG_vom_input_ordering.dual()) + f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] + + # The rest of the adjoint interpolation is the composition + # of the adjoint interpolators in the reverse direction. + # We don't worry about skipping over missing points here + # because we're going from the input ordering VOM to the original VOM + # and all points from the input ordering VOM are in the original. + def callable() -> Cofunction: + f_src_at_src_node_coords = assemble(action(self.point_eval_input_ordering, f_input_ordering)) + assemble(action(self.point_eval, f_src_at_src_node_coords), tensor=f) + return f else: - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): - output = self.V - else: - output = firedrake.Function(V_dest) - - if not adjoint: - if f_src is self.expr: - # f_src is already contained in self.point_eval_interpolate - assert not self.nargs - f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(self.point_eval_interpolate) - ) - else: - f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(action(self.point_eval_interpolate, f_src)) - ) - f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( - self.to_input_ordering_interpolate.function_space() - ) - # We have to create the Function before interpolating so we can - # set default missing values (if requested). - if default_missing_val is not None: - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ - : - ] = default_missing_val - elif self._allow_missing_dofs: - # If we have allowed missing points we know we might end up - # with points in the target mesh that are not in the source - # mesh. However, since we haven't specified a default missing - # value we expect the interpolation to leave these points - # unchanged. By setting the dat values to NaN we can later - # identify these points and skip over them when assigning to - # the output function. - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[:] = numpy.nan - - interp = action(self.to_input_ordering_interpolate, f_src_at_dest_node_coords_src_mesh_decomp) - assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) - - # we can now confidently assign this to a function on V_dest - if self._allow_missing_dofs and default_missing_val is None: - indices = numpy.where( - ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) - )[0] - output.dat.data_wo[ - indices - ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[indices] - else: - output.dat.data_wo[ - : - ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[:] + assert self.rank in {0, 1} + # We evaluate the operand at the node coordinates of the destination space + f_point_eval = assemble(self.point_eval) + + # We create the input-ordering Function before interpolating so we can + # set default missing values if required. + f_point_eval_input_ordering = Function(self.P0DG_vom_input_ordering) + if self.default_missing_val is not None: + f_point_eval_input_ordering.assign(self.default_missing_val) + elif self.allow_missing_dofs: + # If we allow missing points there may be points in the target + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify + # them later. + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + + def callable() -> Function | Number: + assemble(action(self.point_eval_input_ordering, f_point_eval), tensor=f_point_eval_input_ordering) + # We assign these values to the output function + if self.allow_missing_dofs and self.default_missing_val is None: + indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] + f.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] + else: + f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] - else: - # adjoint interpolation - - # f_src is a cofunction on V_dest.dual as originally specified when - # creating the interpolator. Our first adjoint operation is to - # assign the dat values to a P0DG cofunction on our input ordering - # VOM. This has the parallel decomposition V_dest on our orinally - # specified dest_mesh. We can therefore safely create a P0DG - # cofunction on the input-ordering VOM (which has this parallel - # decomposition and ordering) and assign the dat values. - f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Cofunction( - self.to_input_ordering_interpolate.function_space().dual() - ) - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ - : - ] = f_src.dat.data_ro[:] - - # The rest of the adjoint interpolation is merely the composition - # of the adjoint interpolators in the reverse direction. NOTE: I - # don't have to worry about skipping over missing points here - # because I'm going from the input ordering VOM to the original VOM - # and all points from the input ordering VOM are in the original. - interp = action(expr_adjoint(self.to_input_ordering_interpolate), f_src_at_dest_node_coords_dest_mesh_decomp) - f_src_at_src_node_coords = assemble(interp) - # NOTE: if I wanted the default missing value to be applied to - # adjoint interpolation I would have to do it here. However, - # this would require me to implement default missing values for - # adjoint interpolation from a point evaluation interpolator - # which I haven't done. I wonder if it is necessary - perhaps the - # adjoint operator always sets all the values of the resulting - # cofunction? My initial attempt to insert setting the dat values - # prior to performing the multHermitian operation in - # SameMeshInterpolator.interpolate did not effect the result. For - # now, I say in the docstring that it only applies to forward - # interpolation. - interp = action(expr_adjoint(self.point_eval_interpolate), f_src_at_src_node_coords) - assemble(interp, tensor=output) - - return output + if self.rank == 0: + # We take the action of the dual_arg on the interpolated function + assert not isinstance(self.dual_arg, ufl.Coargument) + return assemble(action(self.dual_arg, f)) + else: + return f + self.callable = callable class SameMeshInterpolator(Interpolator): @@ -731,246 +522,204 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, subset=None, freeze_expr=False, access=None, - bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): + def __init__(self, expr): + super().__init__(expr) + subset = self.subset if subset is None: - if isinstance(expr, ufl.Interpolate): - operand, = expr.ufl_operands - else: - operand = expr - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - target = target_mesh.topology - source = source_mesh.topology + target = self.target_mesh.topology + source = self.source_mesh.topology if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": - raise AssertionError("Only cell-cell interpolation supported") + raise AssertionError("Only cell-cell interpolation supported.") indices_active = composed_map.indices_active_with_halo make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: - if not allow_missing_dofs: - raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") + if not self.allow_missing_dofs: + raise ValueError("Iteration (sub)set unclear: run with `allow_missing_dofs=True`.") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: # Do not need subset as target <= source. pass - super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, - access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) - expr = self.ufl_interpolate_renumbered - try: - self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=matfree) - except FIAT.hdiv_trace.TraceError: - raise NotImplementedError("Can't interpolate onto traces sorry") - self.arguments = expr.arguments() - - @PETSc.Log.EventDecorator() - def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs): - """Compute the interpolation. - - For arguments, see :class:`.Interpolator`. - """ - - if transpose is not None: - warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) - adjoint = transpose or adjoint - try: - assembled_interpolator = self.frozen_assembled_interpolator - copy_required = True - except AttributeError: - assembled_interpolator = self.callable() - copy_required = False # Return the original - if self.freeze_expr: - if len(self.arguments) == 2: - # Interpolation operator - self.frozen_assembled_interpolator = assembled_interpolator - else: - # Interpolation action - self.frozen_assembled_interpolator = assembled_interpolator.copy() - - if len(self.arguments) == 2 and len(function) > 0: - function, = function - if not hasattr(function, "dat"): - raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") - if adjoint: - mul = assembled_interpolator.handle.multHermitian - col, row = self.arguments - else: - mul = assembled_interpolator.handle.mult - row, col = self.arguments - V = row.function_space().dual() - assert function.function_space() == col.function_space() - - result = output or firedrake.Function(V) - with function.dat.vec_ro as x, result.dat.vec_wo as out: - if x is not out: - mul(x, out) - else: - out_ = out.duplicate() - mul(x, out_) - out_.copy(result=out) - return result - - else: - if output: - output.assign(assembled_interpolator) - return output - if isinstance(self.V, firedrake.Function): - if copy_required: - self.V.assign(assembled_interpolator) - return self.V - else: - if len(self.arguments) == 0: - return assembled_interpolator.dat.data.item() - elif copy_required: - return assembled_interpolator.copy() - else: - return assembled_interpolator + self.subset = subset + if not isinstance(self.dual_arg, ufl.Coargument): + # Matrix-free assembly of 0-form or 1-form requires INC access + if self.access and self.access != op2.INC: + raise ValueError("Matfree adjoint interpolation requires INC access") + self.access = op2.INC + elif self.access is None: + # Default access for forward 1-form or 2-form (forward and adjoint) + self.access = op2.WRITE -@PETSc.Log.EventDecorator() -def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): - if not isinstance(expr, ufl.Interpolate): - raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") - dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) - source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) + def _get_tensor(self) -> op2.Mat | Function | Cofunction: + """Return the tensor to interpolate into. - arguments = expr.arguments() - rank = len(arguments) - if rank <= 1: - if rank == 0: - R = firedrake.FunctionSpace(target_mesh, "Real", 0) - f = firedrake.Function(R, dtype=utils.ScalarType) - elif isinstance(V, firedrake.Function): - f = V - V = f.function_space() - else: - V_dest = arguments[0].function_space().dual() - f = firedrake.Function(V_dest) - if access in {firedrake.MIN, firedrake.MAX}: + Returns + ------- + op2.Mat | Function | Cofunction + The tensor to interpolate into. + """ + if self.rank == 0: + R = firedrake.FunctionSpace(self.target_mesh, "Real", 0) + f = Function(R, dtype=utils.ScalarType) + elif self.rank == 1: + f = Function(self.expr.function_space()) + if self.access in {firedrake.MIN, firedrake.MAX}: finfo = numpy.finfo(f.dat.dtype) - if access == firedrake.MIN: + if self.access == firedrake.MIN: val = firedrake.Constant(finfo.max) else: val = firedrake.Constant(finfo.min) f.assign(val) - tensor = f.dat - elif rank == 2: - if isinstance(V, firedrake.Function): - raise ValueError("Cannot interpolate an expression with an argument into a Function") - Vrow = arguments[0].function_space() - Vcol = arguments[1].function_space() - if len(Vrow) > 1 or len(Vcol) > 1: - raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") - if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - - if vom_onto_other_vom: - # We make our own linear operator for this case using PETSc SFs - tensor = None - else: - Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) - Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + elif self.rank == 2: + Vrow = self.expr_args[0].function_space() + Vcol = self.expr_args[1].function_space() + if len(Vrow) > 1 or len(Vcol) > 1: + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") + Vrow_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vrow) + Vcol_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed - name="%s_%s_sparsity" % (Vrow.name, Vcol.name), + name=f"{Vrow.name}_{Vcol.name}_sparsity", nest=False, block_sparse=True) - tensor = op2.Mat(sparsity) - f = tensor - else: - raise ValueError(f"Cannot interpolate an expression with {rank} arguments") - - if vom_onto_other_vom: - wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree) - # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the - # data, including the correct data size and dimensional information - # (so for vector function spaces in 2 dimensions we might need a - # concatenation of 2 MPI.DOUBLE types when we are in real mode) - if tensor is not None: - # Callable will do interpolation into our pre-supplied function f - # when it is called. - assert f.dat is tensor - wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert len(arguments) == 1 - - def callable(): - wrapper.forward_operation(f.dat) - return f + f = op2.Mat(sparsity) else: - assert len(arguments) == 2 - assert tensor is None - # we know we will be outputting either a function or a cofunction, - # both of which will use a dat as a data carrier. At present, the - # data type does not depend on function space dimension, so we can - # safely use the argument function space. NOTE: If this changes - # after cofunctions are fully implemented, this will need to be - # reconsidered. - temp_source_func = firedrake.Function(Vcol) - wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) + raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") + return f - # Leave wrapper inside a callable so we can access the handle - # property. If matfree is True, then the handle is a PETSc SF - # pretending to be a PETSc Mat. If matfree is False, then this - # will be a PETSc Mat representing the equivalent permutation - # matrix - def callable(): - return wrapper + def _build_callable(self, tensor=None) -> None: + f = tensor or self._get_tensor() + op2_tensor = f if isinstance(f, op2.Mat) else f.dat - return callable - else: loops = [] - # Initialise to zero if needed - if access is op2.INC: - loops.append(tensor.zero) # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels - if len(arguments) == 2: - # Matrix case assumes that the spaces are not mixed - expressions = {(0,): expr} - elif isinstance(dual_arg, Coargument): + if self.rank == 2: + expressions = {(0,): self.expr} + elif isinstance(self.dual_arg, Coargument): # Split in the coargument - expressions = dict(firedrake.formmanipulation.split_form(expr)) + expressions = dict(firedrake.formmanipulation.split_form(self.expr)) else: + assert isinstance(self.dual_arg, Cofunction) # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian - interp = expr._ufl_expr_reconstruct_(operand, V) + interp = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) # Split the Jacobian into blocks interp_split = dict(firedrake.formmanipulation.split_form(interp)) # Split the cofunction - dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + dual_split = dict(firedrake.formmanipulation.split_form(self.dual_arg)) # Combine the splits by taking their action expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): - sub_tensor = tensor[indices[0]] if rank == 1 else tensor - loops.extend(_interpolator(sub_tensor, sub_expr, subset, access, bcs=bcs)) - # Apply bcs - if bcs and rank == 1: - loops.extend(partial(bc.apply, f) for bc in bcs) + sub_op2_tensor = op2_tensor[indices[0]] if self.rank == 1 else op2_tensor + loops.extend(_build_interpolation_callables(sub_expr, sub_op2_tensor, self.access, self.subset, self.bcs)) + + if self.bcs and self.rank == 1: + loops.extend(partial(bc.apply, f) for bc in self.bcs) def callable(loops, f): for l in loops: l() - return f + return f.dat.data.item() if self.rank == 0 else f + + self.callable = partial(callable, loops, f) + + +class VomOntoVomInterpolator(SameMeshInterpolator): + + def __init__(self, expr: Interpolate): + super().__init__(expr) + + def _build_callable(self, tensor=None): + self.mat = VomOntoVomMat(self) + if self.rank == 2: + # We make our own linear operator for this case using PETSc SFs + op2_tensor = None + else: + f = tensor or self._get_tensor() + op2_tensor = f.dat + # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the + # data, including the correct data size and dimensional information + # (so for vector function spaces in 2 dimensions we might need a + # concatenation of 2 MPI.DOUBLE types when we are in real mode) + if op2_tensor is not None: + assert self.rank == 1 + self.mat.mpi_type = get_dat_mpi_type(f.dat)[0] + if self.expr.is_adjoint: + assert isinstance(self.dual_arg, ufl.Cofunction) + assert isinstance(f, Cofunction) + + def callable() -> Cofunction: + with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec: + self.mat.handle.multHermitian(source_vec, target_vec) + return f + else: + assert isinstance(f, Function) + + def callable() -> Function: + coeff = self.mat.expr_as_coeff() + with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec: + self.mat.handle.mult(coeff_vec, target_vec) + return f + else: + assert self.rank == 2 + # we know we will be outputting either a function or a cofunction, + # both of which will use a dat as a data carrier. At present, the + # data type does not depend on function space dimension, so we can + # safely use the argument function space. NOTE: If this changes + # after cofunctions are fully implemented, this will need to be + # reconsidered. + temp_source_func = Function(self.expr_args[1].function_space()) + self.mat.mpi_type = get_dat_mpi_type(temp_source_func.dat)[0] + # Leave mat inside a callable so we can access the handle + # property. If matfree is True, then the handle is a PETSc SF + # pretending to be a PETSc Mat. If matfree is False, then this + # will be a PETSc Mat representing the equivalent permutation matrix + + def callable() -> VomOntoVomMat: + return self.mat - return partial(callable, loops, f) + self.callable = callable @utils.known_pyop2_safe -def _interpolator(tensor, expr, subset, access, bcs=None): +def _build_interpolation_callables( + expr: ufl.Interpolate | ufl.ZeroBaseForm, + tensor: op2.Dat | op2.Mat | op2.Global, + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC], + subset: op2.Subset | None = None, + bcs: Iterable[DirichletBC] | None = None +) -> tuple[Callable, ...]: + """Returns tuple of callables which calculate the interpolation. + + Parameters + ---------- + expr : ufl.Interpolate | ufl.ZeroBaseForm + The symbolic interpolation expression, or a zero form. Zero forms + are simplified here to avoid code generation when access is WRITE or INC. + tensor : op2.Dat | op2.Mat | op2.Global + Object to hold the result of the interpolation. + access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] + op2 access descriptor + subset : op2.Subset | None + An optional subset to apply the interpolation over, by default None. + bcs : Iterable[DirichletBC] | None + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None, by default None. + + Returns + ------- + tuple[Callable, ...] + Tuple of callables which perform the interpolation. + """ if isinstance(expr, ufl.ZeroBaseForm): # Zero simplification, avoid code-generation if access is op2.INC: @@ -981,54 +730,43 @@ def _interpolator(tensor, expr, subset, access, bcs=None): # Reconstruct the expression as an Interpolate V = expr.arguments()[-1].function_space().dual() expr = interpolate(ufl.zero(V.value_shape), V) - if not isinstance(expr, ufl.Interpolate): raise ValueError("Expecting to interpolate a ufl.Interpolate") - - arguments = expr.arguments() dual_arg, operand = expr.argument_slots() - V = dual_arg.arguments()[0].function_space() - + V = dual_arg.function_space().dual() try: to_element = create_element(V.ufl_element()) except KeyError: # FInAT only elements - raise NotImplementedError("Don't know how to create FIAT element for %s" % V.ufl_element()) + raise NotImplementedError(f"Don't know how to create FIAT element for {V.ufl_element()}") if access is op2.READ: raise ValueError("Can't have READ access for output function") # NOTE: The par_loop is always over the target mesh cells. - target_mesh = as_domain(V) + target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - if target_mesh is not source_mesh: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") - if target_mesh.geometric_dimension() != source_mesh.geometric_dimension(): - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - # For trans-mesh interpolation we use a FInAT QuadratureElement as the - # (base) target element with runtime point set expressions as their - # quadrature rule point set and weights from their dual basis. - # NOTE: This setup is useful for thinking about future design - in the - # future this `rebuild` function can be absorbed into FInAT as a - # transformer that eats an element and gives you an equivalent (which - # may or may not be a QuadratureElement) that lets you do run time - # tabulation. Alternatively (and this all depends on future design - # decision about FInAT how dual evaluation should work) the - # to_element's dual basis (which look rather like quadrature rules) can - # have their pointset(s) directly replaced with run-time tabulated - # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) - rt_var_name = 'rt_X' - try: - cell = operand.ufl_element().ufl_cell() - except AttributeError: - # expression must be pure function of spatial coordinates so - # domain has correct ufl cell - cell = source_mesh.ufl_cell() - to_element = rebuild(to_element, cell, rt_var_name) + # For trans-mesh interpolation we use a FInAT QuadratureElement as the + # (base) target element with runtime point set expressions as their + # quadrature rule point set and weights from their dual basis. + # NOTE: This setup is useful for thinking about future design - in the + # future this `rebuild` function can be absorbed into FInAT as a + # transformer that eats an element and gives you an equivalent (which + # may or may not be a QuadratureElement) that lets you do run time + # tabulation. Alternatively (and this all depends on future design + # decision about FInAT how dual evaluation should work) the + # to_element's dual basis (which look rather like quadrature rules) can + # have their pointset(s) directly replaced with run-time tabulated + # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) + rt_var_name = 'rt_X' + try: + cell = operand.ufl_element().ufl_cell() + except AttributeError: + # expression must be pure function of spatial coordinates so + # domain has correct ufl cell + cell = source_mesh.ufl_cell() + to_element = rebuild(to_element, cell, rt_var_name) cell_set = target_mesh.cell_set if subset is not None: @@ -1038,36 +776,32 @@ def _interpolator(tensor, expr, subset, access, bcs=None): parameters = {} parameters['scalar_type'] = utils.ScalarType - copyin = () - copyout = () + callables = () # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple # contributions from the facet DOFs of the dual argument. # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity. needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() if needs_weight: - # Create a buffer for the weighted Cofunction - W = dual_arg.function_space() - v = firedrake.Function(W) - expr = expr._ufl_expr_reconstruct_(operand, v=v) - copyin += (partial(dual_arg.dat.copy, v.dat),) - # Compute the reciprocal of the DOF multiplicity - wdat = W.make_dat() - m_ = get_interp_node_map(source_mesh, target_mesh, W) + W = dual_arg.function_space() wsize = W.finat_element.space_dimension() * W.block_size kernel_code = f""" void multiplicity(PetscScalar *restrict w) {{ for (PetscInt i=0; i<{wsize}; i++) w[i] += 1; }}""" - kernel = op2.Kernel(kernel_code, "multiplicity") - op2.par_loop(kernel, cell_set, wdat(op2.INC, m_)) - with wdat.vec as w: + kernel = op2.Kernel(kernel_code, "multiplicity", requires_zeroed_output_arguments=False) + weight = firedrake.Function(W) + m_ = get_interp_node_map(source_mesh, target_mesh, W) + op2.par_loop(kernel, cell_set, weight.dat(op2.INC, m_)) + with weight.dat.vec as w: w.reciprocal() - # Create a callable to apply the weight - with wdat.vec_ro as w, v.dat.vec as y: - copyin += (partial(y.pointwiseMult, y, w),) + # Create a buffer for the weighted Cofunction and a callable to apply the weight + v = firedrake.Function(W) + expr = expr._ufl_expr_reconstruct_(operand, v=v) + with weight.dat.vec_ro as w, dual_arg.dat.vec_ro as x, v.dat.vec_wo as y: + callables += (partial(y.pointwiseMult, x, w),) # We need to pass both the ufl element and the finat element # because the finat elements might not have the right mapping @@ -1083,7 +817,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None): coefficient_numbers = kernel.coefficient_numbers needs_external_coords = kernel.needs_external_coords name = kernel.name - kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=(access is not op2.INC), + kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True, flop_count=kernel.flop_count, events=(kernel.event,)) parloop_args = [kernel, cell_set] @@ -1096,12 +830,19 @@ def _interpolator(tensor, expr, subset, access, bcs=None): output = tensor tensor = op2.Dat(tensor.dataset) if access is not op2.WRITE: - copyin += (partial(output.copy, tensor), ) - copyout += (partial(tensor.copy, output), ) + copyin = (partial(output.copy, tensor), ) + else: + copyin = () + copyout = (partial(tensor.copy, output), ) + else: + copyin = () + copyout = () + + arguments = expr.arguments() if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): - V_dest = arguments[-1].function_space() + V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V m_ = get_interp_node_map(source_mesh, target_mesh, V_dest) parloop_args.append(tensor(access, m_)) else: @@ -1122,9 +863,11 @@ def _interpolator(tensor, expr, subset, access, bcs=None): bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] parloop_args.append(tensor(access, (rows_map, columns_map), lgmaps=lgmaps)) + if oriented: co = target_mesh.cell_orientations() parloop_args.append(co.dat(op2.READ, co.cell_node_map())) + if needs_cell_sizes: cs = source_mesh.cell_sizes parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) @@ -1161,13 +904,17 @@ def _interpolator(tensor, expr, subset, access, bcs=None): parloop_args.append(target_ref_coords.dat(op2.READ, m_)) parloop = op2.ParLoop(*parloop_args) + parloop_compute_callable = parloop.compute if isinstance(tensor, op2.Mat): - return parloop, tensor.assemble + return parloop_compute_callable, tensor.assemble else: - return copyin + (parloop, ) + copyout + extra = copyin + callables + if access == op2.INC: + extra += (tensor.zero,) + return extra + (parloop_compute_callable, ) + copyout -def get_interp_node_map(source_mesh, target_mesh, fs): +def get_interp_node_map(source_mesh: MeshGeometry, target_mesh: MeshGeometry, fs: WithGeometry) -> op2.Map | None: """Return the map between cells of the target mesh and nodes of the function space. If the function space is defined on the source mesh then the node map is composed @@ -1269,7 +1016,7 @@ def rebuild_te(element, expr_cell, rt_var_name): transpose=element._transpose) -def compose_map_and_cache(map1, map2): +def compose_map_and_cache(map1: op2.Map, map2: op2.Map | None) -> op2.ComposedMap | None: """ Retrieve a :class:`pyop2.ComposedMap` map from the cache of map1 using map2 as the cache key. The composed map maps from the iterset @@ -1292,7 +1039,7 @@ def compose_map_and_cache(map1, map2): return cmap -def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map): +def vom_cell_parent_node_map_extruded(vertex_only_mesh: MeshGeometry, extruded_cell_node_map: op2.Map) -> op2.Map: """Build a map from the cells of a vertex only mesh to the nodes of the nodes on the source mesh where the source mesh is extruded. @@ -1418,119 +1165,55 @@ def __init__(self, glob): self.ufl_domain = lambda: None -class VomOntoVomWrapper(object): - """Utility class for interpolating from one ``VertexOnlyMesh`` to it's - intput ordering ``VertexOnlyMesh``, or vice versa. +class VomOntoVomMat: + """Object that facilitates interpolation between two vertex-only meshes.""" + def __init__(self, interpolator: VomOntoVomInterpolator): + """Initialises the VomOntoVomMat. - Parameters - ---------- - V : `.FunctionSpace` - The P0DG function space (which may be vector or tensor valued) on the - source vertex-only mesh. - source_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate from. - target_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate to. - expr : `ufl.Expr` - The expression to interpolate. If ``arguments`` is not empty, those - arguments must be present within it. - matfree : bool - If ``False``, the matrix representating the permutation of the points is - constructed and used to perform the interpolation. If ``True``, then the - interpolation is performed using the broadcast and reduce operations on the - PETSc Star Forest. - """ + Parameters + ---------- + interpolator : VomOntoVomInterpolator + A :class:`VomOntoVomInterpolator` object. - def __init__(self, V, source_vom, target_vom, expr, matfree): - arguments = extract_arguments(expr) - reduce = False - if source_vom.input_ordering is target_vom: - reduce = True - original_vom = source_vom - elif target_vom.input_ordering is source_vom: - original_vom = target_vom + Raises + ------ + ValueError + If the source and target vertex-only meshes are not linked by input_ordering. + """ + if interpolator.source_mesh.input_ordering is interpolator.target_mesh: + self.forward_reduce = True + self.original_vom = interpolator.source_mesh + elif interpolator.target_mesh.input_ordering is interpolator.source_mesh: + self.forward_reduce = False + self.original_vom = interpolator.target_mesh else: raise ValueError( "The target vom and source vom must be linked by input ordering!" ) - self.V = V - self.source_vom = source_vom - self.expr = expr - self.arguments = arguments - self.reduce = reduce - # note that interpolation doesn't include halo cells - self.dummy_mat = VomOntoVomDummyMat( - original_vom.input_ordering_without_halos_sf, reduce, V, source_vom, expr, arguments - ) - if matfree: - # If matfree, we use the SF to perform the interpolation - self.handle = self.dummy_mat._wrap_dummy_mat() - else: - # Otherwise we create the permutation matrix - self.handle = self.dummy_mat._create_permutation_mat() - - @property - def mpi_type(self): - """ - The MPI type to use for the PETSc SF. - - Should correspond to the underlying data type of the PETSc Vec. - """ - return self.handle.mpi_type - - @mpi_type.setter - def mpi_type(self, val): - self.dummy_mat.mpi_type = val - - def forward_operation(self, target_dat): - coeff = self.dummy_mat.expr_as_coeff() - with coeff.dat.vec_ro as coeff_vec, target_dat.vec_wo as target_vec: - self.handle.mult(coeff_vec, target_vec) - - -class VomOntoVomDummyMat(object): - """Dummy object to stand in for a PETSc ``Mat`` when we are interpolating - between vertex-only meshes. + self.sf = self.original_vom.input_ordering_without_halos_sf + self.V = interpolator.target_space + self.source_vom = interpolator.source_mesh + self.expr = interpolator.operand + self.arguments = extract_arguments(self.expr) + self.is_adjoint = interpolator.expr.is_adjoint - Parameters - ---------- - sf: PETSc.sf - The PETSc Star Forest (SF) to use for the operation - forward_reduce : bool - If ``True``, the action of the operator (accessed via the `mult` - method) is to perform a SF reduce from the source vec to the target - vec, whilst the adjoint action (accessed via the `multHermitian` - method) is to perform a SF broadcast from the source vec to the target - vec. If ``False``, the opposite is true. - V : `.FunctionSpace` - The P0DG function space (which may be vector or tensor valued) on the - source vertex-only mesh. - source_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate from. - expr : `ufl.Expr` - The expression to interpolate. If ``arguments`` is not empty, those - arguments must be present within it. - arguments : list of `ufl.Argument` - The arguments in the expression. - """ - - def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments): - self.sf = sf - self.forward_reduce = forward_reduce - self.V = V - self.source_vom = source_vom - self.expr = expr - self.arguments = arguments # Calculate correct local and global sizes for the matrix - nroots, leaves, _ = sf.getGraph() + nroots, leaves, _ = self.sf.getGraph() self.nleaves = len(leaves) - self._local_sizes = V.comm.allgather(nroots) + self._local_sizes = self.V.comm.allgather(nroots) self.source_size = (self.V.block_size * nroots, self.V.block_size * sum(self._local_sizes)) self.target_size = ( self.V.block_size * self.nleaves, - self.V.block_size * V.comm.allreduce(self.nleaves, op=MPI.SUM), + self.V.block_size * self.V.comm.allreduce(self.nleaves, op=MPI.SUM), ) + if interpolator.matfree: + # If matfree, we use the SF to perform the interpolation + self.handle = self._wrap_python_mat() + else: + # Otherwise we create the permutation matrix + self.handle = self._create_permutation_mat() + @property def mpi_type(self): """ @@ -1544,13 +1227,25 @@ def mpi_type(self): def mpi_type(self, val): self._mpi_type = val - def expr_as_coeff(self, source_vec=None): - """ - Return a coefficient that corresponds to the expression used at + def expr_as_coeff(self, source_vec: PETSc.Vec | None = None) -> Function: + """Return a coefficient that corresponds to the expression used at construction, where the expression has been interpolated into the P0DG function space on the source vertex-only mesh. Will fail if there are no arguments. + + Parameters + ---------- + source_vec : PETSc.Vec | None, optional + Optional vector used to replace arguments in the expression. + By default None. + + Returns + ------- + Function + A Function representing the expression as a coefficient on the + source vertex-only mesh. + """ # Since we always output a coefficient when we don't have arguments in # the expression, we should evaluate the expression on the source mesh @@ -1577,7 +1272,16 @@ def expr_as_coeff(self, source_vec=None): coeff = firedrake.Function(P0DG).interpolate(coeff_expr) return coeff - def reduce(self, source_vec, target_vec): + def reduce(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Reduce data in source_vec using the PETSc SF. + + Parameters + ---------- + source_vec : PETSc.Vec + The vector to reduce. + target_vec : PETSc.Vec + The vector to store the result in. + """ source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.reduceBegin( @@ -1593,7 +1297,16 @@ def reduce(self, source_vec, target_vec): MPI.REPLACE, ) - def broadcast(self, source_vec, target_vec): + def broadcast(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Broadcast data in source_vec using the PETSc SF. + + Parameters + ---------- + source_vec : PETSc.Vec + The vector to broadcast. + target_vec : PETSc.Vec + The vector to store the result in. + """ source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.bcastBegin( @@ -1609,8 +1322,20 @@ def broadcast(self, source_vec, target_vec): MPI.REPLACE, ) - def mult(self, mat, source_vec, target_vec): - # need to evaluate expression before doing mult + def mult(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Applies the interpolation operator. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + """ + # Need to convert the expression into a coefficient + # so that we can broadcast/reduce it coeff = self.expr_as_coeff(source_vec) with coeff.dat.vec_ro as coeff_vec: if self.forward_reduce: @@ -1618,10 +1343,35 @@ def mult(self, mat, source_vec, target_vec): else: self.broadcast(coeff_vec, target_vec) - def multHermitian(self, mat, source_vec, target_vec): + def multHermitian(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Applies the adjoint of the interpolation operator. + Since ``VomOntoVomMat`` represents a permutation, it is + real-valued and thus the Hermitian adjoint is the transpose. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to adjoint interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + """ self.multTranspose(mat, source_vec, target_vec) - def multTranspose(self, mat, source_vec, target_vec): + def multTranspose(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Applies the tranpose of the interpolation operator. Called by `self.multHermitian`. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to transpose interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + + """ # can only do adjoint if our expression exclusively contains a # single argument, making the application of the adjoint operator # straightforward (haven't worked out how to do this otherwise!) @@ -1648,9 +1398,17 @@ def multTranspose(self, mat, source_vec, target_vec): target_vec.zeroEntries() self.reduce(source_vec, target_vec) - def _create_permutation_mat(self): + def _create_permutation_mat(self) -> PETSc.Mat: """Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to - its input ordering vertex-only mesh""" + its input ordering vertex-only mesh. + + Returns + ------- + PETSc.Mat + PETSc seqaij matrix + """ + # To create the permutation matrix we broadcast an array of indices which are contiguous + # across all ranks and then use these indices to set the values of the matrix directly. mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm) mat.setUp() start = sum(self._local_sizes[:self.V.comm.rank]) @@ -1663,11 +1421,18 @@ def _create_permutation_mat(self): cols = (self.V.block_size * perm[:, None] + numpy.arange(self.V.block_size, dtype=utils.IntType)[None, :]).reshape(-1) mat.setValuesCSR(rows, cols, numpy.ones_like(cols, dtype=utils.IntType)) mat.assemble() - if self.forward_reduce: + if self.forward_reduce and not self.is_adjoint: mat.transpose() return mat - def _wrap_dummy_mat(self): + def _wrap_python_mat(self) -> PETSc.Mat: + """Wraps this object as a PETSc Mat. Used for matfree interpolation. + + Returns + ------- + PETSc.Mat + A PETSc Mat of type python with this object as its context. + """ mat = PETSc.Mat().create(comm=self.V.comm) if self.forward_reduce: mat_size = (self.source_size, self.target_size) @@ -1679,96 +1444,85 @@ def _wrap_dummy_mat(self): mat.setUp() return mat - def duplicate(self, mat=None, op=None): - return self._wrap_dummy_mat() + def duplicate(self, mat: PETSc.Mat | None = None, op: PETSc.Mat.DuplicateOption | None = None) -> PETSc.Mat: + """Duplicates the matrix. Needed to wrap as a PETSc Python Mat. + + Parameters + ---------- + mat : PETSc.Mat | None, optional + Unused, by default None + op : PETSc.Mat.DuplicateOption | None, optional + Unused, by default None + + Returns + ------- + PETSc.Mat + VomOntoVomMat wrapped as a PETSc Mat of type python. + """ + return self._wrap_python_mat() class MixedInterpolator(Interpolator): - """A reusable interpolation object between MixedFunctionSpaces. + """Interpolator between MixedFunctionSpaces.""" + def __init__(self, expr: Interpolate): + """Initialise MixedInterpolator. Should not be called directly; use `get_interpolator`. - Parameters - ---------- - expr - The underlying ufl.Interpolate or the operand to the ufl.Interpolate. - V - The :class:`.FunctionSpace` or :class:`.Function` to - interpolate into. - bcs - A list of boundary conditions. - **kwargs - Any extra kwargs are passed on to the sub Interpolators. - For details see :class:`firedrake.interpolation.Interpolator`. - """ - def __init__(self, expr, V, bcs=None, **kwargs): - super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) - expr = self.ufl_interpolate - self.arguments = expr.arguments() - rank = len(self.arguments) + Parameters + ---------- + expr : Interpolate + Symbolic Interpolate expression. + """ + super().__init__(expr) # We need a Coargument in order to split the Interpolate - needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 + needs_action = not any(isinstance(a, Coargument) for a in self.expr_args) if needs_action: - dual_arg, operand = expr.argument_slots() # Split the dual argument - dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + dual_split = dict(firedrake.formmanipulation.split_form(self.dual_arg)) # Create the Jacobian to be split into blocks - expr = expr._ufl_expr_reconstruct_(operand, V) + self.expr = self.expr._ufl_expr_reconstruct_(self.operand, self.target_space) - Isub = {} - # Split in the arguments of the Interpolate - for indices, form in firedrake.formmanipulation.split_form(expr): + # Get sub-interpolators for each block + self.Isub: dict[tuple[int, int], Interpolator] = {} + for indices, form in firedrake.formmanipulation.split_form(self.expr): if isinstance(form, ufl.ZeroBaseForm): # Ensure block sparsity continue vi, _ = form.argument_slots() Vtarget = vi.function_space().dual() - if bcs and rank != 0: + if self.bcs and self.rank != 0: args = form.arguments() - Vsource = args[1-vi.number()].function_space() - sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}] + Vsource = args[1 - vi.number()].function_space() + sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}] else: sub_bcs = None if needs_action: # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) - - Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) - - self._sub_interpolators = Isub - self.callable = self._assemble_matnest - - def __getitem__(self, item): - return self._sub_interpolators[item] - - def __iter__(self): - return iter(self._sub_interpolators) - - def _assemble_matnest(self): - """Assemble the operator.""" - shape = tuple(len(a.function_space()) for a in self.arguments) - blocks = numpy.full(shape, PETSc.Mat(), dtype=object) - # Assemble the sparse block matrix - for i in self: - blocks[i] = self[i].callable().handle - petscmat = PETSc.Mat().createNest(blocks) - tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) - return tensor.M - - def _interpolate(self, *function, output=None, adjoint=False, **kwargs): - """Assemble the action.""" - rank = len(self.arguments) - if rank == 0: - result = sum(self[i].assemble(**kwargs) for i in self) - return output.assign(result) if output else result - - if output is None: - output = firedrake.Function(self.arguments[-1].function_space().dual()) - - if rank == 1: - for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) - elif rank == 2: - for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs) - for i in self if i[0] == k)) - return output + form.options.bcs = sub_bcs + self.Isub[indices] = get_interpolator(form) + + def _build_callable(self, tensor=None): + V_dest = self.expr.function_space() or self.target_space + f = tensor or Function(V_dest) + if self.rank == 2: + shape = tuple(len(a.function_space()) for a in self.expr_args) + blocks = numpy.full(shape, PETSc.Mat(), dtype=object) + for indices, interp in self.Isub.items(): + interp._build_callable() + blocks[indices] = interp.callable().handle + self.handle = PETSc.Mat().createNest(blocks) + + def callable() -> MixedInterpolator: + return self + elif self.rank == 1: + def callable() -> Function | Cofunction: + for k, sub_tensor in enumerate(f.subfunctions): + sub_tensor.assign(sum( + interp.assemble() for indices, interp in self.Isub.items() if indices[0] == k + )) + return f + else: + def callable() -> Number: + return sum(interp.assemble() for interp in self.Isub.values()) + self.callable = callable diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 5c9f92fb02..b28b8ea6b1 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4153,7 +4153,7 @@ def _parent_mesh_embedding( # nessesary, to other processes. P0DG = functionspace.FunctionSpace(parent_mesh, "DG", 0) with stop_annotating(): - visible_ranks = interpolation.Interpolate( + visible_ranks = interpolation.interpolate( constant.Constant(parent_mesh.comm.rank), P0DG ) visible_ranks = assemble(visible_ranks).dat.data_ro_with_halos.real diff --git a/firedrake/mg/utils.py b/firedrake/mg/utils.py index 37832b64dc..886cc7530c 100644 --- a/firedrake/mg/utils.py +++ b/firedrake/mg/utils.py @@ -143,7 +143,7 @@ def physical_node_locations(V): Vc = V.collapse().reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension())) # FIXME: This is unsafe for DG coordinates and CG target spaces. - locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc)) + locations = firedrake.assemble(firedrake.interpolate(firedrake.SpatialCoordinate(mesh), Vc)) return cache.setdefault(key, locations) diff --git a/firedrake/preconditioners/gtmg.py b/firedrake/preconditioners/gtmg.py index 6ce73cd6b4..2ac5df9a5d 100644 --- a/firedrake/preconditioners/gtmg.py +++ b/firedrake/preconditioners/gtmg.py @@ -4,7 +4,7 @@ from firedrake.petsc import PETSc from firedrake.preconditioners.base import PCBase from firedrake.parameters import parameters -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.solving_utils import _SNESContext from firedrake.matrix_free.operators import ImplicitMatrixContext import firedrake.dmhooks as dmhooks @@ -155,7 +155,7 @@ def initialize(self, pc): # Create interpolation matrix from coarse space to fine space fine_space = ctx.J.arguments()[0].function_space() coarse_test, coarse_trial = coarse_operator.arguments() - interp = assemble(Interpolate(coarse_trial, fine_space)) + interp = assemble(interpolate(coarse_trial, fine_space)) interp_petscmat = interp.petscmat restr_petscmat = appctx.get("restriction_matrix", None) diff --git a/firedrake/preconditioners/hiptmair.py b/firedrake/preconditioners/hiptmair.py index 14ec77fe1a..6c24b9cd84 100644 --- a/firedrake/preconditioners/hiptmair.py +++ b/firedrake/preconditioners/hiptmair.py @@ -10,7 +10,7 @@ from firedrake.preconditioners.hypre_ams import chop from firedrake.preconditioners.facet_split import restrict from firedrake.parameters import parameters -from firedrake.interpolation import Interpolator +from firedrake.interpolation import interpolate from ufl.algorithms.ad import expand_derivatives import firedrake.dmhooks as dmhooks import firedrake.utils as utils @@ -202,7 +202,7 @@ def coarsen(self, pc): coarse_space_bcs = tuple(coarse_space_bcs) if G_callback is None: - interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle) + interp_petscmat = chop(assemble(interpolate(dminus(trial), V, bcs=bcs + coarse_space_bcs)).mat()) else: interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs) diff --git a/firedrake/preconditioners/hypre_ads.py b/firedrake/preconditioners/hypre_ads.py index 89c10dc438..98443f2c75 100644 --- a/firedrake/preconditioners/hypre_ads.py +++ b/firedrake/preconditioners/hypre_ads.py @@ -1,7 +1,7 @@ from firedrake.preconditioners.base import PCBase from firedrake.petsc import PETSc from firedrake.function import Function -from firedrake.ufl_expr import TestFunction +from firedrake.ufl_expr import TrialFunction from firedrake.dmhooks import get_function_space from firedrake.preconditioners.hypre_ams import chop from firedrake.interpolation import interpolate @@ -31,12 +31,12 @@ def initialize(self, obj): NC1 = V.reconstruct(family="N1curl" if mesh.ufl_cell().is_simplex() else "NCE", degree=1) G_callback = appctx.get("get_gradient", None) if G_callback is None: - G = chop(assemble(interpolate(grad(TestFunction(P1)), NC1)).petscmat) + G = chop(assemble(interpolate(grad(TrialFunction(P1)), NC1)).petscmat) else: G = G_callback(P1, NC1) C_callback = appctx.get("get_curl", None) if C_callback is None: - C = chop(assemble(interpolate(curl(TestFunction(NC1)), V)).petscmat) + C = chop(assemble(interpolate(curl(TrialFunction(NC1)), V)).petscmat) else: C = C_callback(NC1, V) diff --git a/firedrake/preconditioners/hypre_ams.py b/firedrake/preconditioners/hypre_ams.py index 9a59702af4..594fe88590 100644 --- a/firedrake/preconditioners/hypre_ams.py +++ b/firedrake/preconditioners/hypre_ams.py @@ -2,7 +2,7 @@ from firedrake.preconditioners.base import PCBase from firedrake.petsc import PETSc from firedrake.function import Function -from firedrake.ufl_expr import TestFunction +from firedrake.ufl_expr import TrialFunction from firedrake.dmhooks import get_function_space from firedrake.utils import complex_mode from firedrake.interpolation import interpolate @@ -51,7 +51,7 @@ def initialize(self, obj): P1 = V.reconstruct(family="Lagrange", degree=1) G_callback = appctx.get("get_gradient", None) if G_callback is None: - G = chop(assemble(interpolate(grad(TestFunction(P1)), V)).petscmat) + G = chop(assemble(interpolate(grad(TrialFunction(P1)), V)).petscmat) else: G = G_callback(P1, V) diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index bb47093a3d..4910ba1452 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -4,7 +4,7 @@ from firedrake.solving_utils import _SNESContext from firedrake.utils import cached_property, complex_mode, IntType from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx -from firedrake.interpolation import Interpolate +from firedrake.interpolation import interpolate from firedrake.ufl_expr import extract_domains from collections import namedtuple @@ -668,7 +668,7 @@ def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None): # with access descriptor MAX to define a consistent opinion # about where the vertices are. CGk = V.reconstruct(family="Lagrange") - coordinates = assemble(Interpolate(coordinates, CGk, access=op2.MAX)) + coordinates = assemble(interpolate(coordinates, CGk, access=op2.MAX)) select = partial(select_entity, dm=dm, exclude="pyop2_ghost") entities = [(p, self.coords(dm, p, coordinates)) for p in diff --git a/firedrake/pyplot/mpl.py b/firedrake/pyplot/mpl.py index 3cf010a1c9..d6a7aa5112 100644 --- a/firedrake/pyplot/mpl.py +++ b/firedrake/pyplot/mpl.py @@ -18,7 +18,7 @@ import mpl_toolkits.mplot3d from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection from math import factorial -from firedrake import (Interpolate, sqrt, inner, Function, SpatialCoordinate, +from firedrake import (interpolate, sqrt, inner, Function, SpatialCoordinate, FunctionSpace, VectorFunctionSpace, PointNotInDomainError, Constant, assemble, dx) from firedrake.mesh import MeshGeometry @@ -120,7 +120,7 @@ def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}): if element.degree() != 1: # Interpolate to piecewise linear. V = VectorFunctionSpace(mesh, element.family(), 1) - coordinates = assemble(Interpolate(coordinates, V)) + coordinates = assemble(interpolate(coordinates, V)) coords = toreal(coordinates.dat.data_ro_with_halos, "real") result = [] @@ -215,7 +215,7 @@ def _plot_2d_field(method_name, function, *args, complex_component="real", **kwa if len(function.ufl_shape) == 1: element = function.ufl_element().sub_elements[0] Q = FunctionSpace(mesh, element) - function = assemble(Interpolate(sqrt(inner(function, function)), Q)) + function = assemble(interpolate(sqrt(inner(function, function)), Q)) num_sample_points = kwargs.pop("num_sample_points", 10) function_plotter = FunctionPlotter(mesh, num_sample_points) @@ -326,7 +326,7 @@ def trisurf(function, *args, complex_component="real", **kwargs): if len(function.ufl_shape) == 1: element = function.ufl_element().sub_elements[0] Q = FunctionSpace(mesh, element) - function = assemble(Interpolate(sqrt(inner(function, function)), Q)) + function = assemble(interpolate(sqrt(inner(function, function)), Q)) num_sample_points = kwargs.pop("num_sample_points", 10) function_plotter = FunctionPlotter(mesh, num_sample_points) @@ -355,7 +355,7 @@ def quiver(function, *, complex_component="real", **kwargs): coords = toreal(extract_unique_domain(function).coordinates.dat.data_ro, "real") V = extract_unique_domain(function).coordinates.function_space() - function_interp = assemble(Interpolate(function, V)) + function_interp = assemble(interpolate(function, V)) vals = toreal(function_interp.dat.data_ro, complex_component) C = np.linalg.norm(vals, axis=1) return axes.quiver(*(coords.T), *(vals.T), C, **kwargs) @@ -816,7 +816,7 @@ def _bezier_plot(function, axes, complex_component="real", **kwargs): mesh = function.function_space().mesh() if deg == 0: V = FunctionSpace(mesh, "DG", 1) - interp = assemble(Interpolate(function, V)) + interp = assemble(interpolate(function, V)) return _bezier_plot(interp, axes, complex_component=complex_component, **kwargs) y_vals = _bezier_calculate_points(function) diff --git a/firedrake/utility_meshes.py b/firedrake/utility_meshes.py index 223fa59a2e..7b1818c2de 100644 --- a/firedrake/utility_meshes.py +++ b/firedrake/utility_meshes.py @@ -11,7 +11,7 @@ Function, Constant, assemble, - Interpolate, + interpolate, FiniteElement, interval, tetrahedron, @@ -2351,7 +2351,7 @@ def OctahedralSphereMesh( ) if degree > 1: # use it to build a higher-order mesh - m = assemble(Interpolate(ufl.SpatialCoordinate(m), VectorFunctionSpace(m, "CG", degree))) + m = assemble(interpolate(ufl.SpatialCoordinate(m), VectorFunctionSpace(m, "CG", degree))) m = mesh.Mesh( m, name=name, @@ -2386,11 +2386,11 @@ def OctahedralSphereMesh( # Make a copy of the coordinates so that we can blend two different # mappings near the pole Vc = m.coordinates.function_space() - Xlatitudinal = assemble(Interpolate( + Xlatitudinal = assemble(interpolate( Constant(radius) * ufl.as_vector([x * scale, y * scale, znew]), Vc )) Vlow = VectorFunctionSpace(m, "CG", 1) - Xlow = assemble(Interpolate(Xlatitudinal, Vlow)) + Xlow = assemble(interpolate(Xlatitudinal, Vlow)) r = ufl.sqrt(Xlow[0] ** 2 + Xlow[1] ** 2 + Xlow[2] ** 2) Xradial = Constant(radius) * Xlow / r diff --git a/tests/firedrake/adjoint/test_reduced_functional.py b/tests/firedrake/adjoint/test_reduced_functional.py index e440967fe9..eb20b4bf82 100644 --- a/tests/firedrake/adjoint/test_reduced_functional.py +++ b/tests/firedrake/adjoint/test_reduced_functional.py @@ -214,7 +214,7 @@ def test_interpolate(): f = Function(V) f.dat.data[:] = 2 - J = assemble(Interpolate(f**2, c)) + J = assemble(interpolate(f**2, c)) Jhat = ReducedFunctional(J, Control(f)) h = Function(V) @@ -244,7 +244,7 @@ def test_interpolate_mixed(): f1, f2 = split(f) exprs = [f2 * div(f1)**2, grad(f2) * div(f1)] expr = as_vector([e[i] for e in exprs for i in np.ndindex(e.ufl_shape)]) - J = assemble(Interpolate(expr, c)) + J = assemble(interpolate(expr, c)) Jhat = ReducedFunctional(J, Control(f)) h = Function(V) diff --git a/tests/firedrake/external_operators/test_external_operators.py b/tests/firedrake/external_operators/test_external_operators.py index b6153f3d1f..a47953a566 100644 --- a/tests/firedrake/external_operators/test_external_operators.py +++ b/tests/firedrake/external_operators/test_external_operators.py @@ -104,7 +104,7 @@ def test_assemble(V, f): assert isinstance(jac, MatrixBase) # Assemble the exact Jacobian, i.e. the interpolation matrix: `Interpolate(dexpr(u,v,w)/du, V)` - jac_exact = assemble(Interpolate(derivative(expr(u, v, w), u), V)) + jac_exact = assemble(interpolate(derivative(expr(u, v, w), u), V)) np.allclose(jac.petscmat[:, :], jac_exact.petscmat[:, :], rtol=1e-14) # -- dNdu(u, v, w; δu, v*) (TLM) -- # diff --git a/tests/firedrake/multigrid/test_poisson_gtmg.py b/tests/firedrake/multigrid/test_poisson_gtmg.py index f70e5c6825..a4154dd392 100644 --- a/tests/firedrake/multigrid/test_poisson_gtmg.py +++ b/tests/firedrake/multigrid/test_poisson_gtmg.py @@ -60,7 +60,7 @@ def p1_callback(): if custom_transfer: P1 = get_p1_space() V = FunctionSpace(mesh, "DGT", degree - 1) - I = assemble(Interpolate(TrialFunction(P1), V)).petscmat + I = assemble(interpolate(TrialFunction(P1), V)).petscmat R = PETSc.Mat().createTranspose(I) appctx['interpolation_matrix'] = I appctx['restriction_matrix'] = R diff --git a/tests/firedrake/regression/test_adjoint_operators.py b/tests/firedrake/regression/test_adjoint_operators.py index cc4f1ade43..57faf80477 100644 --- a/tests/firedrake/regression/test_adjoint_operators.py +++ b/tests/firedrake/regression/test_adjoint_operators.py @@ -729,7 +729,7 @@ def test_copy_function(): g = f.copy(deepcopy=True) J = assemble(g*dx) rf = ReducedFunctional(J, Control(f)) - a = assemble(Interpolate(-one, V)) + a = assemble(interpolate(-one, V)) assert np.isclose(rf(a), -J) diff --git a/tests/firedrake/regression/test_function.py b/tests/firedrake/regression/test_function.py index 85c6cdc0a0..1f9ed7a429 100644 --- a/tests/firedrake/regression/test_function.py +++ b/tests/firedrake/regression/test_function.py @@ -81,22 +81,22 @@ def test_firedrake_tensor_function_nonstandard_shape(W_nonstandard_shape): def test_mismatching_rank_interpolation(V): f = Function(V) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant((1, 2))) VV = VectorFunctionSpace(V.mesh(), 'CG', 1) f = Function(VV) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant((1, 2))) VVV = TensorFunctionSpace(V.mesh(), 'CG', 1) f = Function(VVV) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant((1, 2))) def test_mismatching_shape_interpolation(V): VV = VectorFunctionSpace(V.mesh(), 'CG', 1) f = Function(VV) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): f.interpolate(Constant([1] * (VV.value_shape[0] + 1))) diff --git a/tests/firedrake/regression/test_interp_dual.py b/tests/firedrake/regression/test_interp_dual.py index 50e29b05cb..b58eb3c0e1 100644 --- a/tests/firedrake/regression/test_interp_dual.py +++ b/tests/firedrake/regression/test_interp_dual.py @@ -54,7 +54,7 @@ def test_assemble_interp_adjoint_tensor(mesh, V1, f1): def test_assemble_interp_operator(V2, f1): # Check type - If1 = Interpolate(f1, V2) + If1 = interpolate(f1, V2) assert isinstance(If1, ufl.Interpolate) # -- I(f1, V2) -- # @@ -89,7 +89,7 @@ def test_assemble_interp_matrix(V1, V2, f1): def test_assemble_interp_tlm(V1, V2, f1): # -- Action(I(v1, V2), f1) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = interpolate(v1, V2) b = assemble(interpolate(f1, V2)) assembled_action_Iv1 = assemble(action(Iv1, f1)) @@ -99,7 +99,7 @@ def test_assemble_interp_tlm(V1, V2, f1): def test_assemble_interp_adjoint_matrix(V1, V2): # -- Adjoint(I(v1, V2)) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = interpolate(v1, V2) v2 = TestFunction(V2) c2 = assemble(conj(v2) * dx) @@ -120,11 +120,11 @@ def test_assemble_interp_adjoint_matrix(V1, V2): def test_assemble_interp_adjoint_model(V1, V2): # -- Action(Adjoint(I(v1, v2)), fstar) -- # v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) + Iv1 = interpolate(v1, V2) fstar = Cofunction(V2.dual()) v = Argument(V1, 0) - Ivfstar = assemble(Interpolate(v, fstar)) + Ivfstar = assemble(interpolate(v, fstar)) # Action(Adjoint(I(v1, v2)), fstar) <=> I(v, fstar) res = assemble(action(adjoint(Iv1), fstar)) assert np.allclose(res.dat.data, Ivfstar.dat.data) @@ -167,9 +167,9 @@ def test_assemble_base_form_operator_expressions(mesh): f2 = Function(V1).interpolate(sin(2*pi*y)) f3 = Function(V1).interpolate(cos(2*pi*x)) - If1 = Interpolate(f1, V2) - If2 = Interpolate(f2, V2) - If3 = Interpolate(f3, V2) + If1 = interpolate(f1, V2) + If2 = interpolate(f2, V2) + If3 = interpolate(f3, V2) # Sum of BaseFormOperators (1-form) res = assemble(If1 + If2 + If3) @@ -184,8 +184,8 @@ def test_assemble_base_form_operator_expressions(mesh): # Sum of BaseFormOperator (2-form) v1 = TrialFunction(V1) - Iv1 = Interpolate(v1, V2) - Iv2 = Interpolate(v1, V2) + Iv1 = interpolate(v1, V2) + Iv2 = interpolate(v1, V2) res = assemble(Iv1 + Iv2) mat_Iv1 = assemble(Iv1) mat_Iv2 = assemble(Iv2) @@ -210,7 +210,7 @@ def test_check_identity(mesh): V1 = FunctionSpace(mesh, "CG", 1) v2 = TestFunction(V2) v1 = TestFunction(V1) - a = assemble(Interpolate(v1, conj(v2)*dx)) + a = assemble(interpolate(v1, conj(v2)*dx)) b = assemble(conj(v1)*dx) assert np.allclose(a.dat.data, b.dat.data) @@ -234,7 +234,7 @@ def test_solve_interp_f(mesh): # -- Solution where the source term is interpolated via `ufl.Interpolate` u2 = Function(V1) - If = Interpolate(f1, V2) + If = interpolate(f1, V2) # This requires assembling If F2 = inner(grad(u2), grad(w))*dx + inner(u2, w)*dx - inner(If, w)*dx solve(F2 == 0, u2) @@ -267,7 +267,7 @@ def test_solve_interp_u(mesh): # -- Solution where u2 is interpolated via `ufl.Interpolate` (mat-free) u2 = Function(V1) # Iu is the identity - Iu = Interpolate(u2, V1) + Iu = interpolate(u2, V1) # This requires assembling the action the Jacobian of Iu F2 = inner(grad(u2), grad(w))*dx + inner(Iu, w)*dx - inner(f, w)*dx solve(F2 == 0, u2, solver_parameters={"mat_type": "matfree", @@ -278,7 +278,7 @@ def test_solve_interp_u(mesh): # Same problem with grad(Iu) instead of grad(Iu) u2 = Function(V1) # Iu is the identity - Iu = Interpolate(u2, V1) + Iu = interpolate(u2, V1) # This requires assembling the action the Jacobian of Iu F2 = inner(grad(Iu), grad(w))*dx + inner(Iu, w)*dx - inner(f, w)*dx solve(F2 == 0, u2, solver_parameters={"mat_type": "matfree", @@ -341,7 +341,7 @@ def test_interp_dual_mixed(source_space, target_space): expected = assemble(F_target) F_source = inner(b, v)*dx - I_source = Interpolate(expr, F_source) + I_source = interpolate(expr, F_source) c = Cofunction(W.dual()) c.assign(99) diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index e8b5cb595a..ec925d312a 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -584,7 +584,7 @@ def test_interpolator_reuse(family, degree, mode): u = Function(V.dual()) expr = interpolate(TestFunction(V), u) - I = Interpolator(expr, V) + I = get_interpolator(expr) for k in range(3): u.assign(rg.uniform(u.function_space())) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index f6c22e48e8..65974c5e54 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -339,12 +339,18 @@ def test_exact_refinement(): expr_in_V_fine = x**2 + y**2 + 1 f_fine = Function(V_fine).interpolate(expr_in_V_fine) + # Build interpolation matrices in both directions + coarse_to_fine = assemble(interpolate(TrialFunction(V_coarse), V_fine)) + coarse_to_fine_adjoint = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) + # If we now interpolate f_coarse into V_fine we should get a function # which has no interpolation error versus f_fine because we were able to # exactly represent expr_in_V_coarse in V_coarse and V_coarse is a subset # of V_fine f_coarse_on_fine = assemble(interpolate(f_coarse, V_fine)) assert np.allclose(f_coarse_on_fine.dat.data_ro, f_fine.dat.data_ro) + f_coarse_on_fine_mat = assemble(coarse_to_fine @ f_coarse) + assert np.allclose(f_coarse_on_fine_mat.dat.data_ro, f_fine.dat.data_ro) # Adjoint interpolation takes us from V_fine^* to V_coarse^* so we should # also get an exact result here. @@ -354,6 +360,10 @@ def test_exact_refinement(): assert np.allclose( cofunction_fine_on_coarse.dat.data_ro, cofunction_coarse.dat.data_ro ) + cofunction_fine_on_coarse_mat = assemble(action(coarse_to_fine_adjoint, cofunction_fine)) + assert np.allclose( + cofunction_fine_on_coarse_mat.dat.data_ro, cofunction_coarse.dat.data_ro + ) # Now we test with expressions which are NOT exactly representable in the # function spaces by introducing a cube term. This can't be represented @@ -550,7 +560,7 @@ def test_missing_dofs(): V_src = FunctionSpace(m_src, "CG", 2) V_dest = FunctionSpace(m_dest, "CG", 3) with pytest.raises(DofNotDefinedError): - Interpolator(TestFunction(V_src), V_dest) + assemble(interpolate(TrialFunction(V_src), V_dest)) f_src = Function(V_src).interpolate(expr) f_dest = assemble(interpolate(f_src, V_dest, allow_missing_dofs=True)) dest_eval = PointEvaluator(m_dest, coords) @@ -680,6 +690,32 @@ def test_interpolate_matrix_cross_mesh(): f_interp2.dat.data_wo[:] = f_at_points_correct_order3.dat.data_ro[:] assert np.allclose(f_interp2.dat.data_ro, g.dat.data_ro) + interp_mat2 = assemble(interpolate(TrialFunction(U), V)) + assert interp_mat2.arguments() == (TestFunction(V.dual()), TrialFunction(U)) + f_interp3 = assemble(interp_mat2 @ f) + assert f_interp3.function_space() == V + assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro) + + +@pytest.mark.parallel([1, 3]) +def test_interpolate_matrix_cross_mesh_adjoint(): + mesh_fine = UnitSquareMesh(4, 4) + mesh_coarse = UnitSquareMesh(2, 2) + + V_coarse = FunctionSpace(mesh_coarse, "CG", 1) + V_fine = FunctionSpace(mesh_fine, "CG", 1) + + cofunc_fine = assemble(conj(TestFunction(V_fine)) * dx) + + interp = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) + cofunc_coarse = assemble(Action(interp, cofunc_fine)) + assert interp.arguments() == (TestFunction(V_coarse), TrialFunction(V_fine.dual())) + assert cofunc_coarse.function_space() == V_coarse.dual() + + # Compare cofunc_fine with direct interpolation + cofunc_coarse_direct = assemble(conj(TestFunction(V_coarse)) * dx) + assert np.allclose(cofunc_coarse.dat.data_ro, cofunc_coarse_direct.dat.data_ro) + @pytest.mark.parallel([2, 3, 4]) def test_voting_algorithm_edgecases(): diff --git a/tests/firedrake/regression/test_interpolate_zany.py b/tests/firedrake/regression/test_interpolate_zany.py index b2054843cd..875d0fb0c1 100644 --- a/tests/firedrake/regression/test_interpolate_zany.py +++ b/tests/firedrake/regression/test_interpolate_zany.py @@ -117,7 +117,7 @@ def test_interpolate_zany_into_vom(V, mesh, which, expr_at_vom): P0 = expr_at_vom.function_space() # Interpolate a Function into P0(vom) - f_at_vom = assemble(Interpolate(fexpr, P0)) + f_at_vom = assemble(interpolate(fexpr, P0)) assert numpy.allclose(f_at_vom.dat.data_ro, expr_at_vom.dat.data_ro) # Construct a Cofunction on P0(vom)* @@ -125,10 +125,10 @@ def test_interpolate_zany_into_vom(V, mesh, which, expr_at_vom): expected_action = assemble(action(Fvom, expr_at_vom)) # Interpolate a Function into Fvom - f_at_vom = assemble(Interpolate(fexpr, Fvom)) + f_at_vom = assemble(interpolate(fexpr, Fvom)) assert numpy.allclose(f_at_vom, expected_action) # Interpolate a TestFunction into Fvom - expr_vom = assemble(Interpolate(vexpr, Fvom)) + expr_vom = assemble(interpolate(vexpr, Fvom)) f_at_vom = assemble(action(expr_vom, f)) assert numpy.allclose(f_at_vom, expected_action) diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index a26c1acb08..92e422cf6a 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -51,7 +51,7 @@ def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): f = Function(V_).interpolate(f) v0 = Coargument(V.dual(), 0) v1 = TrialFunction(Vsub) - interp = Interpolate(v1, v0, allow_missing_dofs=True) + interp = interpolate(v1, v0, allow_missing_dofs=True) A = assemble(interp) g = assemble(action(A, gsub)) assert assemble(inner(g - f, g - f) * dx(label_value)).real < 1e-14 @@ -165,7 +165,7 @@ def test_submesh_interpolate_subcell_subcell_2_processes(): f_l.dat.data_with_halos[:] = 3.0 v0 = Coargument(V_r.dual(), 0) v1 = TrialFunction(V_l) - interp = Interpolate(v1, v0, allow_missing_dofs=True) + interp = interpolate(v1, v0, allow_missing_dofs=True) A = assemble(interp) f_r = assemble(action(A, f_l)) g_r = Function(V_r).interpolate(conditional(x < 2.001, 3.0, 0.0))