|
2 | 2 | import numpy |
3 | 3 | import collections |
4 | 4 |
|
5 | | -from ufl import as_vector, split |
| 5 | +from ufl import as_tensor, as_vector, split |
6 | 6 | from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm |
7 | 7 | from ufl.algorithms.map_integrands import map_integrand_dags |
8 | 8 | from ufl.algorithms import expand_derivatives |
|
14 | 14 | from firedrake.petsc import PETSc |
15 | 15 | from firedrake.functionspace import MixedFunctionSpace |
16 | 16 | from firedrake.cofunction import Cofunction |
| 17 | +from firedrake.ufl_expr import Coargument |
17 | 18 | from firedrake.matrix import AssembledMatrix |
18 | 19 |
|
19 | 20 |
|
@@ -133,6 +134,17 @@ def argument(self, o): |
133 | 134 | args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) |
134 | 135 | return self._arg_cache.setdefault(o, as_vector(args)) |
135 | 136 |
|
| 137 | + def coargument(self, o): |
| 138 | + V = o.function_space() |
| 139 | + |
| 140 | + if len(V) == 1: |
| 141 | + # Not on a mixed space, just return ourselves. |
| 142 | + return o |
| 143 | + |
| 144 | + indices = self.blocks[o.number()] |
| 145 | + W = subspace(V, indices) |
| 146 | + return Coargument(W, number=o.number(), part=o.part()) |
| 147 | + |
136 | 148 | def cofunction(self, o): |
137 | 149 | V = o.function_space() |
138 | 150 |
|
@@ -171,6 +183,42 @@ def matrix(self, o): |
171 | 183 | bcs = () |
172 | 184 | return AssembledMatrix(tuple(args), bcs, submat) |
173 | 185 |
|
| 186 | + def zero_base_form(self, o): |
| 187 | + return ZeroBaseForm(tuple(map(self, o.arguments()))) |
| 188 | + |
| 189 | + def interpolate(self, o, operand): |
| 190 | + if isinstance(operand, Zero): |
| 191 | + return self(ZeroBaseForm(o.arguments())) |
| 192 | + |
| 193 | + dual_arg, _ = o.argument_slots() |
| 194 | + if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1: |
| 195 | + # The dual argument has been contracted or does not need to be split |
| 196 | + return o._ufl_expr_reconstruct_(operand, dual_arg) |
| 197 | + |
| 198 | + if not isinstance(dual_arg, Coargument): |
| 199 | + raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.") |
| 200 | + |
| 201 | + indices = self.blocks[dual_arg.number()] |
| 202 | + V = dual_arg.function_space() |
| 203 | + |
| 204 | + # Split the target (dual) argument |
| 205 | + sub_dual_arg = self(dual_arg) |
| 206 | + W = sub_dual_arg.function_space() |
| 207 | + |
| 208 | + # Unflatten the expression into the target shape |
| 209 | + cur = 0 |
| 210 | + components = [] |
| 211 | + for i, Vi in enumerate(V): |
| 212 | + if i in indices: |
| 213 | + components.extend(operand[i] for i in range(cur, cur+Vi.value_size)) |
| 214 | + cur += Vi.value_size |
| 215 | + |
| 216 | + operand = as_tensor(numpy.reshape(components, W.value_shape)) |
| 217 | + if isinstance(operand, Zero): |
| 218 | + return self(ZeroBaseForm(o.arguments())) |
| 219 | + |
| 220 | + return o._ufl_expr_reconstruct_(operand, sub_dual_arg) |
| 221 | + |
174 | 222 |
|
175 | 223 | SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) |
176 | 224 |
|
|
0 commit comments