Skip to content

Commit 2ac0c02

Browse files
authored
MixedInterpolator (#4596)
* MixedInterpolator * Interpolate: support fieldsplit * Interpolate: zero-simplify
1 parent dec8487 commit 2ac0c02

File tree

6 files changed

+272
-142
lines changed

6 files changed

+272
-142
lines changed

firedrake/assemble.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import finat.ufl
1919
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
2020
tsfc_interface, utils)
21-
from firedrake.formmanipulation import split_form
2221
from firedrake.adjoint_utils import annotate_assemble
2322
from firedrake.ufl_expr import extract_unique_domain
2423
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
@@ -570,36 +569,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
570569
rank = len(expr.arguments())
571570
if rank > 2:
572571
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
573-
# If argument numbers have been swapped => Adjoint.
574-
arg_operand = ufl.algorithms.extract_arguments(operand)
575-
is_adjoint = (arg_operand and arg_operand[0].number() == 0)
576-
577572
# Get the target space
578573
V = v.function_space().dual()
579574

580-
# Dual interpolation from mixed source
581-
if is_adjoint and len(V) > 1:
582-
cur = 0
583-
sub_operands = []
584-
components = numpy.reshape(operand, (-1,))
585-
for Vi in V:
586-
sub_operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
587-
cur += Vi.value_size
588-
589-
# Component-split of the primal operands interpolated into the dual argument-split
590-
split_interp = sum(reconstruct_interp(sub_operands[i], v=vi) for (i,), vi in split_form(v))
591-
return assemble(split_interp, tensor=tensor)
592-
593-
# Dual interpolation into mixed target
594-
if is_adjoint and len(arg_operand[0].function_space()) > 1 and rank == 1:
595-
V = arg_operand[0].function_space()
596-
tensor = tensor or firedrake.Cofunction(V.dual())
597-
598-
# Argument-split of the Interpolate gets assembled into the corresponding sub-tensor
599-
for (i,), sub_interp in split_form(expr):
600-
assemble(sub_interp, tensor=tensor.subfunctions[i])
601-
return tensor
602-
603575
# Get the interpolator
604576
interp_data = expr.interp_data.copy()
605577
default_missing_val = interp_data.pop('default_missing_val', None)

firedrake/formmanipulation.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy
33
import collections
44

5-
from ufl import as_vector, split
5+
from ufl import as_tensor, as_vector, split
66
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
77
from ufl.algorithms.map_integrands import map_integrand_dags
88
from ufl.algorithms import expand_derivatives
@@ -14,6 +14,7 @@
1414
from firedrake.petsc import PETSc
1515
from firedrake.functionspace import MixedFunctionSpace
1616
from firedrake.cofunction import Cofunction
17+
from firedrake.ufl_expr import Coargument
1718
from firedrake.matrix import AssembledMatrix
1819

1920

@@ -133,6 +134,17 @@ def argument(self, o):
133134
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
134135
return self._arg_cache.setdefault(o, as_vector(args))
135136

137+
def coargument(self, o):
138+
V = o.function_space()
139+
140+
if len(V) == 1:
141+
# Not on a mixed space, just return ourselves.
142+
return o
143+
144+
indices = self.blocks[o.number()]
145+
W = subspace(V, indices)
146+
return Coargument(W, number=o.number(), part=o.part())
147+
136148
def cofunction(self, o):
137149
V = o.function_space()
138150

@@ -171,6 +183,42 @@ def matrix(self, o):
171183
bcs = ()
172184
return AssembledMatrix(tuple(args), bcs, submat)
173185

186+
def zero_base_form(self, o):
187+
return ZeroBaseForm(tuple(map(self, o.arguments())))
188+
189+
def interpolate(self, o, operand):
190+
if isinstance(operand, Zero):
191+
return self(ZeroBaseForm(o.arguments()))
192+
193+
dual_arg, _ = o.argument_slots()
194+
if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1:
195+
# The dual argument has been contracted or does not need to be split
196+
return o._ufl_expr_reconstruct_(operand, dual_arg)
197+
198+
if not isinstance(dual_arg, Coargument):
199+
raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.")
200+
201+
indices = self.blocks[dual_arg.number()]
202+
V = dual_arg.function_space()
203+
204+
# Split the target (dual) argument
205+
sub_dual_arg = self(dual_arg)
206+
W = sub_dual_arg.function_space()
207+
208+
# Unflatten the expression into the target shape
209+
cur = 0
210+
components = []
211+
for i, Vi in enumerate(V):
212+
if i in indices:
213+
components.extend(operand[i] for i in range(cur, cur+Vi.value_size))
214+
cur += Vi.value_size
215+
216+
operand = as_tensor(numpy.reshape(components, W.value_shape))
217+
if isinstance(operand, Zero):
218+
return self(ZeroBaseForm(o.arguments()))
219+
220+
return o._ufl_expr_reconstruct_(operand, sub_dual_arg)
221+
174222

175223
SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])
176224

0 commit comments

Comments
 (0)