diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 2992fad9d9..8728e8b38c 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -23,6 +23,7 @@ from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key +from firedrake.interpolation import get_interpolator from firedrake.petsc import PETSc from firedrake.slate import slac, slate from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg @@ -613,17 +614,8 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args): rank = len(expr.arguments()) if rank > 2: raise ValueError("Cannot assemble an Interpolate with more than two arguments") - # Get the target space - V = v.function_space().dual() - - # Get the interpolator - interp_data = expr.interp_data.copy() - default_missing_val = interp_data.pop('default_missing_val', None) - if rank == 1 and isinstance(tensor, firedrake.Function): - V = tensor - interpolator = firedrake.Interpolator(expr, V, bcs=bcs, **interp_data) - # Assembly - return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val) + interpolator = get_interpolator(expr) + return interpolator.assemble(tensor=tensor, bcs=bcs) elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)): return tensor.assign(expr) elif tensor and isinstance(expr, ufl.ZeroBaseForm): diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 0e19c872e0..275744c75e 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -1,7 +1,7 @@ # A module implementing strong (Dirichlet) boundary conditions. import numpy as np -import functools +from functools import partial, reduce import itertools import ufl @@ -167,7 +167,7 @@ def hermite_stride(bcnodes): # Edge conditions have only been tested with Lagrange elements. # Need to expand the list. bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss))) - bcnodes1 = functools.reduce(np.intersect1d, bcnodes1) + bcnodes1 = reduce(np.intersect1d, bcnodes1) bcnodes.append(bcnodes1) return np.concatenate(bcnodes) @@ -359,11 +359,11 @@ def function_arg(self, g): raise RuntimeError(f"Provided boundary value {g} does not match shape of space") try: self._function_arg = firedrake.Function(V) - # Use `Interpolator` instead of assembling an `Interpolate` form - # as the expression compilation needs to happen at this stage to - # determine if we should use interpolation or projection - # -> e.g. interpolation may not be supported for the element. - self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate + interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V)) + # Call this here to check if the element supports interpolation + # TODO: It's probably better to have a more explicit way of checking this + interpolator._get_callable() + self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg) except (NotImplementedError, AttributeError): # Element doesn't implement interpolation self._function_arg = firedrake.Function(V).project(g) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 0ead02b064..f9532e9e93 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -2,391 +2,338 @@ import os import tempfile import abc -import warnings -from collections.abc import Iterable -from functools import partial, singledispatch -from typing import Hashable, Literal -import FIAT -import ufl -import finat.ufl -from ufl.algorithms import extract_arguments, extract_coefficients -from ufl.domain import as_domain, extract_unique_domain +from functools import partial, singledispatch +from typing import Hashable, Literal, Callable, Iterable +from dataclasses import asdict, dataclass +from numbers import Number + +from ufl.algorithms import extract_arguments, replace +from ufl.domain import extract_unique_domain +from ufl.classes import Expr +from ufl.duals import is_dual +from ufl.constantvalue import zero, as_ufl +from ufl.form import ZeroBaseForm, BaseForm +from ufl.core.interpolate import Interpolate as UFLInterpolate from pyop2 import op2 from pyop2.caching import memory_and_disk_cache +from FIAT.reference_element import Point + from finat.element_factory import create_element, as_fiat_cell -from tsfc import compile_expression_dual_evaluation -from tsfc.ufl_utils import extract_firedrake_constants, hash_expr +from finat.ufl import TensorElement, VectorElement, MixedElement +from finat.fiat_elements import ScalarFiatElement +from finat.quadrature import QuadratureRule +from finat.quadrature_element import QuadratureElement +from finat.point_set import UnknownPointSet +from finat.tensorfiniteelement import TensorFiniteElement + +from gem.gem import Variable -import gem -import finat +from tsfc.driver import compile_expression_dual_evaluation +from tsfc.ufl_utils import extract_firedrake_constants, hash_expr -import firedrake -from firedrake import tsfc_interface, utils, functionspaceimpl -from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint -from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology +from firedrake.utils import IntType, ScalarType, known_pyop2_safe, tuplify +from firedrake.tsfc_interface import extract_numbered_coefficients, _cachedir +from firedrake.ufl_expr import Argument, Coargument, action +from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry, MeshTopology, VertexOnlyMesh from firedrake.petsc import PETSc -from firedrake.halo import _get_mtype as get_dat_mpi_type +from firedrake.halo import _get_mtype +from firedrake.functionspaceimpl import WithGeometry +from firedrake.matrix import MatrixBase, AssembledMatrix +from firedrake.bcs import DirichletBC +from firedrake.formmanipulation import split_form +from firedrake.functionspace import VectorFunctionSpace, TensorFunctionSpace, FunctionSpace +from firedrake.constant import Constant +from firedrake.function import Function +from firedrake.cofunction import Cofunction + from mpi4py import MPI -from pyadjoint import stop_annotating, no_annotations +from pyadjoint.tape import stop_annotating, no_annotations __all__ = ( "interpolate", - "Interpolator", "Interpolate", + "get_interpolator", "DofNotDefinedError", - "CrossMeshInterpolator", - "SameMeshInterpolator", + "InterpolateOptions", + "Interpolator" ) -class Interpolate(ufl.Interpolate): +@dataclass +class InterpolateOptions: + """Options for interpolation operations. - def __init__(self, expr, V, - subset=None, - access=None, - allow_missing_dofs=False, - default_missing_val=None, - matfree=True): + Parameters + ---------- + subset : pyop2.types.set.Subset or None + An optional subset to apply the interpolation over. + Cannot, at present, be used when interpolating across meshes unless + the target mesh is a :func:`.VertexOnlyMesh`. + access : pyop2.types.access.Access or None + The pyop2 access descriptor for combining updates to shared + DoFs. Possible values include ``WRITE``, ``MIN``, ``MAX``, and ``INC``. + Only ``WRITE`` is supported at present when interpolating across meshes + unless the target mesh is a :func:`.VertexOnlyMesh`. Only ``INC`` is + supported for the matrix-free adjoint interpolation. + allow_missing_dofs : bool + For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) + in the target mesh that cannot be defined on the source mesh. + For example, where nodes are point evaluations, points in the target mesh + that are not in the source mesh. When ``False`` this raises a ``ValueError`` + should this occur. When ``True`` the corresponding values are either + (a) unchanged if some ``output`` is given to the :meth:`interpolate` method + or (b) set to zero. + Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. + This does not affect adjoint interpolation. Ignored if interpolating within + the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a + :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). + default_missing_val : float or None + For interpolation across meshes: the optional value to assign to DoFs + in the target mesh that are outside the source mesh. If this is not set + then the values are either (a) unchanged if some ``output`` is given to + the :meth:`interpolate` method or (b) set to zero. + Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. + matfree : bool + If ``False``, then construct the permutation matrix for interpolating + between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast + and reduce operations. + """ + subset: op2.Subset | None = None + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None + allow_missing_dofs: bool = False + default_missing_val: float | None = None + matfree: bool = True + + +class Interpolate(UFLInterpolate): + + def __init__(self, expr: Expr, V: WithGeometry | BaseForm, **kwargs): """Symbolic representation of the interpolation operator. Parameters ---------- - expr : ufl.core.expr.Expr or ufl.BaseForm + expr : ufl.core.expr.Expr The UFL expression to interpolate. - V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument + V : firedrake.functionspaceimpl.WithGeometry or ufl.BaseForm The function space to interpolate into or the coargument defined on the dual of the function space to interpolate into. - subset : pyop2.types.set.Subset - An optional subset to apply the interpolation over. - Cannot, at present, be used when interpolating across meshes unless - the target mesh is a :func:`.VertexOnlyMesh`. - access : pyop2.types.access.Access - The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes. See note in - :func:`.interpolate` if changing this from default. - allow_missing_dofs : bool - For interpolation across meshes: allow degrees of freedom (aka DoFs/nodes) - in the target mesh that cannot be defined on the source mesh. - For example, where nodes are point evaluations, points in the target mesh - that are not in the source mesh. When ``False`` this raises a ``ValueError`` - should this occur. When ``True`` the corresponding values are either - (a) unchanged if some ``output`` is given to the :meth:`interpolate` method - or (b) set to zero. - Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`. - This does not affect adjoint interpolation. Ignored if interpolating within - the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a - :func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created). - default_missing_val : float - For interpolation across meshes: the optional value to assign to DoFs - in the target mesh that are outside the source mesh. If this is not set - then the values are either (a) unchanged if some ``output`` is given to - the :meth:`interpolate` method or (b) set to zero. - Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`. - matfree : bool - If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. + **kwargs + Additional interpolation options. See :class:`InterpolateOptions` + for available parameters and their descriptions. """ - expr = ufl.as_ufl(expr) - if isinstance(V, functionspaceimpl.WithGeometry): - expr_args = expr.arguments()[1:] if isinstance(expr, ufl.BaseForm) else extract_arguments(expr) - expr_arg_numbers = {arg.number() for arg in expr_args} + expr = as_ufl(expr) + expr_args = expr.arguments()[1:] if isinstance(expr, BaseForm) else extract_arguments(expr) + expr_arg_numbers = {arg.number() for arg in expr_args} + self.is_adjoint = expr_arg_numbers == {0} + if isinstance(V, WithGeometry): # Need to create a Firedrake Argument so that it has a .function_space() method - V = Argument(V.dual(), 1 if expr_arg_numbers == {0} else 0) + V = Argument(V.dual(), 1 if self.is_adjoint else 0) - target_shape = V.arguments()[0].function_space().value_shape - if expr.ufl_shape != target_shape: - raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.") + self.target_space = V.arguments()[0].function_space() + if expr.ufl_shape != self.target_space.value_shape: + raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {self.target_space.value_shape}.") super().__init__(expr, V) - # -- Interpolate data (e.g. `subset` or `access`) -- # - self.interp_data = {"subset": subset, - "access": access, - "allow_missing_dofs": allow_missing_dofs, - "default_missing_val": default_missing_val, - "matfree": matfree} + self._options = InterpolateOptions(**kwargs) - function_space = ufl.Interpolate.ufl_function_space + function_space = UFLInterpolate.ufl_function_space - def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data): - interp_data = interp_data or self.interp_data.copy() - return ufl.Interpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) + def _ufl_expr_reconstruct_( + self, expr: Expr, v: WithGeometry | BaseForm | None = None, **interp_data + ): + interp_data = interp_data or asdict(self.options) + return UFLInterpolate._ufl_expr_reconstruct_(self, expr, v=v, **interp_data) + + @property + def options(self) -> InterpolateOptions: + """Access the interpolation options. + + Returns + ------- + InterpolateOptions + An :class:`InterpolateOptions` instance containing the interpolation options. + """ + return self._options @PETSc.Log.EventDecorator() -def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, default_missing_val=None, matfree=True): +def interpolate(expr: Expr, V: WithGeometry | BaseForm, **kwargs) -> Interpolate: """Returns a UFL expression for the interpolation operation of ``expr`` into ``V``. - :arg expr: a UFL expression. - :arg V: a :class:`.FunctionSpace` to interpolate into, or a :class:`.Cofunction`, - or :class:`.Coargument`, or a :class:`ufl.form.Form` with one argument (a one-form). - If a :class:`.Cofunction` or a one-form is provided, then we do adjoint interpolation. - :kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the - interpolation over. Cannot, at present, be used when interpolating - across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - :kwarg access: The pyop2 access descriptor for combining updates to shared - DoFs. Possible values include ``WRITE`` and ``INC``. Only ``WRITE`` is - supported at present when interpolating across meshes unless the target - mesh is a :func:`.VertexOnlyMesh`. See note below. - :kwarg allow_missing_dofs: For interpolation across meshes: allow - degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be - defined on the source mesh. For example, where nodes are point - evaluations, points in the target mesh that are not in the source mesh. - When ``False`` this raises a ``ValueError`` should this occur. When - ``True`` the corresponding values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. In either case, if ``default_missing_val`` is specified, that - value is used. This does not affect adjoint interpolation. Ignored if - interpolating within the same mesh or onto a :func:`.VertexOnlyMesh` - (the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at - present, set when it is created). - :kwarg default_missing_val: For interpolation across meshes: the optional - value to assign to DoFs in the target mesh that are outside the source - mesh. If this is not set then the values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. Ignored if interpolating within the same mesh or onto a - :func:`.VertexOnlyMesh`. - :kwarg matfree: If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. - :returns: A symbolic :class:`.Interpolate` object - - .. note:: - - If you use an access descriptor other than ``WRITE``, the - behaviour of interpolation changes if interpolating into a - function space, or an existing function. If the former, then - the newly allocated function will be initialised with - appropriate values (e.g. for MIN access, it will be initialised - with MAX_FLOAT). On the other hand, if you provide a function, - then it is assumed that its values should take part in the - reduction (hence using MIN will compute the MIN between the - existing values and any new values). + Parameters + ---------- + expr : ufl.core.expr.Expr + The UFL expression to interpolate. + V : firedrake.functionspaceimpl.WithGeometry or ufl.BaseForm + The function space to interpolate into or the coargument defined + on the dual of the function space to interpolate into. + **kwargs + Additional interpolation options. See :class:`InterpolateOptions` + for available parameters and their descriptions. + + Returns + ------- + Interpolate + A symbolic :class:`Interpolate` object representing the interpolation operation. """ - return Interpolate( - expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs, - default_missing_val=default_missing_val, matfree=matfree - ) + return Interpolate(expr, V, **kwargs) class Interpolator(abc.ABC): - """A reusable interpolation object. - - This object can be used to carry out the same interpolation - multiple times (for example in a timestepping loop). + """Base class for calculating interpolation. Should not be instantiated directly; use the + :func:`get_interpolator` function. Parameters ---------- - expr - The underlying ufl.Interpolate or the operand to the ufl.Interpolate. - V - The :class:`.FunctionSpace` or :class:`.Function` to - interpolate into. - subset - An optional :class:`pyop2.types.set.Subset` to apply the - interpolation over. Cannot, at present, be used when interpolating - across meshes unless the target mesh is a :func:`.VertexOnlyMesh`. - freeze_expr - Set to True to prevent the expression being - re-evaluated on each call. Cannot, at present, be used when - interpolating across meshes unless the target mesh is a - :func:`.VertexOnlyMesh`. - access - The pyop2 access descriptor for combining updates to shared DoFs. - Only ``op2.WRITE`` is supported at present when interpolating across meshes. - Only ``op2.INC`` is supported for the matrix-free adjoint interpolation. - See note in :func:`.interpolate` if changing this from default. - bcs - An optional list of boundary conditions to zero-out in the - output function space. Interpolator rows or columns which are - associated with boundary condition nodes are zeroed out when this is - specified. - allow_missing_dofs - For interpolation across meshes: allow - degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be - defined on the source mesh. For example, where nodes are point - evaluations, points in the target mesh that are not in the source mesh. - When ``False`` this raises a ``ValueError`` should this occur. When - ``True`` the corresponding values are either (a) unchanged if - some ``output`` is given to the :meth:`interpolate` method or (b) set - to zero. Can be overwritten with the ``default_missing_val`` kwarg - of :meth:`interpolate`. This does not affect adjoint interpolation. - Ignored if interpolating within the same mesh or onto a - :func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in - this scenario is, at present, set when it is created). - matfree - If ``False``, then construct the permutation matrix for interpolating - between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast - and reduce operations. - - Notes - ----- - - The :class:`Interpolator` holds a reference to the provided - arguments (such that they won't be collected until the - :class:`Interpolator` is also collected). + expr : Interpolate + The symbolic interpolation expression. """ - - def __new__(cls, expr, V, **kwargs): - V_target = V if isinstance(V, ufl.FunctionSpace) else V.function_space() - if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V_target) - - arguments = expr.arguments() - has_mixed_arguments = any(len(a.function_space()) > 1 for a in arguments) - if len(arguments) == 2 and has_mixed_arguments: - return object.__new__(MixedInterpolator) - - operand, = expr.ufl_operands - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - submesh_interp_implemented = \ - all(isinstance(m.topology, firedrake.mesh.MeshTopology) for m in [target_mesh, source_mesh]) and \ - target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] and \ - target_mesh.topological_dimension == source_mesh.topological_dimension - if target_mesh is source_mesh or submesh_interp_implemented: - return object.__new__(SameMeshInterpolator) - else: - if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - return object.__new__(SameMeshInterpolator) - elif has_mixed_arguments or len(V_target) > 1: - return object.__new__(MixedInterpolator) - else: - return object.__new__(CrossMeshInterpolator) - - def __init__( - self, - expr: ufl.Interpolate | ufl.classes.Expr, - V: ufl.FunctionSpace | firedrake.function.Function, - subset: op2.Subset | None = None, - freeze_expr: bool = False, - access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] | None = None, - bcs: Iterable[firedrake.bcs.BCBase] | None = None, - allow_missing_dofs: bool = False, - matfree: bool = True - ): - if not isinstance(expr, ufl.Interpolate): - expr = interpolate(expr, V if isinstance(V, ufl.FunctionSpace) else V.function_space()) + def __init__(self, expr: Interpolate): dual_arg, operand = expr.argument_slots() self.ufl_interpolate = expr - self.expr = operand - self.V = V - self.subset = subset - self.freeze_expr = freeze_expr - self.bcs = bcs - self._allow_missing_dofs = allow_missing_dofs - self.matfree = matfree - self.callable = None - - # TODO CrossMeshInterpolator and VomOntoVomXXX are not yet aware of - # self.ufl_interpolate (which carries the dual argument). - # See github issue https://github.com/firedrakeproject/firedrake/issues/4592 - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(self, SameMeshInterpolator) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - if isinstance(self, CrossMeshInterpolator) or vom_onto_other_vom: - # For bespoke interpolation, we currently rely on different assembly procedures: - # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form) - # 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form) - # 3) Interpolate(Coefficient(V1), Argument(V2.dual(), 0)) -> Forward action (1-form) - # 4) Interpolate(Argument(V1, 0), Cofunction(V2.dual()) -> Adjoint action (1-form) - # 5) Interpolate(Coefficient(V1), Cofunction(V2.dual()) -> Double action (0-form) - - # CrossMeshInterpolator._interpolate only supports forward interpolation (cases 1 and 3). - # For case 2, we first redundantly assemble case 1 and then construct the transpose. - # For cases 4 and 5, we take the forward Interpolate that corresponds to dropping the Cofunction, - # and we separately compute the action against the dropped Cofunction within assemble(). - if not isinstance(dual_arg, ufl.Coargument): - # Drop the Cofunction - expr = expr._ufl_expr_reconstruct_(operand, dual_arg.function_space().dual()) - expr_args = extract_arguments(operand) - if expr_args and expr_args[0].number() == 0: - # Construct the symbolic forward Interpolate - v0, v1 = expr.arguments() - expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()), - v1: v1.reconstruct(number=v0.number())}) + """The symbolic UFL Interpolate expression.""" + self.interpolate_args = expr.arguments() + """Arguments of the Interpolate expression.""" + self.rank = len(self.interpolate_args) + """Number of arguments in the Interpolate expression.""" + self.operand = operand + """The primal argument slot of the Interpolate expression.""" + self.dual_arg = dual_arg + """The dual argument slot of the Interpolate expression.""" + self.target_space = dual_arg.function_space().dual() + """The primal space we are interpolating into.""" + self.target_mesh = self.target_space.mesh() + """The domain we are interpolating into.""" + self.source_mesh = extract_unique_domain(operand) or self.target_mesh + """The domain we are interpolating from.""" + + # Interpolation options + self.subset = expr.options.subset + self.allow_missing_dofs = expr.options.allow_missing_dofs + self.default_missing_val = expr.options.default_missing_val + self.matfree = expr.options.matfree + self.access = expr.options.access - dual_arg, operand = expr.argument_slots() - self.expr_renumbered = operand - self.ufl_interpolate_renumbered = expr + @abc.abstractmethod + def _get_callable( + self, + tensor: Function | Cofunction | MatrixBase | None = None, + bcs: Iterable[DirichletBC] | None = None + ) -> Callable[[], Function | Cofunction | PETSc.Mat | Number]: + """Return a callable to perform interpolation. - if not isinstance(dual_arg, ufl.Coargument): - # Matrix-free assembly of 0-form or 1-form requires INC access - if access and access != op2.INC: - raise ValueError("Matfree adjoint interpolation requires INC access") - access = op2.INC - elif access is None: - # Default access for forward 1-form or 2-form (forward and adjoint) - access = op2.WRITE - self.access = access + If ``self.rank == 2``, then the callable must return a PETSc matrix. + If ``self.rank == 1``, then the callable must return a ``Function`` + or ``Cofunction`` (in the forward and adjoint cases respectively). + If ``self.rank == 0``, then the callable must return a number. - def interpolate(self, *function, transpose=None, adjoint=False, default_missing_val=None): + Parameters + ---------- + tensor + Optional tensor to store the result in, by default None. + bcs + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None. """ - .. warning:: + pass - This method has been removed. Use the function :func:`interpolate` to return a symbolic - :class:`Interpolate` object. - """ - raise FutureWarning( - "The 'interpolate' method on `Interpolator` objects has been " - "removed. Use the `interpolate` function instead." - ) + def assemble( + self, + tensor: Function | Cofunction | MatrixBase | None = None, + bcs: Iterable[DirichletBC] | None = None + ) -> Function | Cofunction | MatrixBase | Number: + """Assemble the interpolation. The result depends on the rank (number of arguments) + of the :class:`Interpolate` expression: - @abc.abstractmethod - def _interpolate(self, *args, **kwargs): - """ - Compute the interpolation operation of interest. + * rank-2: assemble the operator and return a matrix + * rank-1: assemble the action and return a function or cofunction + * rank-0: assemble the action and return a scalar by applying the dual argument - .. note:: - This method is called when an :class:`Interpolate` object is being assembled. + Parameters + ---------- + tensor + Optional tensor to store the interpolated result. For rank-2 + expressions this is expected to be a subclass of + :class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions + this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`, + for forward and adjoint interpolation respectively. + bcs + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None. + Returns + ------- + Function | Cofunction | MatrixBase | numbers.Number + The function, cofunction, matrix, or scalar resulting from the + interpolation. """ - pass - - def assemble(self, tensor=None, default_missing_val=None): - """Assemble the operator (or its action).""" - from firedrake.assemble import assemble - needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate - arguments = self.ufl_interpolate.arguments() - if len(arguments) == 2: + result = self._get_callable(tensor=tensor, bcs=bcs)() + if self.rank == 2: # Assembling the operator - res = tensor.petscmat if tensor else PETSc.Mat() - # Get the interpolation matrix - op2mat = self.callable() - petsc_mat = op2mat.handle - if needs_adjoint: - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - elif tensor: - petsc_mat.copy(tensor.petscmat) - else: - res = petsc_mat - return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res) + assert isinstance(tensor, MatrixBase | None) + assert isinstance(result, PETSc.Mat) + if tensor: + result.copy(tensor.petscmat) + return tensor + return AssembledMatrix(self.interpolate_args, bcs, result) else: - # Assembling the action - cofunctions = () - if needs_adjoint: - # The renumbered Interpolate has dropped Cofunctions. - # We need to explicitly operate on them. - dual_arg, _ = self.ufl_interpolate.argument_slots() - if not isinstance(dual_arg, ufl.Coargument): - cofunctions = (dual_arg,) - - if needs_adjoint and len(arguments) == 0: - Iu = self._interpolate(default_missing_val=default_missing_val) - return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor) - else: - return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint, - default_missing_val=default_missing_val) + assert isinstance(tensor, Function | Cofunction | None) + return tensor.assign(result) if tensor else result + + +def get_interpolator(expr: Interpolate) -> Interpolator: + """Create an Interpolator. + + Parameters + ---------- + expr : Interpolate + Symbolic interpolation expression. + + Returns + ------- + Interpolator + An appropriate :class:`Interpolator` subclass for the given + interpolation expression. + """ + arguments = expr.arguments() + has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments) + if len(arguments) == 2 and has_mixed_arguments: + return MixedInterpolator(expr) + + operand, = expr.ufl_operands + target_mesh = expr.target_space.mesh() + source_mesh = extract_unique_domain(operand) or target_mesh + submesh_interp_implemented = ( + all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh]) + and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1] + and target_mesh.topological_dimension == source_mesh.topological_dimension + ) + if target_mesh is source_mesh or submesh_interp_implemented: + return SameMeshInterpolator(expr) + + target_topology = target_mesh.topology + source_topology = source_mesh.topology + + if isinstance(target_topology, VertexOnlyMeshTopology): + if isinstance(source_topology, VertexOnlyMeshTopology): + return VomOntoVomInterpolator(expr) + if target_mesh.geometric_dimension != source_mesh.geometric_dimension: + raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.") + return SameMeshInterpolator(expr) + + if has_mixed_arguments or len(expr.target_space) > 1: + return MixedInterpolator(expr) + + return CrossMeshInterpolator(expr) class DofNotDefinedError(Exception): @@ -426,26 +373,15 @@ class CrossMeshInterpolator(Interpolator): """ @no_annotations - def __init__( - self, - expr, - V, - subset=None, - freeze_expr=False, - access=None, - bcs=None, - allow_missing_dofs=False, - matfree=True - ): - if subset: - raise NotImplementedError("subset not implemented") - if freeze_expr: - # Probably just need to pass freeze_expr to the various - # interpolators for this to work. - raise NotImplementedError("freeze_expr not implemented") - if bcs: - raise NotImplementedError("bcs not implemented") - if V.ufl_element().mapping() != "identity": + def __init__(self, expr: Interpolate): + super().__init__(expr) + if self.access and self.access != op2.WRITE: + raise NotImplementedError( + "Access other than op2.WRITE not implemented for cross-mesh interpolation." + ) + else: + self.access = op2.WRITE + if self.target_space.ufl_element().mapping() != "identity": # Identity mapping between reference cell and physical coordinates # implies point evaluation nodes. A more general version would # require finding the global coordinates of all quadrature points @@ -453,246 +389,150 @@ def __init__( raise NotImplementedError( "Can only interpolate into spaces with point evaluation nodes." ) - super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree) - - if self.access != op2.WRITE: - raise NotImplementedError("access other than op2.WRITE not implemented") + if self.allow_missing_dofs: + self.missing_points_behaviour = MissingPointsBehaviour.IGNORE + else: + self.missing_points_behaviour = MissingPointsBehaviour.ERROR - expr = self.expr_renumbered - self.arguments = extract_arguments(expr) - self.nargs = len(self.arguments) + if self.source_mesh.geometric_dimension != self.target_mesh.geometric_dimension: + raise ValueError("Geometric dimensions of source and destination meshes must match.") - if self._allow_missing_dofs: - missing_points_behaviour = MissingPointsBehaviour.IGNORE + dest_element = self.target_space.ufl_element() + if isinstance(dest_element, MixedElement): + if isinstance(dest_element, VectorElement | TensorElement): + # In this case all sub elements are equal + base_element = dest_element.sub_elements[0] + if base_element.reference_value_shape != (): + raise NotImplementedError( + "Can't yet cross-mesh interpolate onto function spaces made from VectorElements " + "or TensorElements made from sub elements with value shape other than ()." + ) + self.dest_element = base_element + else: + raise NotImplementedError("Interpolation with MixedFunctionSpace requires MixedInterpolator.") else: - missing_points_behaviour = MissingPointsBehaviour.ERROR - - # setup - V_dest = V.function_space() if isinstance(V, firedrake.Function) else V - src_mesh = extract_unique_domain(expr) - dest_mesh = as_domain(V_dest) - src_mesh_gdim = src_mesh.geometric_dimension - dest_mesh_gdim = dest_mesh.geometric_dimension - if src_mesh_gdim != dest_mesh_gdim: - raise ValueError( - "geometric dimensions of source and destination meshes must match" - ) - self.src_mesh = src_mesh - self.dest_mesh = dest_mesh + # scalar fiat/finat element + self.dest_element = dest_element - # Create a VOM at the nodes of V_dest in src_mesh. We don't include halo - # node coordinates because interpolation doesn't usually include halos. - # NOTE: it is very important to set redundant=False, otherwise the - # input ordering VOM will only contain the points on rank 0! - # QUESTION: Should any of the below have annotation turned off? - ufl_scalar_element = V_dest.ufl_element() - if isinstance(ufl_scalar_element, finat.ufl.MixedElement): - if type(ufl_scalar_element) is finat.ufl.MixedElement: - raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") - - # For a VectorElement or TensorElement the correct - # VectorFunctionSpace equivalent is built from the scalar - # sub-element. - ufl_scalar_element, = set(ufl_scalar_element.sub_elements) - if ufl_scalar_element.reference_value_shape != (): - raise NotImplementedError( - "Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()." - ) + def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: + """Return the symbolic ``Interpolate`` expressions for cross-mesh interpolation. + Raises + ------ + DofNotDefinedError + If some DoFs in the target function space cannot be defined + in the source function space. + """ from firedrake.assemble import assemble - V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) - f_dest_node_coords = interpolate(dest_mesh.coordinates, V_dest_vec) - f_dest_node_coords = assemble(f_dest_node_coords) - dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim) + # Immerse coordinates of target space point evaluation dofs in src_mesh + target_space_vec = VectorFunctionSpace(self.target_mesh, self.dest_element) + f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec)) + dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension) try: - self.vom_dest_node_coords_in_src_mesh = firedrake.VertexOnlyMesh( - src_mesh, + vom = VertexOnlyMesh( + self.source_mesh, dest_node_coords, redundant=False, - missing_points_behaviour=missing_points_behaviour, + missing_points_behaviour=self.missing_points_behaviour, ) except VertexOnlyMeshMissingPointsError: - raise DofNotDefinedError(src_mesh, dest_mesh) - # vom_dest_node_coords_in_src_mesh uses the parallel decomposition of - # the global node coordinates of V_dest in the SOURCE mesh (src_mesh). - # I first point evaluate my expression at these locations, giving a - # P0DG function on the VOM. As described in the manual, this is an - # interpolation operation. - shape = V_dest.ufl_function_space().value_shape + raise DofNotDefinedError(self.source_mesh, self.target_mesh) + + # Get the correct type of function space + shape = self.target_space.ufl_function_space().value_shape if len(shape) == 0: - fs_type = firedrake.FunctionSpace + fs_type = FunctionSpace elif len(shape) == 1: - fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0]) + fs_type = partial(VectorFunctionSpace, dim=shape[0]) else: - fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) - P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) - self.point_eval_interpolate = interpolate(self.expr_renumbered, P0DG_vom) - # The parallel decomposition of the nodes of V_dest in the DESTINATION - # mesh (dest_mesh) is retrieved using the input_ordering attribute of the - # VOM. This again is an interpolation operation, which, under the hood - # is a PETSc SF reduce. - P0DG_vom_i_o = fs_type( - self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0 - ) - self.to_input_ordering_interpolate = interpolate( - firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o - ) - # The P0DG function outputted by the above interpolation has the - # correct parallel decomposition for the nodes of V_dest in dest_mesh so - # we can safely assign the dat values. This is all done in the actual - # interpolation method below. + fs_type = partial(TensorFunctionSpace, shape=shape) - @PETSc.Log.EventDecorator() - def _interpolate( - self, - *function, - output=None, - transpose=None, - adjoint=False, - default_missing_val=None, - **kwargs, - ): - """Compute the interpolation. + # Get expression for point evaluation at the dest_node_coords + P0DG_vom = fs_type(vom, "DG", 0) + point_eval = interpolate(self.operand, P0DG_vom) - For arguments, see :class:`.Interpolator`. - """ - from firedrake.assemble import assemble + # If assembling the operator, we need the concrete permutation matrix + matfree = False if self.rank == 2 else self.matfree - if transpose is not None: - warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) - adjoint = transpose or adjoint - if adjoint and not self.nargs: - raise ValueError( - "Can currently only apply adjoint interpolation with arguments." - ) - if self.nargs != len(function): - raise ValueError( - "Passed %d Functions to interpolate, expected %d" - % (len(function), self.nargs) - ) + # Interpolate into the input-ordering VOM + P0DG_vom_input_ordering = fs_type(vom.input_ordering, "DG", 0) - if self.nargs: - (f_src,) = function - if not hasattr(f_src, "dat"): - raise ValueError( - "The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" - ) - else: - f_src = self.expr - - if adjoint: - try: - V_dest = self.expr.function_space().dual() - except AttributeError: - if self.nargs: - V_dest = self.arguments[-1].function_space().dual() - else: - coeffs = extract_coefficients(self.expr) - if len(coeffs): - V_dest = coeffs[0].function_space().dual() - else: - raise ValueError( - "Can't adjoint interpolate an expression with no coefficients or arguments." - ) - else: - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): - V_dest = self.V.function_space() + arg = Argument(P0DG_vom, 0 if self.ufl_interpolate.is_adjoint else 1) + point_eval_input_ordering = interpolate(arg, P0DG_vom_input_ordering, matfree=matfree) + return point_eval, point_eval_input_ordering + + def _get_callable(self, tensor=None, bcs=None): + from firedrake.assemble import assemble + if bcs: + raise NotImplementedError("bcs not implemented for cross-mesh interpolation.") + # self.ufl_interpolate.function_space() is None in the 0-form case + V_dest = self.ufl_interpolate.function_space() or self.target_space + f = tensor or Function(V_dest) + + point_eval, point_eval_input_ordering = self._get_symbolic_expressions() + P0DG_vom_input_ordering = point_eval_input_ordering.argument_slots()[0].function_space().dual() + + if self.rank == 2: + # The cross-mesh interpolation matrix is the product of the + # `self.point_eval_interpolate` and the permutation + # given by `self.to_input_ordering_interpolate`. + if self.ufl_interpolate.is_adjoint: + symbolic = action(point_eval, point_eval_input_ordering) else: - V_dest = self.V - if output: - if output.function_space() != V_dest: - raise ValueError("Given output has the wrong function space!") + symbolic = action(point_eval_input_ordering, point_eval) + + def callable() -> PETSc.Mat: + return assemble(symbolic).petscmat + elif self.ufl_interpolate.is_adjoint: + assert self.rank == 1 + # f_src is a cofunction on V_dest.dual + cofunc = self.dual_arg + assert isinstance(cofunc, Cofunction) + + # Our first adjoint operation is to assign the dat values to a + # P0DG cofunction on our input ordering VOM. + f_input_ordering = Cofunction(P0DG_vom_input_ordering.dual()) + f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:] + + # The rest of the adjoint interpolation is the composition + # of the adjoint interpolators in the reverse direction. + # We don't worry about skipping over missing points here + # because we're going from the input ordering VOM to the original VOM + # and all points from the input ordering VOM are in the original. + def callable() -> Cofunction: + f_src_at_src_node_coords = assemble(action(point_eval_input_ordering, f_input_ordering)) + assemble(action(point_eval, f_src_at_src_node_coords), tensor=f) + return f else: - if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)): - output = self.V - else: - output = firedrake.Function(V_dest) - - if not adjoint: - if f_src is self.expr: - # f_src is already contained in self.point_eval_interpolate - assert not self.nargs - f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(self.point_eval_interpolate) - ) - else: - f_src_at_dest_node_coords_src_mesh_decomp = ( - assemble(action(self.point_eval_interpolate, f_src)) - ) - f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Function( - self.to_input_ordering_interpolate.function_space() - ) - # We have to create the Function before interpolating so we can - # set default missing values (if requested). - if default_missing_val is not None: - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ - : - ] = default_missing_val - elif self._allow_missing_dofs: - # If we have allowed missing points we know we might end up - # with points in the target mesh that are not in the source - # mesh. However, since we haven't specified a default missing - # value we expect the interpolation to leave these points - # unchanged. By setting the dat values to NaN we can later - # identify these points and skip over them when assigning to - # the output function. - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[:] = numpy.nan - - interp = action(self.to_input_ordering_interpolate, f_src_at_dest_node_coords_src_mesh_decomp) - assemble(interp, tensor=f_src_at_dest_node_coords_dest_mesh_decomp) - - # we can now confidently assign this to a function on V_dest - if self._allow_missing_dofs and default_missing_val is None: - indices = numpy.where( - ~numpy.isnan(f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro) - )[0] - output.dat.data_wo[ - indices - ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[indices] - else: - output.dat.data_wo[ - : - ] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[:] + assert self.rank in {0, 1} + # We create the input-ordering Function before interpolating so we can + # set default missing values if required. + f_point_eval_input_ordering = Function(P0DG_vom_input_ordering) + if self.default_missing_val is not None: + f_point_eval_input_ordering.assign(self.default_missing_val) + elif self.allow_missing_dofs: + # If we allow missing points there may be points in the target + # mesh that are not in the source mesh. If we don't specify a + # default missing value we set these to NaN so we can identify + # them later. + f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan + + def callable() -> Function | Number: + assemble(action(point_eval_input_ordering, point_eval), tensor=f_point_eval_input_ordering) + # We assign these values to the output function + if self.allow_missing_dofs and self.default_missing_val is None: + indices = numpy.where(~numpy.isnan(f_point_eval_input_ordering.dat.data_ro))[0] + f.dat.data_wo[indices] = f_point_eval_input_ordering.dat.data_ro[indices] + else: + f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:] - else: - # adjoint interpolation - - # f_src is a cofunction on V_dest.dual as originally specified when - # creating the interpolator. Our first adjoint operation is to - # assign the dat values to a P0DG cofunction on our input ordering - # VOM. This has the parallel decomposition V_dest on our orinally - # specified dest_mesh. We can therefore safely create a P0DG - # cofunction on the input-ordering VOM (which has this parallel - # decomposition and ordering) and assign the dat values. - f_src_at_dest_node_coords_dest_mesh_decomp = firedrake.Cofunction( - self.to_input_ordering_interpolate.function_space().dual() - ) - f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_wo[ - : - ] = f_src.dat.data_ro[:] - - # The rest of the adjoint interpolation is merely the composition - # of the adjoint interpolators in the reverse direction. NOTE: I - # don't have to worry about skipping over missing points here - # because I'm going from the input ordering VOM to the original VOM - # and all points from the input ordering VOM are in the original. - interp = action(expr_adjoint(self.to_input_ordering_interpolate), f_src_at_dest_node_coords_dest_mesh_decomp) - f_src_at_src_node_coords = assemble(interp) - # NOTE: if I wanted the default missing value to be applied to - # adjoint interpolation I would have to do it here. However, - # this would require me to implement default missing values for - # adjoint interpolation from a point evaluation interpolator - # which I haven't done. I wonder if it is necessary - perhaps the - # adjoint operator always sets all the values of the resulting - # cofunction? My initial attempt to insert setting the dat values - # prior to performing the multHermitian operation in - # SameMeshInterpolator.interpolate did not effect the result. For - # now, I say in the docstring that it only applies to forward - # interpolation. - interp = action(expr_adjoint(self.point_eval_interpolate), f_src_at_src_node_coords) - assemble(interp, tensor=output) - - return output + if self.rank == 0: + # We take the action of the dual_arg on the interpolated function + assert isinstance(self.dual_arg, Cofunction) + return assemble(action(self.dual_arg, f)) + else: + return f + return callable class SameMeshInterpolator(Interpolator): @@ -704,247 +544,196 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, subset=None, freeze_expr=False, access=None, - bcs=None, matfree=True, allow_missing_dofs=False, **kwargs): + def __init__(self, expr): + super().__init__(expr) + subset = self.subset if subset is None: - if isinstance(expr, ufl.Interpolate): - operand, = expr.ufl_operands - else: - operand = expr - target_mesh = as_domain(V) - source_mesh = extract_unique_domain(operand) or target_mesh - target = target_mesh.topology - source = source_mesh.topology - if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: + target = self.target_mesh.topology + source = self.source_mesh.topology + if all(isinstance(m, MeshTopology) for m in [target, source]) and target is not source: composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) if result_integral_type != "cell": - raise AssertionError("Only cell-cell interpolation supported") + raise AssertionError("Only cell-cell interpolation supported.") indices_active = composed_map.indices_active_with_halo make_subset = not indices_active.all() make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) if make_subset: - if not allow_missing_dofs: - raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") + if not self.allow_missing_dofs: + raise ValueError("Iteration (sub)set unclear: run with `allow_missing_dofs=True`.") subset = op2.Subset(target.cell_set, numpy.where(indices_active)) else: # Do not need subset as target <= source. pass - super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, - access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs) - expr = self.ufl_interpolate_renumbered - try: - self.callable = make_interpolator(expr, V, subset, self.access, bcs=bcs, matfree=matfree) - except FIAT.hdiv_trace.TraceError: - raise NotImplementedError("Can't interpolate onto traces sorry") - self.arguments = expr.arguments() - - @PETSc.Log.EventDecorator() - def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs): - """Compute the interpolation. - - For arguments, see :class:`.Interpolator`. - """ - - if transpose is not None: - warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning) - adjoint = transpose or adjoint - try: - assembled_interpolator = self.frozen_assembled_interpolator - copy_required = True - except AttributeError: - assembled_interpolator = self.callable() - copy_required = False # Return the original - if self.freeze_expr: - if len(self.arguments) == 2: - # Interpolation operator - self.frozen_assembled_interpolator = assembled_interpolator - else: - # Interpolation action - self.frozen_assembled_interpolator = assembled_interpolator.copy() - - if len(self.arguments) == 2 and len(function) > 0: - function, = function - if not hasattr(function, "dat"): - raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!") - if adjoint: - mul = assembled_interpolator.handle.multHermitian - col, row = self.arguments - else: - mul = assembled_interpolator.handle.mult - row, col = self.arguments - V = row.function_space().dual() - assert function.function_space() == col.function_space() - - result = output or firedrake.Function(V) - with function.dat.vec_ro as x, result.dat.vec_wo as out: - if x is not out: - mul(x, out) - else: - out_ = out.duplicate() - mul(x, out_) - out_.copy(result=out) - return result + self.subset = subset - else: - if output: - output.assign(assembled_interpolator) - return output - if isinstance(self.V, firedrake.Function): - if copy_required: - self.V.assign(assembled_interpolator) - return self.V - else: - if len(self.arguments) == 0: - return assembled_interpolator.dat.data.item() - elif copy_required: - return assembled_interpolator.copy() - else: - return assembled_interpolator + if not isinstance(self.dual_arg, Coargument): + # Matrix-free assembly of 0-form or 1-form requires INC access + if self.access and self.access != op2.INC: + raise ValueError("Matfree adjoint interpolation requires INC access") + self.access = op2.INC + elif self.access is None: + # Default access for forward 1-form or 2-form (forward and adjoint) + self.access = op2.WRITE + def _get_tensor(self) -> op2.Mat | Function | Cofunction: + """Return a suitable tensor to interpolate into. -@PETSc.Log.EventDecorator() -def make_interpolator(expr, V, subset, access, bcs=None, matfree=True): - if not isinstance(expr, ufl.Interpolate): - raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.") - dual_arg, operand = expr.argument_slots() - target_mesh = as_domain(dual_arg) - source_mesh = extract_unique_domain(operand) or target_mesh - vom_onto_other_vom = ((source_mesh is not target_mesh) - and isinstance(source_mesh.topology, VertexOnlyMeshTopology) - and isinstance(target_mesh.topology, VertexOnlyMeshTopology)) - - arguments = expr.arguments() - rank = len(arguments) - if rank <= 1: - if rank == 0: - R = firedrake.FunctionSpace(target_mesh, "Real", 0) - f = firedrake.Function(R, dtype=utils.ScalarType) - elif isinstance(V, firedrake.Function): - f = V - V = f.function_space() - else: - V_dest = arguments[0].function_space().dual() - f = firedrake.Function(V_dest) - if access in {firedrake.MIN, firedrake.MAX}: + Returns + ------- + op2.Mat | Function | Cofunction + The tensor to interpolate into. + """ + if self.rank == 0: + R = FunctionSpace(self.target_mesh, "Real", 0) + f = Function(R, dtype=ScalarType) + elif self.rank == 1: + f = Function(self.ufl_interpolate.function_space()) + if self.access in {op2.MIN, op2.MAX}: finfo = numpy.finfo(f.dat.dtype) - if access == firedrake.MIN: - val = firedrake.Constant(finfo.max) + if self.access == op2.MIN: + val = Constant(finfo.max) else: - val = firedrake.Constant(finfo.min) + val = Constant(finfo.min) f.assign(val) - tensor = f.dat - elif rank == 2: - if isinstance(V, firedrake.Function): - raise ValueError("Cannot interpolate an expression with an argument into a Function") - Vrow = arguments[0].function_space() - Vcol = arguments[1].function_space() - if len(Vrow) > 1 or len(Vcol) > 1: - raise TypeError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") - if isinstance(target_mesh.topology, VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a VertexOnlyMesh") - if target_mesh.geometric_dimension != source_mesh.geometric_dimension: - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - - if vom_onto_other_vom: - # We make our own linear operator for this case using PETSc SFs - tensor = None - else: - Vrow_map = get_interp_node_map(source_mesh, target_mesh, Vrow) - Vcol_map = get_interp_node_map(source_mesh, target_mesh, Vcol) + elif self.rank == 2: + Vrow = self.interpolate_args[0].function_space() + Vcol = self.interpolate_args[1].function_space() + if len(Vrow) > 1 or len(Vcol) > 1: + raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator") + Vrow_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vrow) + Vcol_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vcol) sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset), [(Vrow_map, Vcol_map, None)], # non-mixed - name="%s_%s_sparsity" % (Vrow.name, Vcol.name), + name=f"{Vrow.name}_{Vcol.name}_sparsity", nest=False, block_sparse=True) - tensor = op2.Mat(sparsity) - f = tensor - else: - raise ValueError(f"Cannot interpolate an expression with {rank} arguments") - - if vom_onto_other_vom: - wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree) - # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the - # data, including the correct data size and dimensional information - # (so for vector function spaces in 2 dimensions we might need a - # concatenation of 2 MPI.DOUBLE types when we are in real mode) - if tensor is not None: - # Callable will do interpolation into our pre-supplied function f - # when it is called. - assert f.dat is tensor - wrapper.mpi_type, _ = get_dat_mpi_type(f.dat) - assert len(arguments) == 1 - - def callable(): - wrapper.forward_operation(f.dat) - return f + f = op2.Mat(sparsity) else: - assert len(arguments) == 2 - assert tensor is None - # we know we will be outputting either a function or a cofunction, - # both of which will use a dat as a data carrier. At present, the - # data type does not depend on function space dimension, so we can - # safely use the argument function space. NOTE: If this changes - # after cofunctions are fully implemented, this will need to be - # reconsidered. - temp_source_func = firedrake.Function(Vcol) - wrapper.mpi_type, _ = get_dat_mpi_type(temp_source_func.dat) - - # Leave wrapper inside a callable so we can access the handle - # property. If matfree is True, then the handle is a PETSc SF - # pretending to be a PETSc Mat. If matfree is False, then this - # will be a PETSc Mat representing the equivalent permutation - # matrix - def callable(): - return wrapper + raise ValueError(f"Cannot interpolate an expression with {self.rank} arguments") + return f + + def _get_callable(self, tensor=None, bcs=None): + f = tensor or self._get_tensor() + op2_tensor = f if isinstance(f, op2.Mat) else f.dat - return callable - else: loops = [] - # Initialise to zero if needed - if access is op2.INC: - loops.append(tensor.zero) # Arguments in the operand are allowed to be from a MixedFunctionSpace # We need to split the target space V and generate separate kernels - if len(arguments) == 2: - # Matrix case assumes that the spaces are not mixed - expressions = {(0,): expr} - elif isinstance(dual_arg, Coargument): + if self.rank == 2: + expressions = {(0,): self.ufl_interpolate} + elif isinstance(self.dual_arg, Coargument): # Split in the coargument - expressions = dict(firedrake.formmanipulation.split_form(expr)) + expressions = dict(split_form(self.ufl_interpolate)) else: + assert isinstance(self.dual_arg, Cofunction) # Split in the cofunction: split_form can only split in the coargument # Replace the cofunction with a coargument to construct the Jacobian - interp = expr._ufl_expr_reconstruct_(operand, V) + interp = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space) # Split the Jacobian into blocks - interp_split = dict(firedrake.formmanipulation.split_form(interp)) + interp_split = dict(split_form(interp)) # Split the cofunction - dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + dual_split = dict(split_form(self.dual_arg)) # Combine the splits by taking their action expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): - sub_tensor = tensor[indices[0]] if rank == 1 else tensor - loops.extend(_interpolator(sub_tensor, sub_expr, subset, access, bcs=bcs)) - # Apply bcs - if bcs and rank == 1: + sub_op2_tensor = op2_tensor[indices[0]] if self.rank == 1 else op2_tensor + loops.extend(_build_interpolation_callables(sub_expr, sub_op2_tensor, self.access, self.subset, bcs)) + + if bcs and self.rank == 1: loops.extend(partial(bc.apply, f) for bc in bcs) - def callable(loops, f): + def callable() -> Function | Cofunction | PETSc.Mat | Number: for l in loops: l() - return f + if self.rank == 0: + return f.dat.data.item() + elif self.rank == 2: + return f.handle # In this case f is an op2.Mat + else: + return f + + return callable - return partial(callable, loops, f) + +class VomOntoVomInterpolator(SameMeshInterpolator): + + def __init__(self, expr: Interpolate): + super().__init__(expr) + + def _get_callable(self, tensor=None, bcs=None): + if bcs: + raise NotImplementedError("bcs not implemented for vom-to-vom interpolation.") + self.mat = VomOntoVomMat(self) + if self.rank == 1: + f = tensor or self._get_tensor() + # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the + # data, including the correct data size and dimensional information + # (so for vector function spaces in 2 dimensions we might need a + # concatenation of 2 MPI.DOUBLE types when we are in real mode) + self.mat.mpi_type = _get_mtype(f.dat)[0] + if self.ufl_interpolate.is_adjoint: + assert isinstance(self.dual_arg, Cofunction) + assert isinstance(f, Cofunction) + + def callable() -> Cofunction: + with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec: + self.mat.handle.multHermitian(source_vec, target_vec) + return f + else: + assert isinstance(f, Function) + + def callable() -> Function: + coeff = self.mat.expr_as_coeff() + with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec: + self.mat.handle.mult(coeff_vec, target_vec) + return f + elif self.rank == 2: + # Create a temporary function to get the correct MPI type + temp_source_func = Function(self.interpolate_args[1].function_space()) + self.mat.mpi_type = _get_mtype(temp_source_func.dat)[0] + + def callable() -> PETSc.Mat: + return self.mat.handle + + return callable -@utils.known_pyop2_safe -def _interpolator(tensor, expr, subset, access, bcs=None): - if isinstance(expr, ufl.ZeroBaseForm): +@known_pyop2_safe +def _build_interpolation_callables( + expr: Interpolate | ZeroBaseForm, + tensor: op2.Dat | op2.Mat | op2.Global, + access: Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC], + subset: op2.Subset | None = None, + bcs: Iterable[DirichletBC] | None = None +) -> tuple[Callable, ...]: + """Return a tuple of callables which calculate the interpolation. + + Parameters + ---------- + expr : ufl.Interpolate | ufl.ZeroBaseForm + The symbolic interpolation expression, or a zero form. Zero forms + are simplified here to avoid code generation when access is WRITE or INC. + tensor : op2.Dat | op2.Mat | op2.Global + Object to hold the result of the interpolation. + access : Literal[op2.WRITE, op2.MIN, op2.MAX, op2.INC] + op2 access descriptor + subset : op2.Subset | None + An optional subset to apply the interpolation over, by default None. + bcs : Iterable[DirichletBC] | None + An optional list of boundary conditions to zero-out in the + output function space. Interpolator rows or columns which are + associated with boundary condition nodes are zeroed out when this is + specified. By default None, by default None. + + Returns + ------- + tuple[Callable, ...] + Tuple of callables which perform the interpolation. + """ + if isinstance(expr, ZeroBaseForm): # Zero simplification, avoid code-generation if access is op2.INC: return () @@ -953,55 +742,35 @@ def _interpolator(tensor, expr, subset, access, bcs=None): # Unclear how to avoid codegen for MIN and MAX # Reconstruct the expression as an Interpolate V = expr.arguments()[-1].function_space().dual() - expr = interpolate(ufl.zero(V.value_shape), V) - - if not isinstance(expr, ufl.Interpolate): - raise ValueError("Expecting to interpolate a ufl.Interpolate") - - arguments = expr.arguments() + expr = interpolate(zero(V.value_shape), V) + if not isinstance(expr, Interpolate): + raise ValueError("Expecting to interpolate a symbolic Interpolate expression.") dual_arg, operand = expr.argument_slots() - V = dual_arg.arguments()[0].function_space() - + V = dual_arg.function_space().dual() try: to_element = create_element(V.ufl_element()) except KeyError: # FInAT only elements - raise NotImplementedError("Don't know how to create FIAT element for %s" % V.ufl_element()) + raise NotImplementedError(f"Don't know how to create FIAT element for {V.ufl_element()}") if access is op2.READ: raise ValueError("Can't have READ access for output function") # NOTE: The par_loop is always over the target mesh cells. - target_mesh = as_domain(V) + target_mesh = V.mesh() source_mesh = extract_unique_domain(operand) or target_mesh if isinstance(target_mesh.topology, VertexOnlyMeshTopology): - if target_mesh is not source_mesh: - if not isinstance(target_mesh.topology, VertexOnlyMeshTopology): - raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh") - if target_mesh.geometric_dimension != source_mesh.geometric_dimension: - raise ValueError("Cannot interpolate onto a mesh of a different geometric dimension") - if not hasattr(target_mesh, "_parent_mesh") or target_mesh._parent_mesh is not source_mesh: - raise ValueError("Can only interpolate across meshes where the source mesh is the parent of the target") - # For trans-mesh interpolation we use a FInAT QuadratureElement as the - # (base) target element with runtime point set expressions as their - # quadrature rule point set and weights from their dual basis. - # NOTE: This setup is useful for thinking about future design - in the - # future this `rebuild` function can be absorbed into FInAT as a - # transformer that eats an element and gives you an equivalent (which - # may or may not be a QuadratureElement) that lets you do run time - # tabulation. Alternatively (and this all depends on future design - # decision about FInAT how dual evaluation should work) the - # to_element's dual basis (which look rather like quadrature rules) can - # have their pointset(s) directly replaced with run-time tabulated - # equivalent(s) (i.e. finat.point_set.UnknownPointSet(s)) - rt_var_name = 'rt_X' - try: - cell = operand.ufl_element().ufl_cell() - except AttributeError: - # expression must be pure function of spatial coordinates so - # domain has correct ufl cell - cell = source_mesh.ufl_cell() - to_element = rebuild(to_element, cell, rt_var_name) + # For interpolation onto a VOM, we use a FInAT QuadratureElement as the + # target element with runtime point set expressions as their + # quadrature rule point set. + rt_var_name = 'rt_X' + try: + cell = operand.ufl_element().ufl_cell() + except AttributeError: + # expression must be pure function of spatial coordinates so + # domain has correct ufl cell + cell = source_mesh.ufl_cell() + to_element = rebuild(to_element, cell, rt_var_name) cell_set = target_mesh.cell_set if subset is not None: @@ -1009,7 +778,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None): cell_set = subset parameters = {} - parameters['scalar_type'] = utils.ScalarType + parameters['scalar_type'] = ScalarType copyin = () copyout = () @@ -1017,11 +786,11 @@ def _interpolator(tensor, expr, subset, access, bcs=None): # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple # contributions from the facet DOFs of the dual argument. # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity. - needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg() + needs_weight = isinstance(dual_arg, Cofunction) and not to_element.is_dg() if needs_weight: # Create a buffer for the weighted Cofunction W = dual_arg.function_space() - v = firedrake.Function(W) + v = Function(W) expr = expr._ufl_expr_reconstruct_(operand, v=v) copyin += (partial(dual_arg.dat.copy, v.dat),) @@ -1061,7 +830,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None): parloop_args = [kernel, cell_set] - coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers) + coefficients = extract_numbered_coefficients(expr, coefficient_numbers) if needs_external_coords: coefficients = [source_mesh.coordinates] + coefficients @@ -1071,6 +840,8 @@ def _interpolator(tensor, expr, subset, access, bcs=None): if access is not op2.WRITE: copyin += (partial(output.copy, tensor), ) copyout += (partial(tensor.copy, output), ) + + arguments = expr.arguments() if isinstance(tensor, op2.Global): parloop_args.append(tensor(access)) elif isinstance(tensor, op2.Dat): @@ -1084,20 +855,21 @@ def _interpolator(tensor, expr, subset, access, bcs=None): assert tensor.handle.getSize() == (Vrow.dim(), Vcol.dim()) rows_map = get_interp_node_map(source_mesh, target_mesh, Vrow) columns_map = get_interp_node_map(source_mesh, target_mesh, Vcol) - lgmaps = None if bcs: - if ufl.duals.is_dual(Vrow): + if is_dual(Vrow): Vrow = Vrow.dual() - if ufl.duals.is_dual(Vcol): + if is_dual(Vcol): Vcol = Vcol.dual() bc_rows = [bc for bc in bcs if bc.function_space() == Vrow] bc_cols = [bc for bc in bcs if bc.function_space() == Vcol] lgmaps = [(Vrow.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))] parloop_args.append(tensor(access, (rows_map, columns_map), lgmaps=lgmaps)) + if oriented: co = target_mesh.cell_orientations() parloop_args.append(co.dat(op2.READ, co.cell_node_map())) + if needs_cell_sizes: cs = source_mesh.cell_sizes parloop_args.append(cs.dat(op2.READ, cs.cell_node_map())) @@ -1137,10 +909,12 @@ def _interpolator(tensor, expr, subset, access, bcs=None): if isinstance(tensor, op2.Mat): return parloop, tensor.assemble else: + if access == op2.INC: + copyin += (tensor.zero,) return copyin + (parloop, ) + copyout -def get_interp_node_map(source_mesh, target_mesh, fs): +def get_interp_node_map(source_mesh: MeshGeometry, target_mesh: MeshGeometry, fs: WithGeometry) -> op2.Map | None: """Return the map between cells of the target mesh and nodes of the function space. If the function space is defined on the source mesh then the node map is composed @@ -1185,12 +959,12 @@ def get_interp_node_map(source_mesh, target_mesh, fs): def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]: """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`.""" dual_arg, operand = expr.argument_slots() - return (hash_expr(operand), type(dual_arg), hash(ufl_element), utils.tuplify(parameters)) + return (hash_expr(operand), type(dual_arg), hash(ufl_element), tuplify(parameters)) @memory_and_disk_cache( hashkey=_compile_expression_key, - cachedir=tsfc_interface._cachedir + cachedir=_cachedir ) @PETSc.Log.EventDecorator() def compile_expression(comm, *args, **kwargs): @@ -1199,50 +973,57 @@ def compile_expression(comm, *args, **kwargs): @singledispatch def rebuild(element, expr_cell, rt_var_name): - raise NotImplementedError(f"Cross mesh interpolation not implemented for a {element} element.") + """Construct a FInAT QuadratureElement for interpolation onto a + VertexOnlyMesh. The quadrature point is an UnknownPointSet of shape + (1, tdim) where tdim is the topological dimension of expr_cell. The + weight is [1.0], since the single local dof in the VertexOnlyMesh function + space corresponds to a point evaluation at the vertex. + Parameters + ---------- + element : finat.FiniteElementBase + The FInAT element to construct a QuadratureElement for. + expr_cell : ufl.Cell + The UFL cell of the expression being interpolated. + rt_var_name : str + String beginning with 'rt_' which is used as the name of the + gem.Variable used to represent the UnknownPointSet. The `rt_` prefix + forces TSFC to do runtime tabulation. + + Raises + ------ + NotImplementedError + If the element type is not implemented yet. + """ + raise NotImplementedError(f"Point evaluation not implemented for a {element} element.") -@rebuild.register(finat.fiat_elements.ScalarFiatElement) + +@rebuild.register(ScalarFiatElement) def rebuild_dg(element, expr_cell, rt_var_name): - # To tabulate on the given element (which is on a different mesh to the - # expression) we must do so at runtime. We therefore create a quadrature - # element with runtime points to evaluate for each point in the element's - # dual basis. This exists on the same reference cell as the input element - # and we can interpolate onto it before mapping the result back onto the - # target space. - expr_tdim = expr_cell.topological_dimension - # Need point evaluations and matching weights from dual basis. - # This could use FIAT's dual basis as below: - # num_points = sum(len(dual.get_point_dict()) for dual in element.fiat_equivalent.dual_basis()) - # weights = [] - # for dual in element.fiat_equivalent.dual_basis(): - # pts = dual.get_point_dict().keys() - # for p in pts: - # for w, _ in dual.get_point_dict()[p]: - # weights.append(w) - # assert len(weights) == num_points - # but for now we just fix the values to what we know works: - if element.degree != 0 or not isinstance(element.cell, FIAT.reference_element.Point): - raise NotImplementedError("Cross mesh interpolation only implemented for P0DG on vertex cells.") - num_points = 1 - weights = [1.]*num_points + # QuadratureElements have a dual basis which is point evaluation at the + # quadrature points. By using an UnknownPointSet with one point, TSFC + # will generate a kernel with an argument to which we can pass the reference + # coordinates of a point and evaluate the expression at that point at runtime. + if element.degree != 0 or not isinstance(element.cell, Point): + raise NotImplementedError("Interpolation onto a VOM only implemented for P0DG on vertex cells.") + # gem.Variable name starting with rt_ forces TSFC runtime tabulation assert rt_var_name.startswith("rt_") - runtime_points_expr = gem.Variable(rt_var_name, (num_points, expr_tdim)) - rule_pointset = finat.point_set.UnknownPointSet(runtime_points_expr) - rule = finat.quadrature.QuadratureRule(rule_pointset, weights=weights) - return finat.QuadratureElement(as_fiat_cell(expr_cell), rule) + runtime_points_expr = Variable(rt_var_name, (1, expr_cell.topological_dimension)) + rule_pointset = UnknownPointSet(runtime_points_expr) + # What we use for the weight doesn't matter since we are not integrating + rule = QuadratureRule(rule_pointset, weights=[0.0]) + return QuadratureElement(as_fiat_cell(expr_cell), rule) -@rebuild.register(finat.TensorFiniteElement) +@rebuild.register(TensorFiniteElement) def rebuild_te(element, expr_cell, rt_var_name): - return finat.TensorFiniteElement(rebuild(element.base_element, - expr_cell, rt_var_name), - element._shape, - transpose=element._transpose) + return TensorFiniteElement(rebuild(element.base_element, expr_cell, rt_var_name), + element._shape, + transpose=element._transpose) -def compose_map_and_cache(map1, map2): +def compose_map_and_cache(map1: op2.Map, map2: op2.Map | None) -> op2.ComposedMap | None: """ Retrieve a :class:`pyop2.ComposedMap` map from the cache of map1 using map2 as the cache key. The composed map maps from the iterset @@ -1265,7 +1046,7 @@ def compose_map_and_cache(map1, map2): return cmap -def vom_cell_parent_node_map_extruded(vertex_only_mesh, extruded_cell_node_map): +def vom_cell_parent_node_map_extruded(vertex_only_mesh: MeshGeometry, extruded_cell_node_map: op2.Map) -> op2.Map: """Build a map from the cells of a vertex only mesh to the nodes of the nodes on the source mesh where the source mesh is extruded. @@ -1391,118 +1172,74 @@ def __init__(self, glob): self.ufl_domain = lambda: None -class VomOntoVomWrapper(object): - """Utility class for interpolating from one ``VertexOnlyMesh`` to it's - intput ordering ``VertexOnlyMesh``, or vice versa. - - Parameters - ---------- - V : `.FunctionSpace` - The P0DG function space (which may be vector or tensor valued) on the - source vertex-only mesh. - source_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate from. - target_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate to. - expr : `ufl.Expr` - The expression to interpolate. If ``arguments`` is not empty, those - arguments must be present within it. - matfree : bool - If ``False``, the matrix representating the permutation of the points is - constructed and used to perform the interpolation. If ``True``, then the - interpolation is performed using the broadcast and reduce operations on the - PETSc Star Forest. +class VomOntoVomMat: + """ + Object that facilitates interpolation between a VertexOnlyMesh and its + input_ordering VertexOnlyMesh. This is either a PETSc Star Forest wrapped + as a PETSc Mat, or a concrete PETSc seqaij Mat, depending on whether + matfree interpolation is requested. """ + def __init__(self, interpolator: VomOntoVomInterpolator): + """Initialise the VomOntoVomMat. - def __init__(self, V, source_vom, target_vom, expr, matfree): - arguments = extract_arguments(expr) - reduce = False - if source_vom.input_ordering is target_vom: - reduce = True - original_vom = source_vom - elif target_vom.input_ordering is source_vom: - original_vom = target_vom + Parameters + ---------- + interpolator : VomOntoVomInterpolator + A :class:`VomOntoVomInterpolator` object. + + Raises + ------ + ValueError + If the source and target vertex-only meshes are not linked by input_ordering. + """ + if interpolator.source_mesh.input_ordering is interpolator.target_mesh: + self.forward_reduce = True + """True if the forward interpolation is a star forest reduction, False if broadcast.""" + self.original_vom = interpolator.source_mesh + """The original VOM from which the SF is constructed.""" + elif interpolator.target_mesh.input_ordering is interpolator.source_mesh: + self.forward_reduce = False + self.original_vom = interpolator.target_mesh else: raise ValueError( "The target vom and source vom must be linked by input ordering!" ) - self.V = V - self.source_vom = source_vom - self.expr = expr - self.arguments = arguments - self.reduce = reduce - # note that interpolation doesn't include halo cells - self.dummy_mat = VomOntoVomDummyMat( - original_vom.input_ordering_without_halos_sf, reduce, V, source_vom, expr, arguments - ) - if matfree: - # If matfree, we use the SF to perform the interpolation - self.handle = self.dummy_mat._wrap_dummy_mat() - else: - # Otherwise we create the permutation matrix - self.handle = self.dummy_mat._create_permutation_mat() - - @property - def mpi_type(self): - """ - The MPI type to use for the PETSc SF. - - Should correspond to the underlying data type of the PETSc Vec. - """ - return self.handle.mpi_type + self.sf = self.original_vom.input_ordering_without_halos_sf + """The PETSc Star Forest representing the permutation between the VOMs.""" + self.target_space = interpolator.target_space + """The FunctionSpace being interpolated into.""" + self.source_vom = interpolator.source_mesh + """The VOM being interpolated from.""" + self.operand = interpolator.operand + """The expression in the primal slot of the Interpolate.""" + self.arguments = extract_arguments(self.operand) + """The arguments of the expression being interpolated.""" + self.is_adjoint = interpolator.ufl_interpolate.is_adjoint + """Are we doing the adjoint interpolation?""" - @mpi_type.setter - def mpi_type(self, val): - self.dummy_mat.mpi_type = val - - def forward_operation(self, target_dat): - coeff = self.dummy_mat.expr_as_coeff() - with coeff.dat.vec_ro as coeff_vec, target_dat.vec_wo as target_vec: - self.handle.mult(coeff_vec, target_vec) - - -class VomOntoVomDummyMat(object): - """Dummy object to stand in for a PETSc ``Mat`` when we are interpolating - between vertex-only meshes. - - Parameters - ---------- - sf: PETSc.sf - The PETSc Star Forest (SF) to use for the operation - forward_reduce : bool - If ``True``, the action of the operator (accessed via the `mult` - method) is to perform a SF reduce from the source vec to the target - vec, whilst the adjoint action (accessed via the `multHermitian` - method) is to perform a SF broadcast from the source vec to the target - vec. If ``False``, the opposite is true. - V : `.FunctionSpace` - The P0DG function space (which may be vector or tensor valued) on the - source vertex-only mesh. - source_vom : `.VertexOnlyMesh` - The vertex-only mesh we interpolate from. - expr : `ufl.Expr` - The expression to interpolate. If ``arguments`` is not empty, those - arguments must be present within it. - arguments : list of `ufl.Argument` - The arguments in the expression. - """ - - def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments): - self.sf = sf - self.forward_reduce = forward_reduce - self.V = V - self.source_vom = source_vom - self.expr = expr - self.arguments = arguments # Calculate correct local and global sizes for the matrix - nroots, leaves, _ = sf.getGraph() + nroots, leaves, _ = self.sf.getGraph() self.nleaves = len(leaves) - self._local_sizes = V.comm.allgather(nroots) - self.source_size = (self.V.block_size * nroots, self.V.block_size * sum(self._local_sizes)) + """The local number of leaves in the SF.""" + self._local_sizes = self.target_space.comm.allgather(nroots) + """List of local number of roots on each process.""" + self.source_size = (self.target_space.block_size * nroots, self.target_space.block_size * sum(self._local_sizes)) + """Tuple containing the local and global size of the source space.""" self.target_size = ( - self.V.block_size * self.nleaves, - self.V.block_size * V.comm.allreduce(self.nleaves, op=MPI.SUM), + self.target_space.block_size * self.nleaves, + self.target_space.block_size * self.target_space.comm.allreduce(self.nleaves, op=MPI.SUM), ) + """Tuple containing the local and global size of the target space.""" + + if interpolator.matfree: + # If matfree, we use the SF wrapped as a PETSc Mat + # to perform the permutation. This is the default. + self.handle = self._wrap_python_mat() + else: + # If matfree=False, then we build the concrete permutation + # matrix as a PETSc seqaij Mat. This is used to build the + # cross-mesh interpolation matrix. + self.handle = self._create_permutation_mat() @property def mpi_type(self): @@ -1517,23 +1254,35 @@ def mpi_type(self): def mpi_type(self, val): self._mpi_type = val - def expr_as_coeff(self, source_vec=None): - """ - Return a coefficient that corresponds to the expression used at + def expr_as_coeff(self, source_vec: PETSc.Vec | None = None) -> Function: + """Return a Function that corresponds to the expression used at construction, where the expression has been interpolated into the P0DG function space on the source vertex-only mesh. Will fail if there are no arguments. + + Parameters + ---------- + source_vec : PETSc.Vec | None, optional + Optional vector used to replace arguments in the expression. + By default None. + + Returns + ------- + Function + A Function representing the expression as a coefficient on the + source vertex-only mesh. + """ # Since we always output a coefficient when we don't have arguments in # the expression, we should evaluate the expression on the source mesh # so its dat can be sent to the target mesh. with stop_annotating(): - element = self.V.ufl_element() # Could be vector/tensor valued - P0DG = firedrake.FunctionSpace(self.source_vom, element) + element = self.target_space.ufl_element() # Could be vector/tensor valued + P0DG = FunctionSpace(self.source_vom, element) # if we have any arguments in the expression we need to replace # them with equivalent coefficients now - coeff_expr = self.expr + coeff_expr = self.operand if len(self.arguments): if len(self.arguments) > 1: raise NotImplementedError( @@ -1542,15 +1291,24 @@ def expr_as_coeff(self, source_vec=None): if source_vec is None: raise ValueError("Need to provide a source dat for the argument!") arg = self.arguments[0] - arg_coeff = firedrake.Function(arg.function_space()) + arg_coeff = Function(arg.function_space()) arg_coeff.dat.data_wo[:] = source_vec.getArray(readonly=True).reshape( arg_coeff.dat.data_wo.shape ) - coeff_expr = ufl.replace(self.expr, {arg: arg_coeff}) - coeff = firedrake.Function(P0DG).interpolate(coeff_expr) + coeff_expr = replace(self.operand, {arg: arg_coeff}) + coeff = Function(P0DG).interpolate(coeff_expr) return coeff - def reduce(self, source_vec, target_vec): + def reduce(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Reduce data in source_vec using the PETSc SF. + + Parameters + ---------- + source_vec : PETSc.Vec + The vector to reduce. + target_vec : PETSc.Vec + The vector to store the result in. + """ source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.reduceBegin( @@ -1566,7 +1324,17 @@ def reduce(self, source_vec, target_vec): MPI.REPLACE, ) - def broadcast(self, source_vec, target_vec): + def broadcast(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Broadcast data in source_vec using the PETSc SF, storing the + result in target_vec. + + Parameters + ---------- + source_vec : PETSc.Vec + The vector to broadcast. + target_vec : PETSc.Vec + The vector to store the result in. + """ source_arr = source_vec.getArray(readonly=True) target_arr = target_vec.getArray() self.sf.bcastBegin( @@ -1582,8 +1350,21 @@ def broadcast(self, source_vec, target_vec): MPI.REPLACE, ) - def mult(self, mat, source_vec, target_vec): - # need to evaluate expression before doing mult + def mult(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Apply the interpolation operator to source_vec, storing the + result in target_vec. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + """ + # Need to convert the expression into a coefficient + # so that we can broadcast/reduce it coeff = self.expr_as_coeff(source_vec) with coeff.dat.vec_ro as coeff_vec: if self.forward_reduce: @@ -1591,10 +1372,36 @@ def mult(self, mat, source_vec, target_vec): else: self.broadcast(coeff_vec, target_vec) - def multHermitian(self, mat, source_vec, target_vec): + def multHermitian(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Apply the adjoint of the interpolation operator to source_vec, storing the + result in target_vec. Since ``VomOntoVomMat`` represents a permutation, it is + real-valued and thus the Hermitian adjoint is the transpose. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to adjoint interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + """ self.multTranspose(mat, source_vec, target_vec) - def multTranspose(self, mat, source_vec, target_vec): + def multTranspose(self, mat: PETSc.Mat, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None: + """Apply the tranpose of the interpolation operator to source_vec, storing the + result in target_vec. Called by `self.multHermitian`. + + Parameters + ---------- + mat : PETSc.Mat + Required by petsc4py but unused. + source_vec : PETSc.Vec + The vector to transpose interpolate. + target_vec : PETSc.Vec + The vector to store the result in. + + """ # can only do adjoint if our expression exclusively contains a # single argument, making the application of the adjoint operator # straightforward (haven't worked out how to do this otherwise!) @@ -1602,7 +1409,7 @@ def multTranspose(self, mat, source_vec, target_vec): raise NotImplementedError( "Can only apply adjoint to expressions with one argument!" ) - if self.arguments[0] is not self.expr: + if self.arguments[0] is not self.operand: raise NotImplementedError( "Can only apply adjoint to expressions consisting of a single argument at the moment." ) @@ -1621,27 +1428,47 @@ def multTranspose(self, mat, source_vec, target_vec): target_vec.zeroEntries() self.reduce(source_vec, target_vec) - def _create_permutation_mat(self): - """Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to - its input ordering vertex-only mesh""" - mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.V.comm) + def _create_permutation_mat(self) -> PETSc.Mat: + """Create the PETSc matrix that represents the interpolation operator from a vertex-only mesh to + its input ordering vertex-only mesh. + + Returns + ------- + PETSc.Mat + PETSc seqaij matrix + """ + # To create the permutation matrix we broadcast an array of indices which are contiguous + # across all ranks and then use these indices to set the values of the matrix directly. + mat = PETSc.Mat().createAIJ((self.target_size, self.source_size), nnz=1, comm=self.target_space.comm) mat.setUp() - start = sum(self._local_sizes[:self.V.comm.rank]) + start = sum(self._local_sizes[:self.target_space.comm.rank]) end = start + self.source_size[0] - contiguous_indices = numpy.arange(start, end, dtype=utils.IntType) - perm = numpy.zeros(self.nleaves, dtype=utils.IntType) + contiguous_indices = numpy.arange(start, end, dtype=IntType) + perm = numpy.zeros(self.nleaves, dtype=IntType) # result stored in here self.sf.bcastBegin(MPI.INT, contiguous_indices, perm, MPI.REPLACE) self.sf.bcastEnd(MPI.INT, contiguous_indices, perm, MPI.REPLACE) - rows = numpy.arange(self.target_size[0] + 1, dtype=utils.IntType) - cols = (self.V.block_size * perm[:, None] + numpy.arange(self.V.block_size, dtype=utils.IntType)[None, :]).reshape(-1) - mat.setValuesCSR(rows, cols, numpy.ones_like(cols, dtype=utils.IntType)) + rows = numpy.arange(self.target_size[0] + 1, dtype=IntType) + # Vector and Tensor valued spaces are stored in a flattened array, so + # we need to space out the column indices according to the block size + cols = (self.target_space.block_size * perm[:, None] + numpy.arange(self.target_space.block_size, dtype=IntType)[None, :]).reshape(-1) + mat.setValuesCSR(rows, cols, numpy.ones_like(cols, dtype=IntType)) mat.assemble() - if self.forward_reduce: + if self.forward_reduce and not self.is_adjoint: + # The mat we have constructed thus far takes us from the input-ordering VOM to the + # immersed VOM. If we're going the other way, then we need to transpose it, + # unless we're doing the adjoint interpolation in which ca mat.transpose() return mat - def _wrap_dummy_mat(self): - mat = PETSc.Mat().create(comm=self.V.comm) + def _wrap_python_mat(self) -> PETSc.Mat: + """Wrap this object as a PETSc Mat. Used for matfree interpolation. + + Returns + ------- + PETSc.Mat + A PETSc Mat of type python with this object as its context. + """ + mat = PETSc.Mat().create(comm=self.target_space.comm) if self.forward_reduce: mat_size = (self.source_size, self.target_size) else: @@ -1652,54 +1479,67 @@ def _wrap_dummy_mat(self): mat.setUp() return mat - def duplicate(self, mat=None, op=None): - return self._wrap_dummy_mat() + def duplicate(self, mat: PETSc.Mat | None = None, op: PETSc.Mat.DuplicateOption | None = None) -> PETSc.Mat: + """Duplicate the matrix. Needed to wrap as a PETSc Python Mat. + + Parameters + ---------- + mat : PETSc.Mat | None, optional + Unused, by default None + op : PETSc.Mat.DuplicateOption | None, optional + Unused, by default None + + Returns + ------- + PETSc.Mat + VomOntoVomMat wrapped as a PETSc Mat of type python. + """ + return self._wrap_python_mat() class MixedInterpolator(Interpolator): - """A reusable interpolation object between MixedFunctionSpaces. + """Interpolator between MixedFunctionSpaces.""" + def __init__(self, expr: Interpolate): + """Initialise MixedInterpolator. Should not be called directly; use `get_interpolator`. - Parameters - ---------- - expr - The underlying ufl.Interpolate or the operand to the ufl.Interpolate. - V - The :class:`.FunctionSpace` or :class:`.Function` to - interpolate into. - bcs - A list of boundary conditions. - **kwargs - Any extra kwargs are passed on to the sub Interpolators. - For details see :class:`firedrake.interpolation.Interpolator`. - """ - def __init__(self, expr, V, bcs=None, **kwargs): - super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs) - expr = self.ufl_interpolate - self.arguments = expr.arguments() + Parameters + ---------- + expr : Interpolate + Symbolic Interpolate expression. + """ + super().__init__(expr) + + def _get_sub_interpolators(self, bcs: Iterable[DirichletBC] | None = None) -> dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]: + """Gets `Interpolator`s anf boundary conditions for each sub-Interpolate + in the mixed expression. + + Returns + ------- + dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]] + A map from block index tuples to `Interpolator`s and bcs. + """ # Get the primal spaces - spaces = tuple(a.function_space().dual() if isinstance(a, Coargument) else a.function_space() - for a in self.arguments) + spaces = tuple( + a.function_space().dual() if isinstance(a, Coargument) else a.function_space() for a in self.interpolate_args + ) # TODO consider a stricter equality test for indexed MixedFunctionSpace # See https://github.com/firedrakeproject/firedrake/issues/4668 space_equals = lambda V1, V2: V1 == V2 and V1.parent == V2.parent and V1.index == V2.index # We need a Coargument in order to split the Interpolate - needs_action = len([a for a in self.arguments if isinstance(a, Coargument)]) == 0 + needs_action = not any(isinstance(a, Coargument) for a in self.interpolate_args) if needs_action: - dual_arg, operand = expr.argument_slots() # Split the dual argument - dual_split = dict(firedrake.formmanipulation.split_form(dual_arg)) + dual_split = dict(split_form(self.dual_arg)) # Create the Jacobian to be split into blocks - expr = expr._ufl_expr_reconstruct_(operand, V) + self.ufl_interpolate = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space) - Isub = {} - # Split in the arguments of the Interpolate - for indices, form in firedrake.formmanipulation.split_form(expr): - if isinstance(form, ufl.ZeroBaseForm): + # Get sub-interpolators and sub-bcs for each block + Isub: dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]] = {} + for indices, form in split_form(self.ufl_interpolate): + if isinstance(form, ZeroBaseForm): # Ensure block sparsity continue - vi, _ = form.argument_slots() - Vtarget = vi.function_space().dual() sub_bcs = [] for space, index in zip(spaces, indices): subspace = space.sub(index) @@ -1707,44 +1547,29 @@ def __init__(self, expr, V, bcs=None, **kwargs): if needs_action: # Take the action of each sub-cofunction against each block form = action(form, dual_split[indices[-1:]]) - - Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs) - - self._sub_interpolators = Isub - self.callable = self._assemble_matnest - - def __getitem__(self, item): - return self._sub_interpolators[item] - - def __iter__(self): - return iter(self._sub_interpolators) - - def _assemble_matnest(self): - """Assemble the operator.""" - shape = tuple(len(a.function_space()) for a in self.arguments) - blocks = numpy.full(shape, PETSc.Mat(), dtype=object) - # Assemble the sparse block matrix - for i in self: - blocks[i] = self[i].callable().handle - petscmat = PETSc.Mat().createNest(blocks) - tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat) - return tensor.M - - def _interpolate(self, *function, output=None, adjoint=False, **kwargs): - """Assemble the action.""" - rank = len(self.arguments) - if rank == 0: - result = sum(self[i].assemble(**kwargs) for i in self) - return output.assign(result) if output else result - - if output is None: - output = firedrake.Function(self.arguments[-1].function_space().dual()) - - if rank == 1: - for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i].assemble(**kwargs) for i in self if i[0] == k)) - elif rank == 2: - for k, sub_tensor in enumerate(output.subfunctions): - sub_tensor.assign(sum(self[i]._interpolate(*function, adjoint=adjoint, **kwargs) - for i in self if i[0] == k)) - return output + Isub[indices] = (get_interpolator(form), sub_bcs) + + return Isub + + def _get_callable(self, tensor=None, bcs=None): + Isub = self._get_sub_interpolators(bcs=bcs) + V_dest = self.ufl_interpolate.function_space() or self.target_space + f = tensor or Function(V_dest) + if self.rank == 2: + def callable() -> PETSc.Mat: + shape = tuple(len(a.function_space()) for a in self.interpolate_args) + blocks = numpy.full(shape, PETSc.Mat(), dtype=object) + for indices, (interp, sub_bcs) in Isub.items(): + blocks[indices] = interp._get_callable(bcs=sub_bcs)() + return PETSc.Mat().createNest(blocks) + elif self.rank == 1: + def callable() -> Function | Cofunction: + for k, sub_tensor in enumerate(f.subfunctions): + sub_tensor.assign(sum( + interp.assemble(bcs=sub_bcs) for indices, (interp, sub_bcs) in Isub.items() if indices[0] == k + )) + return f + else: + def callable() -> Number: + return sum(interp.assemble(bcs=sub_bcs) for (interp, sub_bcs) in Isub.values()) + return callable diff --git a/firedrake/preconditioners/hiptmair.py b/firedrake/preconditioners/hiptmair.py index 1ee54fcf41..4021190a3d 100644 --- a/firedrake/preconditioners/hiptmair.py +++ b/firedrake/preconditioners/hiptmair.py @@ -10,7 +10,7 @@ from firedrake.preconditioners.hypre_ams import chop from firedrake.preconditioners.facet_split import restrict from firedrake.parameters import parameters -from firedrake.interpolation import Interpolator +from firedrake.interpolation import interpolate from ufl.algorithms.ad import expand_derivatives import firedrake.dmhooks as dmhooks import firedrake.utils as utils @@ -202,7 +202,7 @@ def coarsen(self, pc): coarse_space_bcs = tuple(coarse_space_bcs) if G_callback is None: - interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle) + interp_petscmat = chop(assemble(interpolate(dminus(trial), V), bcs=bcs + coarse_space_bcs).mat()) else: interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs) diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 35bf39de53..7d4a1ff1f2 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -1248,14 +1248,14 @@ def _kernels(self): return self._build_custom_interpolators() def _build_native_interpolators(self): - from firedrake.interpolation import interpolate, Interpolator - P = Interpolator(interpolate(self.uc, self.Vf), self.Vf) + from firedrake.interpolation import interpolate, get_interpolator + P = get_interpolator(interpolate(self.uc, self.Vf)) prolong = partial(P.assemble, tensor=self.uf) rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat) rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat) vc = firedrake.TestFunction(self.Vc) - R = Interpolator(interpolate(vc, rf), self.Vf) + R = get_interpolator(interpolate(vc, rf)) restrict = partial(R.assemble, tensor=rc) return prolong, restrict diff --git a/tests/firedrake/regression/test_interpolate.py b/tests/firedrake/regression/test_interpolate.py index a2a0d11618..b44477927d 100644 --- a/tests/firedrake/regression/test_interpolate.py +++ b/tests/firedrake/regression/test_interpolate.py @@ -584,7 +584,7 @@ def test_interpolator_reuse(family, degree, mode): u = Function(V.dual()) expr = interpolate(TestFunction(V), u) - I = Interpolator(expr, V) + I = get_interpolator(expr) for k in range(3): u.assign(rg.uniform(u.function_space())) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index f6c22e48e8..65974c5e54 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -339,12 +339,18 @@ def test_exact_refinement(): expr_in_V_fine = x**2 + y**2 + 1 f_fine = Function(V_fine).interpolate(expr_in_V_fine) + # Build interpolation matrices in both directions + coarse_to_fine = assemble(interpolate(TrialFunction(V_coarse), V_fine)) + coarse_to_fine_adjoint = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) + # If we now interpolate f_coarse into V_fine we should get a function # which has no interpolation error versus f_fine because we were able to # exactly represent expr_in_V_coarse in V_coarse and V_coarse is a subset # of V_fine f_coarse_on_fine = assemble(interpolate(f_coarse, V_fine)) assert np.allclose(f_coarse_on_fine.dat.data_ro, f_fine.dat.data_ro) + f_coarse_on_fine_mat = assemble(coarse_to_fine @ f_coarse) + assert np.allclose(f_coarse_on_fine_mat.dat.data_ro, f_fine.dat.data_ro) # Adjoint interpolation takes us from V_fine^* to V_coarse^* so we should # also get an exact result here. @@ -354,6 +360,10 @@ def test_exact_refinement(): assert np.allclose( cofunction_fine_on_coarse.dat.data_ro, cofunction_coarse.dat.data_ro ) + cofunction_fine_on_coarse_mat = assemble(action(coarse_to_fine_adjoint, cofunction_fine)) + assert np.allclose( + cofunction_fine_on_coarse_mat.dat.data_ro, cofunction_coarse.dat.data_ro + ) # Now we test with expressions which are NOT exactly representable in the # function spaces by introducing a cube term. This can't be represented @@ -550,7 +560,7 @@ def test_missing_dofs(): V_src = FunctionSpace(m_src, "CG", 2) V_dest = FunctionSpace(m_dest, "CG", 3) with pytest.raises(DofNotDefinedError): - Interpolator(TestFunction(V_src), V_dest) + assemble(interpolate(TrialFunction(V_src), V_dest)) f_src = Function(V_src).interpolate(expr) f_dest = assemble(interpolate(f_src, V_dest, allow_missing_dofs=True)) dest_eval = PointEvaluator(m_dest, coords) @@ -680,6 +690,32 @@ def test_interpolate_matrix_cross_mesh(): f_interp2.dat.data_wo[:] = f_at_points_correct_order3.dat.data_ro[:] assert np.allclose(f_interp2.dat.data_ro, g.dat.data_ro) + interp_mat2 = assemble(interpolate(TrialFunction(U), V)) + assert interp_mat2.arguments() == (TestFunction(V.dual()), TrialFunction(U)) + f_interp3 = assemble(interp_mat2 @ f) + assert f_interp3.function_space() == V + assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro) + + +@pytest.mark.parallel([1, 3]) +def test_interpolate_matrix_cross_mesh_adjoint(): + mesh_fine = UnitSquareMesh(4, 4) + mesh_coarse = UnitSquareMesh(2, 2) + + V_coarse = FunctionSpace(mesh_coarse, "CG", 1) + V_fine = FunctionSpace(mesh_fine, "CG", 1) + + cofunc_fine = assemble(conj(TestFunction(V_fine)) * dx) + + interp = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual()))) + cofunc_coarse = assemble(Action(interp, cofunc_fine)) + assert interp.arguments() == (TestFunction(V_coarse), TrialFunction(V_fine.dual())) + assert cofunc_coarse.function_space() == V_coarse.dual() + + # Compare cofunc_fine with direct interpolation + cofunc_coarse_direct = assemble(conj(TestFunction(V_coarse)) * dx) + assert np.allclose(cofunc_coarse.dat.data_ro, cofunc_coarse_direct.dat.data_ro) + @pytest.mark.parallel([2, 3, 4]) def test_voting_algorithm_edgecases():